← cgad.ski 2023-10-26

When Gradient Descent Is a Kernel Method

Suppose that we sample a large number NN of independent random functions fi ⁣:RRf_i \colon \R \to \R from a certain distribution F\Fc and propose to solve a regression problem by choosing a linear combination fˉ=iλifi.\bar{f} = \sum_i \lambda_i f_i. For large N,N, adjusting the coefficients λi\lambda_i to fit some fixed constraints of the form fˉ(ti)=yi\bar{f}(t_i) = y_i amounts to solving a highly underdetermined linear system, meaning that a high-dimensional space Λ\Lambda of vectors (λ1,,λN)(\lambda_1, \dots, \lambda_N) fit our constraints perfectly. So, choosing one element of Λ\Lambda requires some additional decision-making. To use the picturesque idea of a "loss landscape" over parameter space, our problem will have a ridge of equally performing parameters rather than just a single optimal peak.

Now, we make a very strange proposal. What if we simply initialize λi=1/n\lambda_i = 1/\sqrt{n} for all ii and proceed by minimizing some loss function using gradient descent? If all goes well, we should end up with an element of Λ.\Lambda. Of course, many elements of Λ\Lambda give very bad models. To see this, it's enough to remember that we can expect a linear combination of NN random functions to fit any NN data points, so if we have mm data points, there exist models in Λ\Lambda that perfectly interpolate any adversarial selection of NmN - m additional data points! Does gradient descent tend to make a "good" choice?

Let's test this empirically. In the widget below, I've chosen functions fif_i by sampling 200200 trajectories of a Wiener process, also known as Brownian noise. Click anywhere to introduce data points and click the play button in the top right to run gradient descent for the squared loss i(fˉ(ti)yi)2.\sum_i (\bar{f}(t_i) - y_i)^2.

Interestingly, the functions we obtain are not that bad. They seem to concentrate around piecewise linear interpolations of our data. In fact, in the limit NN \to \infty of many random functions, it turns out that running gradient descent to convergence has a meaningful statistical interpretation. Specifically, if we view the Wiener process F\Fc as a prior, then running gradient descent to convergence samples from the posterior for our data points. Since many optimal solutions to our minimization problem are meaningless, it is not possible to explain this fact if we see gradient descent as "just some optimization method." What explains its relative success?

As we will show in this post, our intriguing Bayesian interpretation can be explained by the relationship between the behavior of gradient descent steps, the statistical properties of our random functions fif_i, and our initialization. In particular, it does not depend on the loss function—so long as it leads gradient descent to converge to an exact interpolation—but does depend significantly on our choice of parameters at initialization. Our analysis will rely on a "tangent kernel" of the sort introduced in the Neural Tangent Kernel paper by Jacot et al.. Specifically, viewing gradient descent as a process occurring in the function space of our regression problem, we will find that its dynamics can be described in terms of a certain kernel function, which in this case is just the kernel function of the process F.\Fc.

Of course, there are much easier ways to sample posteriors for low-dimensional Gaussian processes. Nevertheless, it's interesting to notice a relationship between Bayesian inference and gradient descent methods at large, since the latter tend to apply to situations where direct estimation of a posterior distribution is not practical. Furthermore, the results of Jacot suggest that kernel-based interpretations may also hold for large, non-trivial neural networks. To the extent that this is true, we can use the kernel interpretation to reason about many sorts of neural network phenomena, including the benefits of early stopping, the existence of "implicit regularization", and the fact that overparameterization often increases performance despite the apparent risk of overfitting.

In this post, we will focus exclusively on the toy problem introduced above. Our discussion (which admittedly takes a bit of a scenic route) is divided into three sections.

Kernel Functions

Let's begin by considering the effect that a single step of gradient descent has on our function fˉ.\bar{f}. In general, the differential of a loss can be written as a sum of differentials dπtd \pi_t where πt\pi_t is the evaluation of fˉ\bar{f} at an input t,t, so by linearity it is enough for us to understand how fˉ\bar{f} "responds" to differentials of this form.

