High-dimensional limit of one-pass SGD on least squares

We give a description of the high-dimensional limit of one-pass single-batch stochastic gradient descent (SGD) on a least squares problem. This limit is taken with non-vanishing step-size, and with proportionally related number of samples to problem-dimensionality. The limit is described in terms of a stochastic differential equation in high dimensions, which is shown to approximate the state evolution of SGD. As a corollary, the statistical risk is shown to be approximated by the solution of a convolution-type Volterra equation with vanishing errors as dimensionality tends to infinity. The sense of convergence is the weakest that shows that statistical risks of the two processes coincide. This is distinguished from existing analyses by the type of high-dimensional limit given as well as generality of the covariance structure of the samples.


Introduction
Stochastic optimization methods are the modern day standard for many large-scale computational tasks, especially those that arise in machine learning.There is a long history of analyses of these algorithms, beginning with the seminal work of [RM51], which focused on long-time behavior in a fixed dimensional space.However, modern applications of stochastic optimization have motivated a different regime of analysis, where the problem dimensionality grows proportionally with the run-time of the algorithm.
In this article, we derive the exact scaling behavior of stochastic gradient descent (SGD) on a least squares problem, in the one-pass setting (see below) when dimension tends to infinity.We further draw a comparison to the recent work [Paq+22;Paq+21], in which the multi-pass version of this problem was considered.
Stochastic gradient descent for empirical risk minimization Most versions of (minibatch) SGD can be formulated in the context of finite-sum problems: (1.1) Empirical risk minimization fits in this context by supposing that there are n independent samples from some data distribution, and each f i represents the loss of how the parameters x in some model fit the i-th datapoint.In this article we will exclusively consider the case of linear regression with 2 -regularizer.So we suppose that there are n iid samples ((a i , b i ) : 1 ≤ i ≤ n) from some distrbution D, with some assumptions to be specified.We arrange this data into a design matrix A and label vector b, whose i-th row is given by a i .Finally, we specify the functions f i in (1.1) by setting The parameter δ ≥ 0 is fixed and is the strength of the regularizer and throughout • will be the Euclidean norm.
Minibatch stochastic gradient descent in this context can be described as where {γ k } are stepsize parameters, e i is the i-th standard basis vector, and {i k } is a sequence of choices data.
In this article we consider the one-pass case, in which i k = k but the algorithm is terminated after n steps.In practice, the order of the data points might be shuffled once before, but in the setting we have posed, with iid data, there is no point to including this additional randomization.There are other choices for how to pick i k , and we highlight three of them, all of which are multi-pass variants.
In random (with replacement) sample SGD, each i k is chosen uniformly at random from {1, 2, . . ., n}.This is the setting considered in [Paq+22], and we shall refer simply to this flavor of SGD simply as multi-pass SGD in the bulk of the paper.But for context, we also mention single shuffle SGD, in which one takes i k = k mod n, and so only differs from one-pass SGD in that the algorithm performs the same operations every epoch1 In random shuffle SGD, one modifies the above strategy by randomly permuting the data between each epoch.All of these strategies are extensively studied in the optimization literature: it is generally thought that the single shuffle and random shuffle strategies are faster than the random sample strategy [YSJ21] (see also [RR12; GOP21; SS20; AYS20]).
The one-pass case is the fundamental point of comparison for all of these methods, being both simpler phenomenologically and also representing an idealization of SGD, in which the run-time of the algorithm is the amount data.Appropriately, running for longer (meaning increasing n) can only improve the statistical performance of the SGD estimator x n , in which context this is usually referred to as streaming SGD.
The performance of SGD is measured through the population risk P and sometimes through an 2 -regularized risk R, which are given by (1.4) This regularized risk appears naturally as the mean behavior of one-pass SGD, in that for martingale increments (ξ k : 1 ≤ k ≤ n).Another natural statistical setting to consider is out-of-distributional regression, in which case we would measure the performance of SGD trained on D but tested, as in (1.3) after replacing D with another distribution D .
We shall not pursue this case in detail, but we note that all of the above examples are some quadratic functionals of the SGD state x.

