Challenges in Variational Inference: Optimization, Automation, and Accuracy Rajesh Ranganath December 11, 2015 Rajesh Ranganath Challenges in Variational Inference 1 / 25 Collaborators Dave Blei Alp Kucukelbir Stephan Mandt James McInerney Dustin Tran Rajesh Ranganath Challenges in Variational Inference 2 / 25 Review: Variational Inference Goal: Fit a distribution to the posterior with optimization Model: Model: p(x, z) Latent Variables: z Data: x Variational Inference: Approximating Family: q(z; λ) Minimize KL(q||p(z | x)) or maximize ELBO: L(λ) = Eq [log p(x, z) − log q(z; λ)] Rajesh Ranganath Challenges in Variational Inference 3 / 25 Problem: Local Optima ELBO for mixture model of two Gaussians µ2 5 0 −5 −5 Rajesh Ranganath 0 µ1 5 Challenges in Variational Inference 4 / 25 Problem: Local Optima ELBO for mixture model of two Gaussians µ2 5 0 −5 −5 0 5 µ1 1 Solution is to anneal the likelihood p(x | z) T Rajesh Ranganath Challenges in Variational Inference 4 / 25 Annealing 5 5 0 0 µ2 µ2 We show the variational objective for temperatures (left to right) T = 20, T = 13, T = 8.5 and T = 1. −5 −5 0 µ1 5 −5 5 5 0 0 µ2 µ2 −5 −5 0 5 0 5 µ1 −5 −5 0 µ1 5 Rajesh Ranganath −5 µ1 Challenges in Variational Inference 5 / 25 Annealing 5 5 0 0 µ2 µ2 Annealing slowly reduces temperature while optimizing the T-ELBO. −5 −5 0 µ1 5 −5 5 5 0 0 µ2 µ2 −5 −5 0 5 0 5 µ1 −5 −5 0 µ1 5 Rajesh Ranganath −5 µ1 Challenges in Variational Inference 5 / 25 Why only anneal the likelihood? Modern variational inference methods subsample data Rajesh Ranganath Challenges in Variational Inference 6 / 25 Why only anneal the likelihood? What happens when we anneal the prior and subsample data? Rajesh Ranganath Challenges in Variational Inference 7 / 25 Why only anneal the likelihood? What happens when we anneal the prior and subsample data? Consider Latent Dirichlet Allocation: The prior on the topics is Dirichlet with parameter α = η The annealed prior is Dirichlet with parameter α = η T Given a batch of documents the update for the topics is ( ) D∑ λt+1 = (1 − ρt )λt + ρt α + ϕdw Wdn . B d When a topic is not assigned a word it is quickly driven to α For T = 10 and η = .01, exp(Ψ( Tη ) − Ψ(η)) ≈ 10−400 This is the digamma problem or the zero forcing problem. Rajesh Ranganath Challenges in Variational Inference 7 / 25 Multicanonical Methods Annealing introduces a temperature sequence T1 to Tt Results are very sensitive to the temperature schedule Rajesh Ranganath Challenges in Variational Inference 8 / 25 Multicanonical Methods Annealing introduces a temperature sequence T1 to Tt Results are very sensitive to the temperature schedule Solution: Make T part of the model x z T Place a multinomial prior on T Variational update needs Z (T ) This trades a sequence of parameters settings for integral computation to renormalize. Rajesh Ranganath Challenges in Variational Inference 8 / 25 −7.0 500 topics, ArXiv data Temperatures, 500 topics, ArXiv −7.2 −7.4 SVI AVI, tA = 0.01 AVI, tA = 0.1 AVI, tA = 1 MVI −7.6 −7.8 0 2000 4000 6000 iterations 8000 expected or annealed T predictive log likelihood Results 10000 2.5 2.0 1.5 1.0 100 MVI AVI 0.01 AVI 0.1 AVI 1.0 SVI 101 102 iterations 103 104 Similar results on factorial mixtures Rajesh Ranganath Challenges in Variational Inference 9 / 25 Questions Is this procedure always better? Adding latent variables can introduce new optima? More generally, what are automated model transforms that preserve model semantics while improving computation? Rajesh Ranganath Challenges in Variational Inference 10 / 25 Problem: Automation A lot of variational inference methods are black box, but what happens if we try to develop them in a programming framework? Rajesh Ranganath Challenges in Variational Inference 11 / 25 Specifying a Model in Stan A ˛ D 1:5; D1 ✓ D I data { int N; // number of observations int x[N]; // discrete -valued observations } parameters { // latent variable , must be positive real <lower=0> theta; } model { // non-conjugate prior for latent variable theta ~ weibull(1.5, 1); xn N V // likelihood for (n in 1:N) x[n] ~ poisson(theta); } Figure 1: Specifying a simple nonconjugate probability model in Stan. iables r✓ log p.X; ✓/. The Rajesh gradient is validChallenges within inthe support of the prior Ranganath Variational Inference 12 / 25 Automatic Differentiation Variational Inference (ADVI) What does it work for? Differentiable models where the posterior has same support as the prior Rajesh Ranganath Challenges in Variational Inference 13 / 25 anyAutomatic differentiable Differentiation probability model toVariational the real coordinate space. (See Figure 1.) Inference (ADVI) are many ways to differentiably transform the support a model to the real coordinate sp of the transformation directly affects the shape of the variational approximation. In e study sensitivity to the choice of transformation. In any case, after this automatic t step, weHow can does choose a variational distribution independently from the model. it work? Density T 1 1 T !1 0 1 2 3 ! (a) Latent variable space Prior Posterior Approximation !1 0 1 2 " (b) Real coordinate space Transforming the model to real coordinate space. The purple line is the posterior. Posit a factorized normal approximation on this space is the approximation. (a) The latent variable space is R>0 . (a!b) T transforms the l pace to R. (b) The variational approximation is a Gaussian in real coordinate space. Rajesh Ranganath Challenges in Variational Inference 14 / 25 Automatic Differentiation Variational Inference (ADVI) Automatic Differentiation Variational Inference How does it work? S! 1 !1 S!!1 0 1 2 ! (a) Real coordinate space 1 Prior Posterior Approximation !2 !1 0 1 2 " (b) Standardized space Elliptical standardization. The purple line is the posterior. The green line is the appr ) The variational approximation is a Gaussian in real coordinate space. (a!b) S! abs ters of the Gaussian. (b) We maximize the elbo in the standardized space, with a fi aussian approximation. Rajesh Ranganath Challenges in Variational Inference 14 / 25 Automatic Differentiation Variational Inference (ADVI) How does it work? Use Monte Carlo estimate reparameterization gradient to optimize the ELBO ∇λ L(λ) = Es(ϵ) [∇z [log p(x, z)]∇λ z(ϵ)] + ∇λ H[q] Rajesh Ranganath Challenges in Variational Inference 14 / 25 ADVI: Does it work? D V 3 5 ADVI (M=1) 7 ADVI (M=10) NUTS HMC 9 10 1 100 101 Seconds (a) Linear Regression with Average Log Predictive Average Log Predictive A I 0:7 0:9 1:1 1:3 1:5 ADVI (M=1) ADVI (M=10) NUTS HMC 10 1 100 101 102 Seconds (b) Hierarchical Logistic Regression Figure 10: Hierarchical Generalized Linear Models. 4.2 Non-negative Matrix Factorization We continue by exploring two nonconjugate non-negative matrix factorization models: a constrained Gamma Poisson model (?) and a Dirichlet Exponential Poisson model. Here, we show how easy it is to explore new models using . In both models, we use the Frey Face dataset, which contains Rajesh Challenges in Variational Inference 15 / 25 1956 frames (28 ⇥ 20 pixels) of facialRanganath expressions extracted from a video sequence. The Role of Transforms There exist multiple maps from the constrained to the unconstrained space. For example from: R+ → R Density K ,T ,R ,G T 1 : log(x) and T 2 : log(exp(x) − 1) 1 1 0 1 2 (a) Gamma.1; 2/ Figure 9: 4. B True Posterior with T1 with T2 1 0 1 2 (b) Gamma.2:5; 4:2/ 0 1 2 ✓ (c) Gamma.10; 10/ approximations to Gamma densities under two different transformations. in Practice We now apply to an array of nonconjugate probability models. We study hierarchical regression, non-negative matrix mixture in models. ForInference these models, we compare Rajeshfactorization, Ranganath and Challenges Variational 16 / 25 The Role of Transforms There exist multiple maps from the constrained to the unconstrained space. For example from: R+ → R Density K ,T ,R ,G T 1 : log(x) and T 2 : log(exp(x) − 1) 1 1 0 1 2 (a) Gamma.1; 2/ Figure 9: B True Posterior with T1 with T2 1 0 1 2 (b) Gamma.2:5; 4:2/ 0 1 2 ✓ (c) Gamma.10; 10/ approximations to Gamma densities under two different transformations. The optimal transform can be written as ϕ−1 (P(z)) 4. in Practice We now apply to an array of nonconjugate probability models. We study hierarchical regression, non-negative matrix mixture in models. ForInference these models, we compare Rajeshfactorization, Ranganath and Challenges Variational 16 / 25 What’s the value of automation? Rajesh Ranganath Challenges in Variational Inference 17 / 25 What’s the value of automation? K ,T ,R ,G B 5 7 9 ADVI NUTS 11 101 102 103 Seconds 104 Average Log Predictive Average Log Predictive Studying multiple models 0 200 400 ADVI NUTS 600 101 102 103 Seconds 104 (a) Gamma Poisson Predictive Likelihood (b) Dirichlet Exponential Predictive Likelihood (c) Gamma Poisson Factors (d) Dirichlet Exponential Factors Figure 11: Non-negative matrix factorization of the Frey Faces dataset. Rajesh Ranganath Challenges in Variational Inference 17 / 25 Questions Can you learn to initialize from the Stan program? Is there a lightweight way to choose hyperparameters? Can we expand the class of models to say where the posterior support doesn’t match the prior? Rajesh Ranganath Challenges in Variational Inference 18 / 25 Problem: Accuracy Consider the model yt ∼ N (0, exp(ht /2)) where the volatility itself follows an auto-regressive process ht ∼ N (µ + ϕ(ht−1 − µ), σ) with initialization h1 ∼ N (µ, √ σ 1 − ϕ2 ). We posit the following priors for the latent variables µ ∼ Cauchy(0, 10), ϕ ∼ Unif(−1, 1), Rajesh Ranganath and σ ∼ LogNormal(0, 10). Challenges in Variational Inference 19 / 25 Mean-Field Variational Bayes Automatic Differentiation Variational Inference Posterior mean of h t t Sampling Mean-field Figure 6: Comparison of posterior mean estimates of volatility. ic volatility time-series model. Consider a stochastic volatility model that describ of an economic asset across discrete time measurements (?). An autoregressive p atent volatility; thus we expect posterior estimates of volatility to be correlated, esp Rajesh Ranganath Challenges in Variational Inference 20 / 25 Posit a richer approximation Instead of a factorized normal, consider a multivariate normal approximation on the unconstrained model. Rajesh Ranganath Challenges in Variational Inference 21 / 25 Correlations Help Expectations A D V I Posterior mean of h t t Sampling Mean-field Full-rank Figure 6: Comparison of posterior mean estimates of volatility. Fewer iterations are needed with the un-factorized approximation. ic volatility time-series model. Consider a stochastic volatility model that describ of an economic asset across discrete time measurements (?). An autoregressive p atent volatility; thus we expect posterior estimates of volatility to be correlated, esp Ranganath Challenges in Variational Inference 22 / 25 e volatility deviates from theRajesh mean. Variational Models Finding good variational distributions is modeling problem ✓ ⇠ D = {(s, t)} zi fi d (a) MODEL ξ ∼ Normal(0, I ), fi ∼VARIATIONAL GP(0, K )|Di Rajesh Ranganath Challenges in Variational Inference 23 / 25 Questions Can you choose dependence based on the property of interest of the posterior? What are other distances between probability distributions amenable to finding good posterior approximations? Rajesh Ranganath Challenges in Variational Inference 24 / 25 Thanks Rajesh Ranganath Challenges in Variational Inference 25 / 25