In response to dπt,d \pi_t, the parameters λi\lambda_i are assigned differentials πtλi=fi(t).\frac{\partial \pi_t}{\partial \lambda_i} = f_i(t). So, gradient descent will increase λi\lambda_i proportional to the value of fi(t).f_i(t). In terms of fˉ,\bar{f}, we find that the value fˉ(s)\bar{f}(s) at another input ss will increase proportional to Δt(s)=i=1Nfi(s)fi(t).(1)\Delta_t(s) = \sum_{i = 1}^N f_i(s) f_i(t). \tag{1} Note that this expression is independent of the coefficients λi.\lambda_i. This means that the gradient descent step we apply to fˉ\bar{f} depends only on our learning rate and differential of the loss at fˉ.\bar{f}. In other words, we can view gradient descent as a process happening in the function space of our regression problem.

For large NN we get the approximation 1NΔt(s)EfF[f(s)f(t)].(2)\frac{1}{N} \Delta_t(s) \approx E_{f \sim \Fc}[f(s) f(t)]. \tag{2} This last expression is familiar in the study of Gaussian processes; it is called the covariance kernel of F\Fc and denoted K(s,t).K(s, t). For the Wiener process, the covariance kernel takes the form K(s,t)=min(s,t).K(s, t) = \min(s, t). In the following, we'll assume NN is large enough for us to make the approximation (2)(2) confidently. Then, we conclude that a request of dπtd \pi_t will cause gradient descent to push fˉ\bar{f} in the direction of the function K(,t).K(-, t).

You can get a visual sense of this behavior in the widget below. As above, I've generated 200200 trials of the Wiener process to use as my functions fi.f_i. You can choose the request dπtd \pi_t by clicking on the graph, and your browser will compute the corresponding response Δt/N.\Delta_t/N. For comparison I've also drawn the prediction K(,t).K(-, t).

This is already a significant conclusion. In particular, it means that every step of gradient descent modifies fˉ\bar{f} by a linear combination of the functions K(,ti),K(-, t_i), where tit_i ranges over the inputs in our training set. Since the linear span of these functions is a certain space of piecewise affine functions, if we initialize λi=0\lambda_i = 0 and run gradient descent to convergence with any reasonable loss function, we should approximately converge to a piecewise affine interpolation of our data points. We've run this experiment in the widget below.

This new model has less variance than before. In fact, in the large NN limit, its behavior will be exactly deterministic! In comparison, our previous model will always exhibit variance due to initialization of the functions fif_i even for large N.N. In other words, when given a finite training set, gradient descent cannot entirely "forget" its initialization even when run to convergence.

Another important conclusion is that, when we optimize least squares with gradient descent, the evolution of fˉ\bar{f} is linear in the sense of approximately obeying a linear ODE. Indeed, for data points (ti,yi)(t_i, y_i) our loss differential will be 12di(fˉ(ti)yi)2=idπti(fˉ(ti)yi).\frac{1}{2} d \sum_i (\bar{f}(t_i) - y_i)^2 = \sum_i d \pi_{t_i} (\bar{f}(t_i) - y_i). So, if we view fˉ\bar{f} as evolving under the flow of gradient descent with respect to a continuous parameter τ,\tau, we have dfˉdτ=iK(,ti)(yifˉ(ti)),\frac{d \bar{f}}{d \tau} = \sum_i K(-, t_i) (y_i - \bar{f}(t_i)), where the right-hand side is a linear function of the empirical error vector h=(yifˉ(ti)).h = (y_i - \bar{f}(t_i)). Restricting this equation to the evaluation of fˉ\bar{f} over the input points ti,t_i, we find that the error vector hh solves the ODE dhdτ=Kh\frac{d h}{d \tau} = -K h where KK is the matrix [K(ti,tj)].[K(t_i, t_j)]. Since this matrix is positive-definite—it is a covariance matrix—we conclude that hh will converge to 00 over the training process if our learning rate is sufficiently small. Furthermore, knowing the eigenvalues of KK lets us understand the nature of our convergence; gradient descent will "correct" the error hh along the components of the eigenbasis for KK with largest eigenvalues first, and take longer to correct components with smaller eigenvalues.