Data and stepsize assumptions
The goal of this analysis is to allow the number of samples n to be large and proportional to the dimension of the problem, here d.This means that the data must be normalized to be nearly dimension independent.Further, we shall need good tail properties of some of the random variables involved, and so we recall the Orlicz norms • ψp for p ≥ 1 which are given by X ψp = inf{t : Ee |X| p /t p ≤ 2}.
We refer the reader to [Ver18] for further exposition, properties and equivalent formulations.
We shall suppose throughout that under D, the labels are given by an underlying linear model with noise.Formally, we suppose that: Assumption 1.1.For (a, b) sampled from D, conditionally on a, the distribution of b is given by a • x + w where w is mean 0, variance η 2 ≥ 0 and is subgaussian with w ψ2 ≤ d ε .
The ground truth x is assumed to have norm at most d ε .
The constant ε will be small and fixed throughout.Anything less than 1 18 will do.The data covariance is assumed to be normalized in such a way that it is almost bounded in norm, which is to say: Assumption 1.2.The covariance matrix K := E[aa T ] has operator norm bounded independent of d.
Note that while we do not explicitly assume that a is centered, the mean would have to be small in some sense to achieve the assumption above.
Finally, we suppose that a has good tail properties, namely that Assumption 1.3.The data vector a satisfies that, for any deterministic x of norm less than 1, a • x ψ2 ≤ d ε , and the data vector a satisfies the Hanson-Wright inequality: for all t ≥ 0 and for any deterministic matrix B We remark that these assumptions hold for two important settings: (a) when a = √ Ku where K is some deterministic matrix of bounded operator norm and u is a vector of iid subgaussian random variables or (b) when a is a vector with the convex concentration property, see [Ada15] for details.
There are natural examples of the second case, such as random features models [RR08] with Lipschitz activation functions (see also [Paq+22,Proposition 6.2] for specifics in the case of random features).We note that by truncation, it is also possible to work in the setting (a) above but solely under a uniform bound on a sufficiently high but finite moment, but we do not pursue this.
Finally, the step-size parameters γ k must be normalized approporiately: Assumption 1.4.The stepsize γ k = γ d for all k and fixed γ > 0. We note that we may also pick γ k = γ(k/d)/d for a bounded continuous function γ : [0, ∞) → [0, ∞), and this leads to no change anywhere in the arguments.
In light of all these assumptions, we note that SGD finally reduces to the following stochastic recurrence where Homogenized SGD Our theorem is most easily formulated as showing that the state of SGD can be compared to a certain diffusion model in high dimensions.Homogenized SGD is defined to be a continuous time process with initial condition X 0 = x 0 that solves the stochastic differential equation where B t is standard Brownian motion in dimension d, and R and P are the regularized and unregularized risks, respectively (recall (1.3) and (1.4)).
Our main theorem shows that for quadratic statistics (in particular the risks (1.3) and (1.4)), homogenized SGD and SGD are interchangeable to leading order.We use the probabilistic modifier with overwhelming probability to mean a statement holds except on an event of probability at most e −ω(log d) where ω(log d) tends to ∞ faster than log d as d → ∞.We further introduce a norm • C 2 on quadratic functions q : R d → C with the norms on the right hand side being given by the operator and Euclidean norm respectively.
Theorem 1.5.For any quadratic q : R d → R, and for any deterministic initialization x 0 with x 0 ≤ 1, there is a constant C ( K ) so that the processes {x k } n k=0 and {X t } (1.7) with overwhelming probability.
The processes x k and X t are independent, and hence this is also a statement about concentration.In particular, the statement is also true if we replace q(X k/d ) by Eq(X k/d ).
Explicit risk curves Using existing theory (see [Paq+22, Theorem 1.1]), P(X k/d ) and R(X k/d ) can be seen to concentrate around their means, which solve a convolution Volterra equation.Specifically, EP(X t ) = Ψ (t) and ER(X t ) = Ω(t) where The process X t is gradient flow, and for the problems here, is explicitly solvable.Note the equation for Ψ is autonomous, and the solution of Ω is then solvable in terms of it.From this Volterra equation, it is easy to derive convergence rates, as well as convergence thresholds, as well as expressions for the limiting risk (in the double scaling limit, d → ∞ followed by t → ∞).See the discussions in [Paq+21] and see Figure 1 for an example.

Comparison to the multi-pass case
In [Paq+22], the analog of Theorem 1.5 was proven for the (random sample) multi-pass case.In that case, the diffusion is different.
Introduce the empirical risk L(x) = 1 2n Ax − b 2 and the regularized empirical risk f (x) = L(x) + δ x 2 .Then the homogenized SGD for the multi-pass case becomes Hence the difference between the multi-pass and one-pass cases is that the population risks are traded for the empirical risk and the data covariance matrix K is traded for the empirical data covariance matrix ( 1 n A T A).Note that if one conditions on the data (A, b), then multi-pass SGD in fact is streaming SGD, but with a finitely supported data distribution (specifically the empirical distribution of data); however the empirical distribution of data is very far from satisfying Assumption 1.3.version.Note that at smaller dataset sizes, multi-pass SGD improves greatly over onepass SGD.At higher dataset sizes, they are similar and in fact multi-pass SGD always underperforms.

