Exact and approximate inference in probabilistic graphical models Kevin Murphy MIT CSAIL UBC CS/Stats www.ai.mit.edu/~murphyk/AAAI04 AAAI 2004 tutorial SP2-1 Outline • Introduction • Exact inference • Approximate inference SP2-2 Outline • Introduction – What are graphical models? – What is inference? • Exact inference • Approximate inference SP2-3 Probabilistic graphical models Probabilistic models Graphical models Undirected Directed (Markov Random fields - MRFs) (Bayesian networks) SP2-4 Bayesian networks • Directed acyclic graph (DAG) – Nodes – random variables – Edges – direct influence (“causation”) Earthquake • Xi ? Xancestors | Xparents • e.g., C ? {R,B,E} | A • Simplifies chain rule by using conditional independencies Radio Burglary Alarm Call Pearl, 1988 SP2-5 Conditional probability distributions (CPDs) Earthquake • Each node specifies a distribution over its values given its parents values P(Xi | XPai) • Full table needs 25-1=31 E parameters, BN needs 10 e Pearl, 1988 SP2-6 Radio Burglary Alarm B P(A|E,B) b 0.9 0.1 e b 0.2 0.8 e b 0.9 0.1 e b 0.01 0.99 Call Example BN: Hidden Markov Model (HMM) Hidden states eg. words Observations eg. sounds X1 Y1 X2 X3 Y2 Y3 SP2-7 1 2 3 CPDs for HMMs A=state transition matrix X1 X2 X3 A Parameter tyeing B Y1 Y3 Y2 Transition matrix Observation matrix Initial state distribution SP2-8 Markov Random Fields (MRFs) • Undirected graph • Xi ? Xrest | Xnbrs • Each clique c has a potential function c X1 X2 X3 X4 X5 Hammersley-Clifford theorem: P(X) = (1/Z) c c(Xc) The normalization constant (partition function) is SP2-9 Potentials for MRFs One potential per maximal clique One potential per edge 12 X1 X2 X1 123 X2 23 13 X3 34 X4 X3 35 34 X5 X4 SP2-10 35 X5 Example MRF: Ising/ Potts model y Parameter tying x Compatibility with neighbors SP2-11 Local evidence (compatibility with image) Lafferty01,Kumar03,etc Conditional Random Field (CRF) y Parameter tying x Compatibility with neighbors SP2-12 Local evidence (compatibility with image) Directed vs undirected models X1 X2 X2 moralization X3 X4 X1 X3 X5 X4 X5 separation => cond. indepence d-separation => cond. indepence Parameter learning hard Parameter learning easy Inference is same! SP2-13 Factor graphs X1 Bayes net X2 X1 X2 X1 Pairwise Markov net X4 X5 X1 X2 X3 X5 X2 X1 X4 X5 Bipartite graph X2 X4 X3 X3 X4 X1 X3 X5 X4 X2 Markov net X3 Kschischang01 SP2-14 X3 X5 X4 X5 Outline • Introduction – What are graphical models? – What is inference? • Exact inference • Approximate inference SP2-15 Inference (state estimation) Earthquake Radio Burglary Alarm Call C=t SP2-16 P(E=t|C=t)=0.1 Inference P(B=t|C=t) = 0.7 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 Earthquake Radio Burglary Alarm Call C=t SP2-17 P(E=t|C=t)=0.1 Inference P(B=t|C=t) = 0.7 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 Earthquake Radio Burglary Alarm R=t Call C=t SP2-18 P(E=t|C=t)=0.1 Inference P(B=t|C=t) = 0.7 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 Earthquake Burglary P(B=t|C=t,R=t) = 0.1 P(E=t|C=t,R=t)=0.97 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 Radio Alarm R=t Call C=t SP2-19 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 P(E=t|C=t)=0.1 Inference P(B=t|C=t) = 0.7 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 Earthquake Burglary P(B=t|C=t,R=t) = 0.1 P(E=t|C=t,R=t)=0.97 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 Radio Alarm R=t Call C=t Explaining away effect SP2-20 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 P(E=t|C=t)=0.1 Inference P(B=t|C=t) = 0.7 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 Earthquake Burglary P(B=t|C=t,R=t) = 0.1 P(E=t|C=t,R=t)=0.97 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 Radio Alarm R=t 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 Call C=t “Probability theory is nothing but common sense reduced to calculation” – Pierre Simon Laplace SP2-21 Inference tasks • Posterior probabilities of Query given Evidence – Marginalize out Nuisance variables – Sum-product • Most Probable Explanation (MPE)/ Viterbi – max-product • “Marginal Maximum A Posteriori (MAP)” – max-sum-product SP2-22 Causal vs diagnostic reasoning • Sometimes easier to specify P(effect|cause) than P(cause|effect) [stable mechanism] • Use Bayes’ rule to invert causal model Diseases, H Symptoms, v SP2-23 Applications of Bayesian inference Domain Speech rec. Image proc. Info. coding Medical diag. Fault diagnosis Biology Visible data v Acoustics Noisy image Received msg Symptoms Outliers Current species SP2-24 Hidden cause h Words Clean image Transmitted msg Diseases Faulty sensor Evolutionary tree Decision theory • Decision theory = probability theory + utility theory • Bayesian networks + actions/utilities = influence/ decision diagrams • Maximize expected utility SP2-25 Outline • Introduction • Exact inference – Brute force enumeration – Variable elimination algorithm – Complexity of exact inference – Belief propagation algorithm – Junction tree algorithm – Linear Gaussian models • Approximate inference SP2-26 Brute force enumeration • We can compute in O(KN) time, where K=|Xi| B E A J M • By using BN, we can represent joint in O(N) space SP2-27 Brute force enumeration Russell & Norvig SP2-28 Enumeration tree Russell & Norvig SP2-29 Enumeration tree contains repeated sub-expressions SP2-30 Kschischang01,Dechter96 Variable/bucket elimination • Push sums inside products (generalized distributive law) • Carry out summations right to left, storing intermediate results (factors) to avoid recomputation (dynamic programming) SP2-31 Variable elimination SP2-32 VarElim: basic operations • Pointwise product • Summing out Only multiply factors which contain summand (lazy evaluation) SP2-33 Variable elimination Russell & Norvig SP2-34 Outline • Introduction • Exact inference – Brute force enumeration – Variable elimination algorithm – Complexity of exact inference – Belief propagation algorithm – Junction tree algorithm – Linear Gaussian models • Approximate inference SP2-35 VarElim on loopy graphs Let us work right-to-left, eliminating variables, and adding arcs to ensure that any two terms that co-occur in a factor are connected in the graph 2 4 Elim 6 1 6 3 5 2 Elim 5 4 1 Elim 4 6 3 SP2-36 5 2 4 1 6 3 5 Complexity of VarElim • Time/space for single query = O(N Kw+1) for N nodes of K states, where w=w(G, ) = width of graph induced by elimination order • w* = argmin w(G,) = treewidth of G • Thm: finding an order to minimize treewidth is NP-complete • Does there exist a more efficient exact inference algorithm? Yannakakis81 SP2-37 Exact inference is #P-complete • Can reduce 3SAT to exact inference => NP-hard • Equivalent to counting num. satisfying assignments => #P-complete Literals Clauses Sentence A C B C1 C2 D P(A)=P(B)=P(C)=P(D)=0.5 C1 = A v B v C C2 = C v D v ~A C3 = B v C v ~D C3 S S = C1 ^ C2 ^ C3 SP2-38 Dagum93 Summary so far • Brute force enumeration O(KN) time, O(N KC) space (where C=max clique size) • VarElim O(N Kw+1) time/space – w = w(G,) = induced treewidth • Exact inference is #P-complete – Motivates need for approximate inference SP2-39 Treewidth Low treewidth High tree width Chains N=nxn grid W* = 1 Trees (no loops) MINVOLSET W* = O(n) = O(p N) PULMEMBOLUS INTUBATION KINKEDTUBE VENTMACH DISCONNECT PAP SHUNT VENTLUNG VENITUBE Loopy graphs PRESS MINOVL VENTALV PVSATARTCO2 TPR EXPCO2 SAO2 INSUFFANESTH HYPOVOLEMIA LVFAILURE CATECHOL LVEDVOLUME STROEVOLUME HRERRCAUTER ERRBLOWOUTPUT HISTORY CVP PCWP CO HREKGHRSAT HRBP BP W* = #parents W* = NP-hard to find SP2-40 Arnborg85 Graph triangulation Golumbic80 • A graph is triangulated (chordal, perfect) if it has no chordless cycles of length > 3. • To triangulate a graph, for each node Xi in order , ensure all neighbors of Xi form a clique by adding fill-in edges; then remove Xi 2 4 Elim 6 1 6 3 5 2 Elim 5 4 1 Elim 4 6 3 SP2-41 5 2 4 1 6 3 5 Graph triangulation • A graph is triangulated (chordal) if it has no chordless cycles of length > 3 5 Not triangulated! 2 3 5 6 7 4 1 2 8 9 3 7 4 1 8 6 9 Chordless 6-cycle SP2-42 Graph triangulation • Triangulation is not just adding triangles 5 Still not triangulated… 2 3 5 6 7 4 1 2 8 9 3 7 4 1 8 6 9 Chordless 4-cycle SP2-43 Graph triangulation • Triangulation creates large cliques 5 Triangulated at last! 1 2 3 5 6 7 4 2 8 9 3 SP2-44 7 4 1 8 6 9 Finding an elimination order • The size of the induced clique depends on the elimination order. • Since this is NP-hard to optimize, it is common to apply greedy search techniques: • At each iteration, eliminate the node that would result in the smallest Kjaerulff90 – Num. fill-in edges [min-fill] – Resulting clique weight [min-weight] (Weight of clique = product of number of states per node in clique) • There are some approximation algorithms SP2-45 Amir01 Speedup tricks for VarElim • Remove nodes that are irrelevant to the query • Exploit special forms of P(Xi|Xi) to sum out variables efficiently SP2-46 Irrelevant variables • M is irrelevant to computing P(j|b) B E A M J =1 • Thm: Xi irrelevant unless Xi 2 Ancestors({XQ} [ XE) • Here, Ancestors({J} [ {B}) = {A,E} • => hidden leaves (barren nodes) can always be removed SP2-47 Irrelevant variables B E A J M • M, B and E irrelevant to computing P(j|a) • All variables relevant to a query can be found in O(N) time • Variable elimination supports query-specific optimizations SP2-48 Structured CPDs • Sometimes P(Xi|Xpii) has special structure, which we can exploit computationally – Context-specific independence (eg. CPD=decision tree) Boutilier96b,Zhang99 – Causal independence (eg CPD=noisy-OR) – Determinism Zweig98,Bartels04 • Such “non-graphical” structure complicates the search for the optimal triangulation Bartels04 SP2-49 Rish98,Zhang96b Outline • Introduction • Exact inference – Brute force enumeration – Variable elimination algorithm – Complexity of exact inference – Belief propagation algorithm – Junction tree algorithm – Linear Gaussian models • Approximate inference SP2-50 What’s wrong with VarElim • Often we want to query all hidden nodes. • VarElim takes O(N2 Kw+1) time to compute P(Xi|xe) for all (hidden) nodes i. • There exist message passing algorithms that can do this in O(N Kw+1) time. • Later, we will use these to do approximate inference in O(N K2) time, indep of w. X1 X2 X3 Y1 Y2 Y3 SP2-51 Repeated variable elimination leads to redundant calculations X1 X2 X3 Y1 Y2 Y3 O(N2 K2) time to compute all N marginals SP2-52 Forwards-backwards algorithm Rabiner89,etc Xt Y1:t-1 Xt Yt Xt+1 Yt+1:N (Use dynamic programming to compute these) Forwards prediction Local evidence Backwards prediction SP2-53 Forwards algorithm (filtering) Xt Xt Y1:t-1 Yt SP2-54 Backwards algorithm Xt Xt SP2-55 Xt+1 Xt+2 Yt+1 Yt+2:N Forwards-backwards algorithm 1 X1 b1 … 12 X12 b12 … 24 X24 b24 • Forwards • Backwards Backwards messages independent of forwards messages • Combine O(N K2) time to compute all N marginals, not O(N2 K2) SP2-56 Pearl88,Shafer90,Yedidia01,etc Belief propagation • Forwards-backwards algorithm can be generalized to apply to any tree-like graph (ones with no loops). • For now, we assume pairwise potentials. SP2-57 Absorbing messages Xt-1 Xt Xt+1 Yt SP2-58 Sending messages Xt-1 Xt Xt+1 Yt SP2-59 Centralized protocol Collect to root (post-order) R Distribute from root (pre-order) R 3 4 2 1 5 1 3 5 2 4 Computes all N marginals in 2 passes over graph SP2-60 Distributed protocol Computes all N marginals in O(N) parallel updates SP2-61 Loopy belief propagation • Applying BP to graphs with loops (cycles) can give the wrong answer, because it overcounts evidence Cloudy Sprinkler Rain WetGrass • In practice, often works well (e.g., error correcting codes) SP2-62 Why Loopy BP? • We can compute exact answers by converting a loopy graph to a junction tree and running BP (see later). • However, the resulting Jtree has nodes with O(Kw+1) states, so inference takes O(N Kw+1) time [w=clique size of triangulated graph]. • We can apply BP to the original graph in O(N KC) time [C = clique size of original graph]. • To apply BP to a graph with non pairwise potentials, it is simpler to use factor graphs. SP2-63 Factor graphs X1 Bayes net X2 X1 X2 X1 Pairwise Markov net X4 X5 X1 X2 X3 X5 X2 X1 X4 X5 Bipartite graph X2 X4 X3 X3 X4 X1 X3 X5 X4 X2 Markov net X3 Kschischang01 SP2-64 X3 X5 X4 X5 Kschischang01 BP for factor graphs • Beliefs f • Message: variable to factor x f’’ f’ • Message: factor to variable x f(x,y,x) y SP2-65 z Sum-product vs max-product • Sum-product computes marginals using this rule • Max-product computes max marginals using the rule • Same algorithm on different semirings: (+,x,0,1) and (max,x,-1,1) Shafer90,Bistarelli97,Goodman99,Aji00 SP2-66 Viterbi decoding Compute most probable explanation (MPE) of observed data Hidden Markov Model (HMM) X1 Y1 X2 X3 hidden Y2 Y3 observed “Tomato” SP2-67 Viterbi algorithm for HMMs • Run max forwards algorithm, keeping track of most probable predecessor for each state • Pointer traceback • Can produce “N-best list” (most probable configurations) in O(N T K2) time Forney73,Nilsson01 SP2-68 Loopy Viterbi • Use max-product to compute/ approximate • If there are no ties and the max-marginals are exact, then • This method does not use traceback, so can be used with distributed/ loopy BP • We can break ties, and produce N most-probable configurations, by asserting that certain assignments are disallowed, and rerunning Yanover04 SP2-69 BP speedup tricks • Sometimes we can reduce the time to compute a message from O(K2) to O(K) – If (xi,xj) = exp(||f(xi) – f(xj)||2), then • Sum-product in O(K log K) time [exact FFT] or O(K) time [approx] Felzenszwalb03/04,Movellan04,deFreitas04 • Max-product in O(K) time [distance transform] – For general (discrete) potentials, we can dynamically add/delete states to reduce K Coughlan04 • Sometimes we can speedup convergence by – Using a better message-passing schedule (e.g., along embedded spanning trees) Wainwright01 – Using a multiscale method Felzenszwalb04 SP2-70 Outline • Introduction • Exact inference – Variable elimination algorithm – Complexity of exact inference – Belief propagation algorithm – Junction tree algorithm – Linear Gaussian models • Approximate inference SP2-71 Junction/ join/ clique trees • To perform exact inference in an arbitrary graph, convert it to a junction tree, and then perform belief propagation. • A jtree is a tree whose nodes are sets, and which has the Jtree property: all sets which contain any given variable form a connected graph (variable cannot appear in 2 disjoint places) C CSR C moralize S R W Make jtree S R SR W Maximal cliques = { {C,S,R}, {S,R,W} } Separators = { {C,S,R} Å {S,R,W} = {S,R} } SP2-72 SRW Making a junction tree G GM B D B moralize A F C A F E C E Triangulate (order f,d,e,c,b,a) Wij = |Ci Å Cj| {a,b,c} Jensen94 {b,d} 1 {b,d} Max spanning tree {b,c,e} D {b,e,f} 1 1 1 2 {b,c,e} 2 Jgraph {b,e,f} Jtree Find max cliques B D A F {a,b,c} C SP2-73 E GT C S R Clique potentials W CSR Each model clique potential gets assigned to one Jtree clique potential Each observed variable assigns a delta function to one Jtree clique potential SR If we observe W=w*, set E(w)=(w,w*), else E(w)=1 SRW Square nodes are factors SP2-74 C S R Separator potentials W CSR Separator potentials enforce consistency between neighboring cliques on common variables. SR SRW Square nodes are factors SP2-75 BP on a Jtree CSR 1 4 SR 2 3 SRW • A Jtree is a MRF with pairwise potentials. • Each (clique) node potential contains CPDs and local evidence. • Each edge potential acts like a projection function. • We do a forwards (collect) pass, then a backwards (distribute) pass. • The result is the Hugin/ Shafer-Shenoy algorithm. SP2-76 BP on a Jtree (collect) CSR Initial clique potentials contain CPDs and evidence SR SRW SP2-77 BP on a Jtree (collect) CSR SR Message from clique to separator marginalizes belief (projects onto intersection) [remove c] SRW SP2-78 BP on a Jtree (collect) CSR SR Separator potentials gets marginal belief from their parent clique. SRW SP2-79 BP on a Jtree (collect) CSR SR SRW Message from separator to clique expands marginal [add w] SP2-80 BP on a Jtree (collect) CSR SR SRW SP2-81 Root clique has seen all the evidence BP on a Jtree (distribute) CSR CSR SR SR SRW SRW SP2-82 BP on a Jtree (distribute) CSR CSR Marginalize out w and exclude old evidence (ec, er) SR SR SRW SRW SP2-83 BP on a Jtree (distribute) CSR CSR Combine upstream and downstream evidence SR SR SRW SRW SP2-84 BP on a Jtree (distribute) CSR CSR SR SR SRW SRW Add c and exclude old evidence (ec, er) SP2-85 BP on a Jtree (distribute) CSR CSR SR SR SRW SRW Combine upstream and downstream evidence SP2-86 Partial beliefs CSR CSR SR SR SRW SRW Evidence on R now added here •The “beliefs”/ messages at intermediate stages (before finishing both passes) may not be meaningful, because any given clique may not have “seen” all the model potentials/ evidence (and hence may not be normalizable). •This can cause problems when messages may fail (eg. Sensor nets). •One must reparameterize using the decomposable model to ensure meaningful partial beliefs. Paskin04 SP2-87 Hugin algorithm Hugin = BP applied to a Jtree using a serial protocol Distribute Collect Ci Ci Sij Sij Cj Cj Square nodes are separators SP2-88 Shafer-Shenoy algorithm • SS = BP on a Jtree, but without separator nodes. • Multiplies by all-but-one messages instead of dividing out by old beliefs. • Uses less space but more time than Hugin. Lepar98 SP2-89 Other Jtree variants • Strong Jtree is created using a constrained elimination order. Useful for Cowell99 – Decision (influence) diagrams – Hybrid (conditional Gaussian) networks – MAP (max-sum-product) inference • Lazy Jtree Madsen99 – Multiplies model potentials only when needed – Can exploit special structure in CPDs SP2-90 Outline • Introduction • Exact inference – Variable elimination algorithm – Complexity of exact inference – Belief propagation algorithm – Junction tree algorithm – Linear Gaussian models • Approximate inference SP2-91 Gaussian MRFs • Absent arcs correspond to 0s in the precision matrix V-1 (“structural zeros”) 1 2 3 SP2-92 1 2 3 Gaussian Bayes nets Shachter89 4 If each CPD has the form Then the joint is where Absent arcs correspond to 0s in the regression matrix B SP2-93 Gaussian marginalization/ conditioning If then and • Hence exact inference can be in O(N3) time (matrix inversion). • Methods which exploit the sparsity structure of are equivalent to the junction tree algorithm and have complexity O(N w3), where w is the treewidth. Paskin03b SP2-94 Gaussian BP • Messages • Beliefs SP2-95 Weiss01 Kalman filtering Linear Dynamical System (LDS)/ State Space Model (SSM) X1 X2 X3 Y1 Y2 Y3 Hidden state Noisy observations BP on an LDS model = Kalman filtering/ smoothing See my matlab toolbox SP2-96 Example: LDS for 2D tracking Constant velocity model X1 X2 X3 Y1 Y2 Y3 Sparse linear Gaussian system ) sparse graphs X1 X2 X1 X2 y1 y2 y1 y2 o SP2-97 X1 X2o y1o y2o Non-linear, Gaussian models X1 X2 X3 Y1 Y2 Y3 Hidden state Noisy observations •Extended Kalman filter (EKF) – linearize f/g around •Unscented Kalman filter (UKF) – more accurate than EKF and avoids need to compute gradients •Both EKF and UKF assume P(Xt|y1:t) is unimodal See Rebel matlab toolbox (OGI) SP2-98 Minka02 Non-Gaussian models • If P(Xt|y1:t) is multi-modal, we can approximately represent it using – Mixtures of Gaussians (assumed density filtering (ADF)) – Samples (particle filtering) • Batch (offline) versions: – ADF -> expectation propagation (EP) – Particle filtering -> particle smoothing or Gibbs sampling SP2-99