Now, we could have chosen a distribution F\Fc of random functions with a different covariance kernel K.K. Here the functions K(,t)K(-, t) were easy to interpret, but in general, what does it mean to fit a data set using a linear combination of functions like these? One interesting perspective comes from the idea of regularization, which we discuss next.


Consider a Hilbert space HH equipped with bounded projections πt ⁣:HR\pi_t \colon H \to \R for each tR,t \in \R, and suppose that we want to find an element vHv \in H that minimizes some loss function depending on the values πt(v)\pi_{t}(v) for tt belonging to some collection {ti}.\{t_i\}. (Note that elements of HH can be viewed as functions from R\R to R\R by viewing πt\pi_t as the evaluation at t.t.) If this problem is underdetermined—which necessarily will happen if HH is infinite-dimensional and our collection {ti}\{t_i\} is finite—then we may ask for an element vv that both minimizes our loss and has minimal norm in H.H. In machine learning, this is called regularization.

Let's write Π\Pi for the product of the projections πti\pi_{t_i} Because vv is chosen with minimal norm, it cannot be made smaller by adjusting it by an element of kerΠ,\ker \Pi, so vv is orthogonal to kerΠ.\ker \Pi. But since the maps πt\pi_{t} are continuous, they can be represented by vectors KtK_{t} in the sense that πt()=Kt,.\pi_{t}(-) = \langle K_{t}, - \rangle. (This is the Riesz representation theorem.) Since kerΠ\ker \Pi can be described as the orthogonal complement to the set {Kti},\{K_{t_i}\}, the orthogonal complement to kerΠ\ker \Pi is exactly the closure of the span of the vectors Kti.K_{t_i}. We conclude that any regularized solution to our loss function is a (limit of) linear combinations of these vectors.

What are the projections of the "representative elements" KtK_{t}? By our own definition, we have πs(Kt)=Ks,Kt\pi_s(K_t) = \langle K_s, K_t \rangle for any other sR.s \in \R. This last expression is a positive semidefinite kernel, which we will denote K(s,t).K(s, t). In other words, the norm on HH and the projections πt\pi_t work together to produce a kernel function KK whose partial evaluations K(,t)K(-, t) help us solve optimization problems regularized by the norm of H.H.

In the literature, a Hilbert space equipped with bounded projections indexed over a set II is called a reproducing kernel Hilbert space (or RKHS). In fact, we can also go in the other direction: every positive definite kernel on II is "reproduced" by some RKHS, which also turns out to be unique in a certain sense. This is known as the Moore-Aronszajn theorem.

What RKHS corresponds to our kernel K(s,t)=min(s,t)K(s, t) = \min(s, t)? In general, determining the RKHS of a kernel is not entirely straightforward. In fact, notice that for positive definite kernels over a finite set I,I, the inner product for the RKHS expressed in the dual basis for our projections turns out to be the inverse of the matrix encoded by our kernel. Indeed, where KiK_i are representatives of the projections πi\pi_i and eie_i is a dual basis verifying πi(ej)=δi,j,\pi_i(e_j) = \delta_{i, j}, we find that ei,ejK(j,k)=ei,ejπj(Kk)=ei,Kk=πk(ej)=δi,k.\langle e_i, e_j \rangle K(j, k) = \langle e_i, e_j \rangle \pi_j(K_k) = \langle e_i, K_k \rangle = \pi_k(e_j) = \delta_{i, k}. So, interpreting a RKHS norm in terms of projections of elements requires solving some sort of inverse problem.

The RKHS for a centered Gaussian process (Xt)(X_t) can be viewed as an isometric embedding of the observables XtX_t with respect to the L2L^2 norm for the process measure F.\Fc. Specifically, if we define f(Xt)=Kt,f(X_t) = K_t, then clearly Xt,XkL2F=f(Xt),f(Xk)RKHS.\langle X_t, X_k \rangle_{L^2 \Fc} = \langle f(X_t), f(X_k) \rangle_\text{RKHS}. Indeed, the space of observables of a Gaussian process is already a RKHS for its covariance kernel, if we take the projections to be the maps Xt,.\langle X_t, - \rangle. However, we would like to view the RKHS more directly as a space of functions.