Discussion
We have presented an approach to taking the high-dimensional limit of one-pass SGD on a least squares problem in which the number of steps is proportional to the dimension of the problem.The limit object is described in terms of a Langevin type diffusion, which can be directly compared to the same object in the multi-pass case.The literature on scaling limits of one-pass SGD training is large, and so we mention just some of the closest literature.[AGJ22] is perhaps the closest high-dimensional diffusion approximation, and it applies in cases where there is a hidden finite dimensional structure; it covers the case studied here when K = I as well as cases in which K has boundedly many eigenvalues.See also [AGJ21].
There are other scaling limits that pursue a different formulation than the one here.[WL19; WML17] give a PDE limit for the state for a generalized linear model, with identity covariance.[BGH23] give a scaling limit of the SGD under smoothness assumptions on the covariance K, when interpreted geometrically; they further describe fluctuations of SGD in a certain sense.Note that Theorem 1.5 essentially gives the law of large numbers for the risks and not the fluctuations.
[Ger+22] uses dynamical field theory to give closely related results for shallow neural networks with minibatch SGD of large batch-size; dynamical mean field theory provides an implicit characterization of the autocorrelation of the minibatch noise and a few other processes.See also the related work of [CCM21].We comment that in the case of proportional batch sizes, there is also a discrete Volterra description in [Lee+22].
Organization In Section 2 we give an overview of the main proof, reducing it to its main technicalities.In Section 3, we bound the stochastic error terms, representing the main technical contribution of the paper.

