Tradeoffs in Large Scale Learning: Statistical Accuracy vs. Numerical Precision

advertisement
Tradeoffs in Large Scale Learning:
Statistical Accuracy vs. Numerical Precision
Sham M. Kakade
Microsoft Research, New England
S. M. Kakade (MSR)
One pass learning
1 / 16
The Stochastic Optimization Problem
You have a stochastic optimization problem:
min P(w), where P(w) = Eψ∼D [ψ(w)] .
w
You get samples from D:
ψ1 , ψ2 , . . . ψN
How do you estimate w∗ , the minima of P?
Your expected error/excess risk/regret is:
b N ) − P(w∗ )]
E[P(w
S. M. Kakade (MSR)
One pass learning
2 / 16
Statistics vs. Computation
Stochastic approximation
Numerical analysis
e.g. stochastic gradient descent
e.g. batch gradient descent
easy to implement steps
(streaming)
costly to implement steps
1/ε steps for ε accuracy
log 1/ε steps for ε accuracy
good enough for statistics?
extra accuracy is worth it?
What would you do?
S. M. Kakade (MSR)
One pass learning
3 / 16
Without Computational Constraints...
What would like to you do?
Compute the empirical risk minimizer ERM:
b NERM ∈ argmin
w
w∈S
N
1X
ψi (w).
N
i=1
(regularizer can be included).
Consider the ratio:
b NERM ) − P(w∗ )]
E[P(w
.
b N ) − P(w∗ )]
E[P(w
Can you compete with the ERM on every problem efficiently?
S. M. Kakade (MSR)
One pass learning
4 / 16
This Talk
Under regularity conditions (strong convexity + smoothness),
a streaming algorithm which:
single pass algo, with space usage linear in the size of one sample
achieves the statistical rate of the ERM on every problem
(even considering constant factors)
(super)-polynomially decreases the initial error
(also, trivially parallelizable)
Related work: Juditsky and Polyak ’92; Defossez and Bach ’14;
S. M. Kakade (MSR)
One pass learning
5 / 16
Related work
Stochastic Approximation:
Robin & Monro (1951)
Polyak & Juditsky (1992)
Bottou & Bousquet (2008)
Defossez & Bach (2014)
Aside: Linear convergence for sums of convex functions
Strohmer & Vershynin (2009)
Le Roux, Schmidt, Bach (2012)
Shalev-Shwartz & Zhang, (2013)
Stochastic Variance Reduced Gradient (SVRG), Johnson & Zhang (2013)
S. M. Kakade (MSR)
One pass learning
6 / 16
Talk Outline
1
Statistics: How does the ERM even behave?
2
Computation: Let us just compute the ERM.
3
Computation + Statistics: ?
S. M. Kakade (MSR)
One pass learning
7 / 16
Statistics: Empirical Risk Minimization
b NERM ?
Precisely, what is the error of w
i
1 h
E k∇ψ(w∗ )k2(∇2 P(w∗ ))−1
2
Theorem: (classical, like Cramer-Rao)
σ 2 :=
independent samples ψ1 , ψ2 , . . .
ψ is convex (almost surely)
ψ are smooth functions (almost surely)
∇2 P(w∗ ) exists and is positive definite.
We have:
b NERM ) − P(w∗ )]
E[P(w
=1
N→∞
σ 2 /N
lim
S. M. Kakade (MSR)
One pass learning
8 / 16
Computation:
Offline optimization for sums of convex functions:
min F (w), where F (w) =
w
1X
ψi (w)
N
i
Assume: F is µ-strongly convex and each ψi is L-smooth.
Least squares as a running example.
ψi (w) = (Yi − w · Xi )2 ,
and so:
µ = λmin E[XX > ] and kX k2 ≤ L
effective condition number: κ =
S. M. Kakade (MSR)
w, Xi ∈ Rd , Yi ∈ R
L
µ
One pass learning
9 / 16
Stochastic Gradient Descent
F (w) =
1X
ψi (w)
N
i
SGD update rule: at time t,
sample ψ ∼ uniform{ψ1 , ψ2 , . . . ψN }
wt+1 ← wt − η∇ψ(wt )
let wmin be the minimizer of F
even if wt = wmin , after updating wt+1 6= wmin
How do you fix this?
S. M. Kakade (MSR)
One pass learning
10 / 16
Stochastic Variance Reduced Gradient (SVRG)
Iterate:
1
exact GD: at w̃s , compute the exact (batch) gradient: ∇F (w̃s )
note: ∇ψ(w̃s ) − ∇F (w̃s ) is mean zero for random ψ.
2
corrected SGD: initialize w0 = w̃s
sample ψ ∼ uniform{ψ1 , ψ2 , . . . ψN }
wt+1 ← wt − η (∇ψ(wt ) − ∇ψ(w̃s ) + ∇F (w̃s ))
3
update and repeat: w̃s+1 = wm (after m steps)
S. M. Kakade (MSR)
One pass learning
11 / 16
SVRG: the rate
Thm: (Johnson & Zhang, ’13) SVRG has geometric convergence for
fixed η.
E[F (w̃s ) − F (wmin )] ≤ e−s · (F (w̃0 ) − F (wmin ))
must store dataset
still requires many passes
What about the statistical rate?
S. M. Kakade (MSR)
One pass learning
12 / 16
Back to the stochastic optimization problem
Streaming setting:
1
we get random samples one at time ψ ∼ D
2
we care about out of sample error on P
do we need to handle both a:
3
generalization analysis?
computational analysis?
S. M. Kakade (MSR)
One pass learning
13 / 16
Our algo: Streaming SVRG
always draw fresh samples
1
estimation: at some w̃s ,
draw ks samples: ψ1 , ψ2 . . .
1 X
\
∇ψi (w̃s ).
∇P(
w̃s ) =
ks
i
2
update: for m steps, initialize w = w̃s :
sample ψ randomly
\
wt+1 ← wt − η ∇ψ(wt ) − ∇ψ(w̃s ) + ∇P(
w̃s ) .
3
update and repeat w̃s+1 = wm
single pass computation; memory of O(one parameter)
S. M. Kakade (MSR)
One pass learning
14 / 16
Single Pass Least Squares Estimation
Theorem
κ effective condition number
choose p > 2
schedule: increasing batch size ks = 2ks−1 . fixed m and η = 2p .
If total sample size N is larger than multiple of κ (depends on p), then
b N ) − P(w∗ )] ≤ 1.5
E[P(w
σ 2 P(w̃0 ) − P(w∗ )
+
N p
N
κ
σ 2 /N is the ERM rate.
constant → 1 with adaptation. (super)-poly decay of initial error.
analysis handles estimation error
general case: use self-concordance
S. M. Kakade (MSR)
One pass learning
15 / 16
Thanks!
We can obtain (nearly) the same rate as the ERM in a single pass.
See website for paper.
Collaborators:
R. Frostig
S. M. Kakade (MSR)
R. Ge
One pass learning
A. Sidford
16 / 16
Download