We may begin by observing, then, that observables of the Wiener process can be isometrically mapped into L2[0,)L^2 [0, \infty) by sending XtX_t to Ks(t)={1:ts0:otherwise.K_s(t) = \begin{cases} 1 : t \le s \\ 0 : \text{otherwise.} \end{cases} Under this perspective, our projections become πt(f)=Kt,f=0tf(s)ds.\pi_t(f) = \langle K_t, f \rangle = \int_0^t f(s) \, ds. Ultimately, we are led to view the RKHS of the Wiener process as the Sobolev space of absolutely continuous functions f ⁣:[0,)Rf \colon [0, \infty) \to \R such that f(0)=0f(0) = 0 and such that the norm f=(0(f(t))2dt)\lVert f \rVert = \left( \int_0^\infty (f'(t))^2 \, dt \right) is finite. In fact, solving the regularized interpolation problem minimize0(f(s))2dtsubject tof(0)=0,f(ti)=yi  for all i\begin{align*} \textbf{minimize} & \quad \int_0^\infty (f'(s))^2 \, dt \\ \textbf{subject to} & \quad f(0) = 0, f(t_i) = y_i \; \text{for all }i \end{align*} results in the piecewise affine interpolations we observed in the widget above.

So far, we have shown that its relationship with kernel functions gives gradient descent a distinct flavor of implicit regularization. We did not have a penalty function in mind when we set up our problem, but our distribution of random functions ended up making our gradient updates interpretable in terms of a RKHS for an associated kernel function. In the last section of this post, we address how this fact is related to the statistical idea of a conditional distribution for a Gaussian process.

Bayesian Interpretation

When XX and YY are jointly Gaussian distributed, we know that the remainder YE(YX)Y - E(Y|X) of the conditional expectation is independent from X.X. So, we can decompose YY into two components Y=(YE(YX))+E(YX),Y = (Y - E(Y|X)) + E(Y|X), the first being a Gaussian variable independent from XX and the second being XX-measurable. This clarifies the nature of the conditional distribution of YY on X=xX = x: it will have constant variance equal to the variance of YE(YX)Y - E(Y|X) and mean equal to E(YX),E(Y|X), a linear function of X.X. In particular, if we want to sample the conditional distribution of YY given X=x,X = x, we could take Y+E(YX=x)E(YX)=Y+E(YX=xX),Y + E(Y|X = x) - E(Y|X) = Y + E(Y|X = x - X), where the apparently nonsensical conditional expectation on the RHS should be interpreted as the evaluation of the conditional expectation E(YX),E(Y|X), viewed as a function of X,X, at xX.x - X. Keep in mind that this is a very special property of Gaussian distributions; in general, the distribution of the remainder YE(YX)Y - E(Y|X) conditional on XX will depend on X,X, and so we won't be able to sample the conditional distribution under another "counterfactual" value X=xX = x simply by translating a sample of the remainder.

Now, consider a random function ff drawn from a Gaussian distribution F\Fc and let Π\Pi give the values of our trajectory on a finite set of inputs. If we want to produce a sample from the distribution of ff conditional on some data Π=π,\Pi = \pi, we can take fE(fΠ=πΠ).f - E(f | \Pi = \pi - \Pi). But, as it turns out, the conditional expectation E(fΠ=π)E(f|\Pi = \pi^*) will be exactly the function ff in the RKHS of our process that solves the constraint Π=π\Pi = \pi^* regularized by the RKHS norm! This explains the Bayesian interpretation of our toy model.

One way to understand this is just to write out the expression for E(f(t)Π)E(f(t)|\Pi) at a given value of t.t. We know that this will be the linear function λΠ\lambda \Pi of Π\Pi uniquely determined by the equation Cov(f(t)λΠ,Π)=0.\Cov(f(t) - \lambda \Pi, \Pi) = 0. Where K=[K(ti,tj)]K = [K(t_i, t_j)] is the covariance matrix of Π\Pi and v=[Cov(f(t),f(ti))]v = [\Cov(f(t), f(t_i))] gives the covariance of f(t)f(t) with the components f(ti)f(t_i) of Π,\Pi, this equation can be written as vλK=0,v - \lambda K = 0, and E(f(t)Π)=vK1Π.E(f(t) | \Pi) = v K^{-1} \Pi. We conclude that the function tE(f(t)Π)t \mapsto E(f(t)|\Pi) is a linear combination of the functions K(t,ti)K(t, t_i)—the coordinates of the vector vv—with constant coefficients, determined by the constraints that E(f(t)Π)E(f(t)|\Pi) should agree with Π\Pi at the points t=ti.t = t_i. But, as we saw above, this is the same as the solution to the problem of interpolating some constraints f(ti)=yif(t_i) = y_i regularized by the RKHS norm of our process.

To see this connection more directly, remember that the mean of a Gaussian distribution coincides with the mode—the point of highest probability density under linear coordinates. So, for example, E(f(t)Π=π)E(f(t)|\Pi = \pi^*) is exactly the value of f(t)f(t) that minimizes ln(p(f(t),π))=C+12[f(t)  π(1)    π(m)]K1[f(t)π(1)π(m)]]\ln(p(f(t), \pi^*)) = C + -\frac{1}{2} [f(t) \; \pi^*(1) \; \dots \; \pi^*(m)] K^{-1} \begin{bmatrix} f(t) \\ \pi^*(1) \\ \vdots \\ \pi^*(m)] \end{bmatrix} where K1K^{-1} is the inverse of the covariance matrix for (f(t),Π).(f(t), \Pi). But from the previous section we know that K1K^{-1} expresses the inner product of the RKHS derived from the covariance kernel of (f(t),Π)(f(t), \Pi) in the dual basis for the projections. So in fact we are asking for the value of (f(t),Π)(f(t), \Pi) that satisfies the constraint Π=π\Pi = \pi^* and is regularized by the RKHS norm corresponding to a restriction of the covariance kernel of our process, which by the representer theorem will be a linear combination of restrictions of functions K(,t).K(-, t).

