Manifold MCMC for Mixture Models Vassilios Stathopoulos Mark A. Girolami Department of Statistical Science University College London April 2011 Outline 1 Manifold MCMC Metropolis-Hastings Metropolis Adjusted Langevin Algorithm Manifold Metropolis Adjusted Langevin Algorithm Manifold Hamiltonian Monte Carlo 2 Finite Gaussian Mixture Models Model re-parameterisation Approximations of the Metric Tensor Examples 3 Ad-Mixtures Model re-parameterisation Metropolis within Gibbs scheme Example 4 Conclusions and Discussion Metropolis-Hastings Aim: sample from intricate distribution π(θ) – Bayesian: π(θ) ∝ p(θ)p(x|θ). Metropolis-Hastings Algorithm At initialization t = 0, 1 set θ 0 arbitrarily At each iteration t ≥ 1 1 sample (θ ∗ ) from q(·|θ t−1 ) 2 with probability ( π(θ ∗ ) q(θ t−1 |θ ∗ ) min 1, π(θ t−1 ) q(θ ∗ |θ t−1 ) accept: set θ k = θ ∗ , else reject: set θ k = θ k−1 . In the simplest case q(θ ∗ |θ t−1 ) = N (θ ∗ |θ t−1 , I) ) Metropolis Adjusted Langevin Algorithm • [Roberts & Tweedie1996], for θ ∈ RD with density π(θ), L(θ) ≡ log π(θ), define Langevin diffusion dθ(t) = 1 ∇θ L(θ(t))dt + db(t) 2 • First order Euler-Maruyama discrete integration of diffusion θ(τ + ) = θ(τ ) + 2 ∇θ L(θ(τ )) + z(τ ) 2 • Proposal q(θ ∗ |θ) = N (θ ∗ |µ(θ, ), 2 I) with µ(θ, ) = θ + 2 ∇θ L(θ) 2 • Isotropic diffusion inefficient, employ pre-conditioning √ θ ∗ = θ + 2 M∇θ L(θ)/2 + Mz • How to set M systematically? Tuning in transient & stationary phases Manifold Metropolis Adjusted Langevin [Girolami & Calderhead2011], discretised Langevin diffusion on manifold defines proposal mechanism θ ∗d = θ d + D p X 2 −1 d G (θ)∇θ L(θ) − 2 G(θ)−1 G−1 (θ)z i,j Γi,j + 2 d d i,j q(θ ∗ |θ t−1 ) = N θ ∗ |µ(θ t−1 , ), 2 G−1 (θ t−1 ) Where µ(θ, )d = θ d + D X 2 −1 d G (θ)∇θ L(θ) − 2 G(θ)−1 i,j Γi,j 2 d i,j Metric tensor G(θ) = cov (∇θ L(θ)) is the expected FI and Γdi,j are the Christoffel symbols ∂G(θ)m,j ∂G(θ)i,j ∂G(θ)m,i 1 X −1 d Γi,j = G (θ)d,m + − 2 m ∂θ j ∂θ i ∂θ m Manifold Metropolis Adjusted Langevin [Girolami & Calderhead2011], discretised Langevin diffusion on manifold defines proposal mechanism θ ∗d = θ d + D p X 2 −1 d G (θ)∇θ L(θ) − 2 G(θ)−1 G−1 (θ)z i,j Γi,j + 2 d d i,j q(θ ∗ |θ t−1 ) = N θ ∗ |µ(θ t−1 , ), 2 G−1 (θ t−1 ) Where µ(θ, )d = θ d + D X 2 −1 d G (θ)∇θ L(θ) − 2 G(θ)−1 i,j Γi,j 2 d i,j Metric tensor G(θ) = cov (∇θ L(θ)) is the expected FI and Γdi,j are the Christoffel symbols ∂G(θ)m,j ∂G(θ)i,j ∂G(θ)m,i 1 X −1 d Γi,j = G (θ)d,m + − 2 m ∂θ j ∂θ i ∂θ m Simplified MMALA assumes that the metric is not changing locally Manifold Hamiltonian Monte Carlo [Girolami & Calderhead2011], design proposal mechanism that follows direct paths - geodesics Introduce auxiliary variable p ∼ N (0, G(θ)). Hamiltonian defined on Riemann manifold using the negative joint log likelihood is 1 1 H(θ, p) = −L(θ) + log(2π)D |G(θ)| + pT G(θ)−1 p 2 2 | {z } | {z } Potential energy Kinetic energy Sample using a Metropolis within Gibbs scheme pt ∼ N (0, G(θ t−1 )) θ t ∼ p(θ|pt ) Second step involves integrating ∂ dθ = H(θ, p) dt ∂p dp ∂ = − H(θ, p) dt ∂θ using a time reversible, volume preserving numerical integration such as generalised leapfrog. Finite Gaussian Mixture Models p(x|θ) = K X πk N (x|µk , Σk ) k=1 • Gibbs sampler is straightforward using a data augmentation scheme and conjugate priors. [McLachlan & Peel2000] • Poor mixing especially when components overlap.[Celeux et al.2000] • Attracted by local modes of the posterior [Marin et al.2005]. • Both problems related to the conditional dependence of latent variables and model parameters. • Simulated tempering with Metropolis-Hastings base kernels [Celeux et al.2000, Jarsa et. al.2005]. • Manifold MCMC algorithms can be used as base kernels for simulated tempering and population MCMC. Re-parameterisation & priors • Univariate mixtures of Gaussians πk = PK e αk k 0 =1 e αk 0 σk2 = eγk , , p(αk ) = G(eαk |λ, 1)eαk p(γk ) = IG(σk2 |b, c)eγk Re-parameterisation & priors • Univariate mixtures of Gaussians πk = PK e αk k 0 =1 e αk 0 σk2 = eγk , , p(αk ) = G(eαk |λ, 1)eαk p(γk ) = IG(σk2 |b, c)eγk • Multivariate mixtures of Gaussians, [Pinheiro & Bates1996] Σ = LLT , p(B k ) = e((Bk )1,1 ) (B k )2,1 Lk = .. . (B k )D,1 D Y d=1 0 .. . .. . ... 0 0 0 0 e((Bk )d,d ) (B k )D,D−1 0 e((Bk )D,D ) G(e(Bk )d,d |ω, ψ)e(Bk )d,d Y d 0 6=d N ((B k )d 0 ,d |0, β) Metric Tensor • For mixture models the expected FI is not explicit. • Empirical FI using a finite sample estimate, [McLachlan & Peel2000] G(θ) = ∂G(θ) ∂θ d = where Si,d = 1 ST S − s̄s̄T −−−−→ cov (∇θ L(θ)) N→∞ N T 1 ∂ s̄T ∂S ∂S ∂ s̄ T S + ST − s̄ + s̄ ∂θ d ∂θ d N ∂θ d ∂θ d ∂ log p(xi |θ) ∂θ d and s̄ = PN n=1 Si,· Metric Tensor • For mixture models the expected FI is not explicit. • Empirical FI using a finite sample estimate, [McLachlan & Peel2000] G(θ) = ∂G(θ) ∂θ d = where Si,d = 1 ST S − s̄s̄T −−−−→ cov (∇θ L(θ)) N→∞ N T 1 ∂ s̄T ∂S ∂S ∂ s̄ T S + ST − s̄ + s̄ ∂θ d ∂θ d N ∂θ d ∂θ d ∂ log p(xi |θ) ∂θ d and s̄ = PN n=1 Si,· • Alternative metric between densities, [Basu et al.1998] Z L2 = |p(x|θ) − p(x|θ + δθ)|2 dx Z G(θ) = ∇θ p(x|θ)∇Tθ p(x|θ)dx Example p(x|µ, σ 2 ) = 0.7 × N (x|0, σ 2 ) + 0.3 × N (x|µ, σ 2 ) 4 3 2 lnσ 1 0 −1 −2 −3 −4 −4 −3 −2 −1 0 µ 1 2 3 4 Figure: Arrows correspond to the gradients and ellipses to the inverse metric tensor. Dashed lines are isocontours of the joint log density Example MALA path Simplified mMALA path mMALA path 4 4 4 −1410 −1410 −1410 −1310 3 −1310 3 −1210 −1110 1 1 −910 0 −1 lnσ −910 lnσ lnσ −910 0 −1410 −2 −1 0 µ 1 2 −1410 −2 −2910 −3 −4 −4 4 −3 −2 MALA Sample Autocorrelation Function −1 0 µ 1 2 3 −4 −4 4 0.2 0 0.6 0.4 0.2 0 6 8 10 Lag 12 14 16 18 20 −0.2 −2 −1 0 µ 1 2 3 4 0.8 Sample Autocorrelation Sample Autocorrelation 0.4 4 −3 Simplified mMALA Sample Autocorrelation Function 0.8 0.6 2 −2910 −3 mMALA Sample Autocorrelation Function 0.8 0 −1410 −2 −2910 −3 3 0 −1 −1 −2 Sample Autocorrelation −1010 −1010 1 −0.2 −1110 2 2 −1010 −3 −1210 −1110 2 −4 −4 −1310 3 −1210 0.6 0.4 0.2 0 0 2 4 6 8 10 Lag 12 14 16 18 20 −0.2 0 2 4 6 8 10 Lag 12 14 16 Figure: Comparison of MALA (left), mMALA (middle) and simplified mMALA (right) convergence paths and autocorrelation plots. Autocorrelation plots are from the stationary chains, i.e. once the chains have converged to the stationary distribution. 18 20 Example HMC path RMHMC path 4 Gibbs path 4 4 −1410 −1410 −1310 3 −1410 −1310 3 −1210 −1110 2 −1010 1 1 −1 −910 lnσ −910 lnσ lnσ −910 0 −1 −1410 −2 −2 −1 0 µ 1 2 −1410 −4 −4 4 −3 −2 HMC Sample Autocorrelation Function −1 0 µ 1 2 3 −4 −4 4 0.2 0 0.6 0.4 0.2 0 6 8 10 Lag 12 14 −2 16 18 20 −0.2 −1 0 µ 1 2 3 4 0.8 Sample Autocorrelation Sample Autocorrelation 0.4 4 −3 Gibbs Sample Autocorrelation Function 0.8 0.6 2 −2910 −3 RMHMC Sample Autocorrelation Function 0.8 0 −1410 −2 −2910 −3 3 0 −1 −2 −2910 −3 Sample Autocorrelation −1010 1 0 −0.2 −1110 2 −1010 −3 −1210 −1110 2 −4 −4 −1310 3 −1210 0.6 0.4 0.2 0 0 2 4 6 8 10 Lag 12 14 16 18 20 −0.2 0 2 4 6 8 10 Lag 12 14 16 Figure: Comparison of HMC (left), RMHMC (middle) and GIBBS (right) convergence paths and autocorrelation plots. Autocorrelation plots are from the stationary chains, i.e. once the chains have converged to the stationary distribution. 18 20 MMALA 164 620 175 610 Simp. MMALA Experiments with univariate mixtures 910 3456 RMHMC GIBBS HMC 864 446 1.5 236 754 0.4 0.3 1 0.2 0.5 0.1 557 542 2857 363 745 461 495 3186 204 291 459 149 2444 119 461 0.3 0.6 0.2 0.4 0.2 STD L2 MMALA Simp. MMALA RMHMC GIBBS HMC Kurtotic 0 Bimodal 0 −2 0 2 Figure: 0.1 Skewed 0 0 2 36 −2 48 30 87 96 268 Densities used 52 18 34 82 −2 to 56 59 143 generate 24 113 0 2 0 0 3000 3000 2000 2000 1000 1000 MMALA ESS/Time FI Kurtotic Claw Simp. MMALA Bimodal 2 110 19 120 7 122 2 ESS L2 4000 Kurtotic Bimodal Skewed Trimodal Claw −2 0 74 52 synthetic224 datasets. 14 24 −2 ESS FI 4000 0 0.2 Trimodal 0 Skewed 0 Kurtotic Bimodal Skewed Trimodal RMHMC Trimodal GIBBS Claw HMC Claw Time L2 Bimodal Kurtotic Skewed Trimodal 15.52 14.64 MMALA Experiments with univariate mixtures 8.55 8.38 Simp. MMALA RMHMC GIBBS HMC 113.89 132.37 37.47 111.74 130.45 36.13 15 8.63 115.57 131.56 37.03 ESS/time FI 100 80 80 60 60 40 40 20 20 Kurtotic Bimodal MMALA Skewed Trimodal 242.37 36.38 1288 327.62 239.32 ESS/time L2 100 0 Claw 33.03 16.48 249.01 264.22 75.67 Claw Simp. MMALA 0 Kurtotic RMHMC Bimodal Skewed GIBBS Trimodal Claw HMC Experiments with bi-variate mixtures Table: Densities used to generate synthetic datasets π1 = 0.8 π2 = 0.2 T µ1 = µ 2 = [0, 0] 1 0.5 Σ1 = 0.5 1 π1 = 0.5 π2 T µ1 = [0, 0] 1 Σ1 = 0.8 Σ2 = 0.1 0 = 0.5 µ2 = [0.8, 0.8]T 0.8 1 Σ2 = 1 −0.8 0 0.01 −0.8 1 π1 = 0.55 π2 = 0.225 π3 = 0.225 T T µ1 = [0, µ2= [1.8, 0]T µ3= [−1.8, 0] 0] 1 0 0.2 0 Σ1 = Σ2 = Σ3 = 0 0.5 0 1 Time MMALA Simp. MMALA 58.0195681 43.0124589 80.2944338 RMHMC GIBBS HMC 213.867983 214.328976 444.477031 8.74193067 11.025021 21.5238343 64.4478987 67.1533293 108.348163 Kurtotic Cross Trimodal 169.418564with 156.117962 393.924618 Experiments bi-variate mixtures ESS ESS/time 900 50 675 37.5 450 25 225 12.5 0 Kurtotic Cross MMALA Trimodal Simp. MMALA 0 Kurtotic RMHMC Cross GIBBS Trimodal HMC Time MMALA Simp. MMALA 58.0195681 43.0124589 80.2944338 RMHMC GIBBS HMC 213.867983 214.328976 444.477031 8.74193067 11.025021 21.5238343 64.4478987 67.1533293 108.348163 Kurtotic Cross Trimodal 169.418564with 156.117962 393.924618 Experiments bi-variate mixtures ESS ESS/time 900 50 675 37.5 450 25 225 12.5 0 Kurtotic Cross MMALA Trimodal Simp. MMALA 0 Kurtotic RMHMC Cross GIBBS Trimodal HMC Computational complexity of manifold MCMC, (K × [D + D × (D + 1)/2 + 1])3 Gibbs sampler, K [D + (D(D + 1)/2)3 + 1] Ad-Mixtures • Popular models in population genetics. • The genotype of an individual is an ad-mixture of K unknown sub-populations. • Goal: given the genotype of a sample of individuals at specific loci infer their ancestry proportions and sub-population allele frequencies. • Applications in document analysis • A document is an ad-mixture of K unknown topics. • Give a collection of documents infer topic distributions and topic proportions for each document. Ad-Mixtures • θ n mixing proportions for nth observation. • φk parameters of the k th sub-population/topic. The generative process for ad-mixtures is θ n ∼ DK (α), n ∈ {1, . . . , N} Mn ∼ P(λ), n ∈ {1, . . . , N} φk ∼ DT (β), k ∈ {1, . . . , K } x n ∼ M(θ n Φ, Mn ), n ∈ {1, . . . , N} The likelihood can be written as !xn,t p(X |Θ, Φ) = YY X n t k θn,k φk,t Ad-Mixtures • θ n mixing proportions for nth observation. • φk parameters of the k th sub-population/topic. The generative process for ad-mixtures is θ n ∼ DK (α), n ∈ {1, . . . , N} Mn ∼ P(λ), n ∈ {1, . . . , N} φk ∼ DT (β), k ∈ {1, . . . , K } x n ∼ M(θ n Φ, Mn ), n ∈ {1, . . . , N} The likelihood can be written as !xn,t p(X |Θ, Φ) = YY X n t θn,k φk,t k Re-parameterisation eγn,k γn,k 0 , k0 e θn,k = P eψk ,t ψk ,t 0 k0 e φk,t = P Manifold MCMC for Ad-Mixtures • Number of parameters increases linearly with the observations. K × (N + T ) • Exploit conditional independence E ∂L(Γ, Ψ) ∂L(Γ, Ψ) = 0, ∂γn,k ∂γi,j ∀i 6= n • Metropolis within Gibbs scheme γ n |Γ/n , Ψ ∼ p(γ n |Γ/n , Ψ), Ψ|Γ ∼ p(Ψ|Γ) • Each step is a simplified MMALA sampler ∀n Synthetic data example ESS \theta \phi std M-H • 6.52 4.83 Simp. MMALA 9.1 8.25 200 observations with average length HMC 10.32 7.12 markers / terms, α = 0.5, β = 0.1. GIBBS 55.86 20.45 \theta 25, 2.4 1.8 MMALA 2.2 2.6 fromSimp. 6 sub-populations with HMC 2.9 2.2 GIBBS 10.1 5.7 ESS 70 52.5 35 17.5 0 M-H \theta Simp. MMALA \phi M-H \phi HMC GIBBS 9 Synthetic data example Gibbs Trace Plot Simplified MMALA Trace Plot 0.38 0.3 0.36 0.28 0.34 0.26 0.24 φ1,1 φ1,1 0.32 0.3 0.22 0.28 0.2 0.26 0.18 0.24 0.16 0.22 0 100 200 300 400 500 600 700 800 900 1000 0 100 200 300 400 t 500 600 700 800 900 1000 800 900 1000 t Gibbs Trace Plot Simplified MMALA Trace Plot −3 x 10 0.06 8 7 0.05 6 0.04 φ1,4 φ1,4 5 0.03 4 3 0.02 2 0.01 1 0 0 0 100 200 300 400 500 t 600 700 800 900 1000 0 100 200 300 400 500 t 600 700 Conclusions • Approximations of the expected FI. • Observed FI is effective but computationally expensive. • Simplified MMALA as effective as MMALA. • Problematic on skew posteriors. • Algorithms scale as O(D 3 ), D the number of parameters. • Cumbersome derivations, consider automatic differentiation. Discussion points • Appropriate geometry for finite mixtures ? • Geometry for mixtures with unknown number of components and applications in MCMC ? • Overcoming computational cost ? • Alternative re-parameterisations ? References Ayanendranath Basu, R. Harris, Ian, Nils L. Hjort, and M. C. Jones. Robust and efficient estimation by minimising a density power divergence. Biometrika, 85(3):549–559, 1998. Gilles Celeux, Merrilee Hurn, and Christian P. Robert. Computational and inferential difficulties with mixture posterior distributions. Journal of the American Statistical Association, 95(451):957–970, 2000. Mark Girolami and Ben Calderhead. Riemann manifold Langevin and Hamiltonian Monte Carlo methods. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 73(2):123–214, 2011. A. Jasra, C. Holmes, and D. A. Stephens MCMC and the label switching problem in Bayesian mixture models. Statistical Science, 20:50–67, 2005. M. J. Marin, K. Mengersen, and P. Robert, C. Bayesian Modelling and Inference on Mixtures of Distributions. Handbook of Statistics, pages 15840–15845, 2005. References Geoffrey McLachlan and David Peel. Finite Mixture Models. Wiley Series in Probability and Statistics. Wiley-Interscience, October 2000. C. Pinheiro, Jose and M. Bates, Douglas. Unconstrained parametrizations for variance-covariance matrices. Statistics and Computing, 6(3):289–296, 1996. G. O. Roberts and R. O. Tweedie Exponential convergence of Langevin distributions and their discrete approximations Bernouli,2(4):341–363, 1996