Main argument of proof
In order to compare the SGD and homogenized SGD, we use a version of the martingale method in diffusion approximations (see [EK09]).In effect we show that q(x k ) nearly satisfies the conclusion of Itô's lemma.Further, we show the martingale terms in both of the Doob decompositions are small, and hence it suffices to show the predictable parts of q(x k ) and q(X t ) are close.
To advance the discussion, we compute this Doob decomposition.To take advantage of the simpler structure afforded by removing x, introduce (2.1) We shall extend the first integer indexed function to real-valued indices by setting v t = v t .We also let (F t : t ≥ 0) be the filtration generated by (v t : t ≥ 0) and (V t/d : t ≥ 0).Hence for all k ∈ N, v k is measurable with respect to F k .Recalling the recurrence (1.5) for a quadratic q q (2. 2) The equation above can each be decomposed as a predictable part and two martingale increments (2. 3) The remainder of the martingale increments are given by ∆M lin k and are all linear in ∆ k .
The predictable parts have been further decomposed into the leading order terms and an error term ∆E quad k .
These predictable parts, in turn, depend on different statistics q 1 (v k−1 ).In finite dimensional settings, we would be able to relate this (or some suitably large finite set of summary statistics q, q 1 , . . ., q r to itself through a closed system of recurrences.In this setting, this is not possible.On the other hand, for the problem at hand, we show there is a manifold of functions which approximately closes.Specifically, we let (2.4) Here R(z; K) = (K−zI) −1 is the resolvent matrix, and Γ is a circle of radius max{1, 3 K }.
In order to control the martingales, it is convenient to impose a stopping time and we introduce the corresponding stopped processes (2.6) We prove a version of our theorem for the stopped processes and then show that the stopping time is greater than n with overwhelming probability.
Our key tool for comparing v td and V t is the following lemma.
Lemma 2.1.Given a quadratic q with q C 2 ≤ 1, (2.7) Here M HSGD,τ t is the martingale part in the semimartingale decomposition of q(V τ t ).
Sketch of Proof.Owing to the similarities of this claim with the proof in [Paq+22, Proposition 4.1], we just illustrate the main idea.The idea is that if we take a g ∈ Q, and we apply (2.3), then in the predictable part of g(v t ) we have These also appear with coefficients that can be bounded solely using g C 2 and K .We get the same, applying Itô's lemma to g(V t ), albeit with the replacement v t → V t .We wish to bound for example I 1 (v t ) − I 1 (V t ).We do this by expressing its integrand as p(v t ) − p(V t ) for polynomial p.If g is linear (the final row of (2.4)), then p is again linear.
For example, if it is g(x) = ∇q(x) T R(z; K) x, then p is again linear and is given by where we have used the resolvent identity (K − z)R(z; K) = I.Note the function x T R(z; K) x is contained in Q by virtue of being in Q n ( • 2 ).Moreover, by Cauchy's integral formula, we can represent x T x by averaging −1 2πi x T R(y; K) The same manipulations lead finally to showing every term included in Q can be controlled in a similar manner, using the other elements of the class Q.
The second important idea is to discretize the set Q.

Lemma 2.2. There exists
Proof.On the spectral curve Γ , we can bound the norm of the resolvent.Since we have it is norm bounded by an absolute constant.The arc length of the curve is at most C( K ), and so by choosing a minimal net d −2ε of the manifold Γ × Γ , the lemma follows.
Now the main technical part of the argument is to control the martingales and errors.As we work with the stopped process v τ k we introduce the stopped proccesses , which are defined analogously to (2.6).
Lemma 2.3.For any quadratic q with q C 2 ≤ 1, the terms M lin,τ k , M quad,τ k , E quad,τ k satisfy the following bounds with overwhelming probability (with a bound which is uniform in q) for n ≤ d log d Combining Lemmas 2.1 and 2.2, along with the above (see also Lemma 3.2 in which the homogenized SGD martingales are bounded), we conclude that, for any q ∈ Q with q C 2 = 1, (2.8) Hence by Lemma 2.2 and by bounding g C 2 over all Q, By Gronwall's inequality, this gives us that with overwhelming probability (2.10) Now we note that the norm function x → x 2 is one of the quadratics included in Q.
Hence if we let G be the event in the above display, and we let This is because on the event {τ ≤ n/d} ∩ E we must have had v τ > d ε , but in the step before τ, we had v τ −1 could be compared to V τ −1 (due to G, and we had the norm of V τ −1 was small.Now it is easily seen that with overwhelming probability, no increment of SGD between time 0 and n/d can increase the norm by a power of d.So to complete the proof it suffices to show E holds with overwhelming probability.
Thus the proof is completed by the following: Lemma 2.4.For any δ > 0 and any t > 0 with overwhelming probability Proof.We apply Itô's formula to φ(X t ) := log(1 + X t 2 ), from which we have The drift terms and the quadratic variation terms can be bounded by some C( K ).Hence with this constant, for all r ≥ 0, Taking r = √ log d log log d, we conclude that with overwhelming probability

Controlling the errors
The main goal of this section is to control the martingale terms and error terms; in particular we prove Lemma 2.3.In order to obtain these bounds, we will need the following concentration lemma, which is standard (c.f.[Ver18, Theorem 2.8.1],where the nonmartingale bound is proven.The adaptation to the martingale case is a small extension): 1 is a martingale on the filtered probability space (Ω, (F n ) N 1 , P)) and we define then there is an absolute constant C > 0 so that, for all t > 0, . (3.2) We will also record for future use an estimate on ∇q that follows from • C 2 control. (3.3)

Martingale for gradient part of recurrence
Proof of Lemma 2.3 part (i).Comparing (2.2) and (2.3), we see that for k ≤ τ where where C is some absolute constant.For the second part, we have Combining these, we see that, for every k, and, by the martingale Bernstein inequality, (3.9) with overwhelming probability.

Martingale for Hessian part of recurrence
Proof of Lemma 2.3 parts (ii) and (iii).Next we consider the contribution from the Hessian part of the recurrence.We write (3.10) Rearranging the terms, we get where (3.12) This can be expanded as (3.13) so we focus first on obtaining subexponential bounds for the quantities For the terms involving η k , we use the Orlicz bounds from the assumptions in the set-up to obtain (3.17) Since also ||η 2 k || ψ1 = d −1+2ε combining the bounds (3.16) and (3.17), we have (3.18) Furthermore, we have uniformly for all k based on the assumptions on η k and m k .We now use (3.15), (3.18), (3.19) to bound each term of (3.13) in turn.
To bound the contribution from we can conclude that, with overwhelming probability, n k=1 (3.20) For the second term of (3.13) we have We can bound this quantity using where the bound in the last line comes from combining (3.14) and (3.18).Using this bound, we obtain (3.24)Thus,  (3.29) with overwhelming probability.This completes the proof of part (ii) of the lemma.
For part (iii), we observe that ∆E quad,τ ), the error terms arising from u k cross terms, so that the bound of E quad,τ k follows immediately from (3.25).

Figure 1 :
Figure 1: Risk curves for a simple linear regression problem in d = 2000.Multi-pass SGD, its high dimensional equivalent (the expected risk under homogenized SGD, i.e. "Volterra"), Streaming SGD (i.e.one-pass with varying dataset size), and the expected risk of homogenized SGD ("Streaming Volterra") are all plotted.Risk levels for streaming SGD at various levels n are plotted for comparison against the corresponding multi-pass

(3. 4 )
Note for k > τ , the stopped martingale increment is 0. Using (3.3), w k−1 ≤ C(γ, δ)d ε .We will separately bound the contributions from ∆M lin 1,τ k and ∆M lin 2,τ k in terms of their Orlicz norms.For the first part, for any fixed k, we condition on F k−1 and Assumption 1.3, we conclude that n ≤ d log d then this gives us sup 1≤k≤n |M lin,τ k | ≤ d − 1 2 +5ε 14)and thus we have the subexponential bound||A k − E[A k ]|| ψ1 < Cd − 1 2 +2ε .(3.15)Next we obtain a subexponential bound for B k .For the part of B k not involving η k , we use Hanson-Wright to get