Abstractly, whenever we have a positive definite kernel K ⁣:I×IRK \colon I \times I \to \R with RKHS HH and a finite subset JI,J \subseteq I, we get a natural projection P ⁣:HHJP \colon H \to H_J onto the RKHS for the kernel restricted to J×JJ \times J given by P(v)=jπj(v)Kj.P(v) = \sum_j \pi_j(v) K_j. Given a vector vH,v \in H, what is the norm of P(v)P(v) in HJH_J? Since PP is an isometry over the span of the elements Kj,K_j, we can view P(v)HJ\lVert P(v) \rVert_{H_J} as the minimum possible norm for an element wHw \in H solving the equation P(w)=P(v).P(w) = P(v). In particular, solving a regularized problem over HH that depends on projections πj\pi_j for jJj \in J and then restricting the solution to HJH_J is the same as restricting to HJH_J and solving the regularized problem with respect to the norm on HJ.H_J.

As a final remark, note that we can informally imagine the RKHS of a Gaussian process as specifying the "energy" of the process in a statistical mechanics sense; although the norm of the RKHS is not defined over the same function space that the process takes values, we get the energy for the joint distribution of any finite projection (f(t1),,f(tm))(f(t_1), \dots, f(t_m)) as a function of (y1,,ym)(y_1, \dots, y_m) by solving minimizefHsubject tof(ti)=yi.\begin{align*} \textbf{minimize} & \quad \lVert f \rVert_H \\ \textbf{subject to} & \quad f(t_i) = y_i. \end{align*} This is the most satisfactory way that I've found to connect the interpretation of kernel functions in terms of regularization with their interpretation in terms of conditional expectations of a Gaussian process.

← cgad.ski