Tradeoffs in Large Scale Learning: Statistical Accuracy vs. Numerical Precision Sham M. Kakade Microsoft Research, New England S. M. Kakade (MSR) One pass learning 1 / 16 The Stochastic Optimization Problem You have a stochastic optimization problem: min P(w), where P(w) = Eψ∼D [ψ(w)] . w You get samples from D: ψ1 , ψ2 , . . . ψN How do you estimate w∗ , the minima of P? Your expected error/excess risk/regret is: b N ) − P(w∗ )] E[P(w S. M. Kakade (MSR) One pass learning 2 / 16 Statistics vs. Computation Stochastic approximation Numerical analysis e.g. stochastic gradient descent e.g. batch gradient descent easy to implement steps (streaming) costly to implement steps 1/ε steps for ε accuracy log 1/ε steps for ε accuracy good enough for statistics? extra accuracy is worth it? What would you do? S. M. Kakade (MSR) One pass learning 3 / 16 Without Computational Constraints... What would like to you do? Compute the empirical risk minimizer ERM: b NERM ∈ argmin w w∈S N 1X ψi (w). N i=1 (regularizer can be included). Consider the ratio: b NERM ) − P(w∗ )] E[P(w . b N ) − P(w∗ )] E[P(w Can you compete with the ERM on every problem efficiently? S. M. Kakade (MSR) One pass learning 4 / 16 This Talk Under regularity conditions (strong convexity + smoothness), a streaming algorithm which: single pass algo, with space usage linear in the size of one sample achieves the statistical rate of the ERM on every problem (even considering constant factors) (super)-polynomially decreases the initial error (also, trivially parallelizable) Related work: Juditsky and Polyak ’92; Defossez and Bach ’14; S. M. Kakade (MSR) One pass learning 5 / 16 Related work Stochastic Approximation: Robin & Monro (1951) Polyak & Juditsky (1992) Bottou & Bousquet (2008) Defossez & Bach (2014) Aside: Linear convergence for sums of convex functions Strohmer & Vershynin (2009) Le Roux, Schmidt, Bach (2012) Shalev-Shwartz & Zhang, (2013) Stochastic Variance Reduced Gradient (SVRG), Johnson & Zhang (2013) S. M. Kakade (MSR) One pass learning 6 / 16 Talk Outline 1 Statistics: How does the ERM even behave? 2 Computation: Let us just compute the ERM. 3 Computation + Statistics: ? S. M. Kakade (MSR) One pass learning 7 / 16 Statistics: Empirical Risk Minimization b NERM ? Precisely, what is the error of w i 1 h E k∇ψ(w∗ )k2(∇2 P(w∗ ))−1 2 Theorem: (classical, like Cramer-Rao) σ 2 := independent samples ψ1 , ψ2 , . . . ψ is convex (almost surely) ψ are smooth functions (almost surely) ∇2 P(w∗ ) exists and is positive definite. We have: b NERM ) − P(w∗ )] E[P(w =1 N→∞ σ 2 /N lim S. M. Kakade (MSR) One pass learning 8 / 16 Computation: Offline optimization for sums of convex functions: min F (w), where F (w) = w 1X ψi (w) N i Assume: F is µ-strongly convex and each ψi is L-smooth. Least squares as a running example. ψi (w) = (Yi − w · Xi )2 , and so: µ = λmin E[XX > ] and kX k2 ≤ L effective condition number: κ = S. M. Kakade (MSR) w, Xi ∈ Rd , Yi ∈ R L µ One pass learning 9 / 16 Stochastic Gradient Descent F (w) = 1X ψi (w) N i SGD update rule: at time t, sample ψ ∼ uniform{ψ1 , ψ2 , . . . ψN } wt+1 ← wt − η∇ψ(wt ) let wmin be the minimizer of F even if wt = wmin , after updating wt+1 6= wmin How do you fix this? S. M. Kakade (MSR) One pass learning 10 / 16 Stochastic Variance Reduced Gradient (SVRG) Iterate: 1 exact GD: at w̃s , compute the exact (batch) gradient: ∇F (w̃s ) note: ∇ψ(w̃s ) − ∇F (w̃s ) is mean zero for random ψ. 2 corrected SGD: initialize w0 = w̃s sample ψ ∼ uniform{ψ1 , ψ2 , . . . ψN } wt+1 ← wt − η (∇ψ(wt ) − ∇ψ(w̃s ) + ∇F (w̃s )) 3 update and repeat: w̃s+1 = wm (after m steps) S. M. Kakade (MSR) One pass learning 11 / 16 SVRG: the rate Thm: (Johnson & Zhang, ’13) SVRG has geometric convergence for fixed η. E[F (w̃s ) − F (wmin )] ≤ e−s · (F (w̃0 ) − F (wmin )) must store dataset still requires many passes What about the statistical rate? S. M. Kakade (MSR) One pass learning 12 / 16 Back to the stochastic optimization problem Streaming setting: 1 we get random samples one at time ψ ∼ D 2 we care about out of sample error on P do we need to handle both a: 3 generalization analysis? computational analysis? S. M. Kakade (MSR) One pass learning 13 / 16 Our algo: Streaming SVRG always draw fresh samples 1 estimation: at some w̃s , draw ks samples: ψ1 , ψ2 . . . 1 X \ ∇ψi (w̃s ). ∇P( w̃s ) = ks i 2 update: for m steps, initialize w = w̃s : sample ψ randomly \ wt+1 ← wt − η ∇ψ(wt ) − ∇ψ(w̃s ) + ∇P( w̃s ) . 3 update and repeat w̃s+1 = wm single pass computation; memory of O(one parameter) S. M. Kakade (MSR) One pass learning 14 / 16 Single Pass Least Squares Estimation Theorem κ effective condition number choose p > 2 schedule: increasing batch size ks = 2ks−1 . fixed m and η = 2p . If total sample size N is larger than multiple of κ (depends on p), then b N ) − P(w∗ )] ≤ 1.5 E[P(w σ 2 P(w̃0 ) − P(w∗ ) + N p N κ σ 2 /N is the ERM rate. constant → 1 with adaptation. (super)-poly decay of initial error. analysis handles estimation error general case: use self-concordance S. M. Kakade (MSR) One pass learning 15 / 16 Thanks! We can obtain (nearly) the same rate as the ERM in a single pass. See website for paper. Collaborators: R. Frostig S. M. Kakade (MSR) R. Ge One pass learning A. Sidford 16 / 16