Tony Jebara, Columbia University Machine Learning 4771 Instructor: Tony Jebara Tony Jebara, Columbia University Topic 18 • The Junction Tree Algorithm • Collect & Distribute • Algorithmic Complexity • ArgMax Junction Tree Algorithm Tony Jebara, Columbia University Review: Junction Tree Algorithm • Send message from each clique to its separators of what it thinks the submarginal on the separator is. • Normalize each clique by incoming message from its separators so it agrees with them V = {A,B } If agree: ∑ V \S ψV = φS = p (S ) = φS = ∑W \S ψW Else: Send message From V to W… Send message From W to V… φ = ∑V \S ψV φ = ∑W \S ψ φ*S φ** S * S * ψW = φS ψV* = ψV S = {B } ψW ** S ψV** = * W φ*S ** * ψW = ψW ψV* W = {B,C } …Done! Now they Agree…Done! ∑ ψ = ∑V \S V \S ** V = φ** S φ*S φ** S φ*S ψV* ∑ * ψ V \S V ** = φ** = ψ ∑ S W \S W Tony Jebara, Columbia University JTA with many cliques • Problem: what if we have more than two cliques? 1) Update AB & BC AB B BC C CD 2) Update BC & CD AB B BC C CD • Problem: AB has not heard about CD! After BC updates, it will be inconsistent for AB • Need to iterate the pairwise updates many times • This will eventually converge to consistent marginals • But, inefficient… can we do better? Tony Jebara, Columbia University JTA: Collect & Distribute • Use tree recursion rather than iterate messages mindlessly! initialize(DAG){ Pick root Set all variables as: ψCi = p x i | πi , φS = 1 } ( ) collectEvidence(node) { for each child of node { update1(node,collectEvidence(child)); } return(node); } distributeEvidence(node) { for each child of node { update2(child,node); distributeEvidence(child); } } * V ∩W update1(node w,node v) { φ ** V ∩W update2(node w,node v) { φ ( ) normalize() { p XC = ∑ 1 ψ** C C = ∑V \(V ∩W ) ψV , ψW = φV* ∩W = ∑V \(V ∩W ) ψV , ψW = φV**∩W ψC** ∀C , p (X S ) = ∑ 1 φ** S S φV ∩W φV* ∩W φ** ∀S S ψW } ψW } } Tony Jebara, Columbia University Junction Tree Algorithm • JTA: 1)Initialize 2)Collect 3)Distribute 4)Normalize • Note: leaves do not change their ψ during collect • Note: the first cliques collect changes are parents of leaves • Note: root does not change its ψ during distribute Tony Jebara, Columbia University Algorithmic Complexity • The 5 steps of JTA are all efficient: 1) Moralization Polynomial in # of nodes 2) Introduce Evidence (fixed or constant) Polynomial in # of nodes (convert pdf to slices) 3) Triangulate (Tarjan & Yannakakis 1984) Suboptimal=Polynomial, Optimal=NP 4) Construct Junction Tree (Kruskal) Polynomial in # of cliques 5) Junction Tree Algorithm (Init,Collect,Distribute,Normalize) Polynomial (linear) in # of cliques, Exponential in Clique Cardinality Tony Jebara, Columbia University Junction Tree Algorithm x 3x 4 • Convert Directed Graph to Junction Tree x 2x 3 ! x1x 2 x 3x 5 x 5x 6 x 5x 7 • Initialize separators to 1 (and Z=1) and set clique tables to the CPTs in the Directed Graph p (X ) = p (x1 ) p (x 2 | x1 ) p (x 3 | x 2 ) p (x 4 | x 3 ) p (x 5 | x 3 ) p (x 6 | x 5 ) p (x 7 | x 5 ) 1 p (X ) = Z ∏ ψ (X ) = 1 p (x ,x ) p (x ∏ φ (X ) 1 C C S S 1 2 3 | x 2 ) p (x 4 | x 3 ) p (x 5 | x 3 ) p (x 6 | x 5 ) p (x 7 | x 5 ) 1×1×1×1×1 • Run Collect, Distribute, Normalize • Get valid marginals from all ψ , φ tables Tony Jebara, Columbia University JTA with Extra Evidence • If extra evidence is observed, must slice tables accordingly • Example: p (A,B,C,D ) = Z1 ψAB ψBC ψCD Z =1 ψAB AB ⎡ 8 4 ⎤ ⎥ =⎢ ⎢ 3 1 ⎥ ⎣ ⎦ B =0 B =1 BC B A=0 ψBC A=1 C ⎡ 2 3 ⎤ ⎥ =⎢ ⎢ 1 1 ⎥ ⎣ ⎦ B=0 B =1 CD ψCD C =0 C =1 ⎡ 1 4 ⎤ ⎥ =⎢ ⎢ 1 1 ⎥ ⎣D = 0 D = 1⎦ C =0 C =1 • You are given evidence: A=0. Replace table with slices… ψAB → ⎡⎢⎣ 8 4 ⎤⎥⎦ ψBC ⎡ 2 3 ⎤ ⎥ =⎢ ⎢ 1 1 ⎥ ⎣ ⎦ ψCD ⎡ 1 4 ⎤ ⎥ =⎢ ⎢ 1 1 ⎥ ⎣ ⎦ • JTA now gives ψ, φ as marginals conditioned on evidence p (B | A = 0) = ψ** AB p (B,C | A = 0) = ∑ B ψ**AB ψ** BC ∑ B,C ψ**BC p (C,D | A = 0) = • All denominators equal the new normalizer Z’ Z ' = p (EVIDENCE ) = ∑ ψ = ∑ ψ = ∑ ψ B ** AB B,C ** BC C ,D ** CD ∑ ** ψCD C ,D ** ψCD Tony Jebara, Columbia University ArgMax Junction Tree Algorithm • We can also use JTA for finding the max not the sum over the joint to get argmax of marginals & conditionals • Say have some evidence: p (X F , X E ) = p (x1,…,xn ,xn+1,…,x N ) • Most likely (highest p) XF? X F* = arg max X p (X F , X E ) F • What is most likely state of patient with fever & headache? pF* = max x 2 ,x 3 ,x 4 ,x 5 p (x1 = 1,x 2,x 3,x 4 ,x 5,x 6 = 1) = max x p (x 2 | x1 = 1) p (x1 = 1) max x p (x 3 | x1 = 1) 2 3 max x p (x 4 | x 2 ) max x p (x 5 | x 3 ) p (x 6 = 1 | x 2,x 5 ) 4 5 • Solution: replace sum with max inside JTA update code * V ∩W φ = maxV \(V ∩W ) ψV , ψW = φV* ∩W φV ∩W ψW ** V ∩W φ = maxV \(V ∩W ) ψV , ψW = φV**∩W φV* ∩W ψW • Final potentials are max marginals: ψ** (XC ) = maxU \C p (X ) • Highest value in potential is most likely: XC* = arg maxC ψ** (XC ) Tony Jebara, Columbia University ArgMax Junction Tree Algorithm • Why do I need the ArgMax junction tree algorithm? • Can’t I just compute marginals using the Sum algorithm and then find the highest value in each marginal??? • No!! Here’s a counter-example: x1 x1 x1 p (x1,x 2 ) = A x2 = 0 x2 = 1 B C ⎡ .14 .05 .27 ⎤ ⎢ ⎥ ⎢ .24 .20 .10 ⎥ ⎣ ⎦ • Most likely is x1*=C and x2*=0 • But the sub-marginals p(x1) and p(x2) do not reveal this… p (x1 ) = A B C ⎡ ⎤ ⎣⎢ 0.38 0.25 0.37 ⎥⎦ p (x 2 ) = x2 = 0 x2 = 1 ⎡.46 ⎤ ⎢ ⎥ ⎢.54 ⎥ ⎣ ⎦ • The marginals would falsely imply that is x1*=A and x2*=1 Tony Jebara, Columbia University Example AB BC B • Note that products are element-wise • Let us send a regular JTA message from AB to BC ψAB ⎡ 8 4 ⎤ ⎥ =⎢ ⎢ 3 1 ⎥ ⎣ ⎦ B =0 B =1 A=0 A=1 ⎡ 1 ⎤ φB = ⎢ ⎥ ⎢ 1 ⎥ ⎣ ⎦ ⎡ φ = ∑ A ψAB = ∑ A ⎢ ⎢ ⎣ ⎡ 11 ⎤ ⎢ ⎥ * ⎢ 5 ⎥ φB * ⎦ ψBC = ψBC = ⎣ ⎡ 1 ⎤ φB ⎢ ⎥ ⎢ 1 ⎥ ⎣ ⎦ * B B=0 ψBC B =1 ⎡ 2 3 ⎤ ⎥ =⎢ ⎢ 1 1 ⎥ ⎣ ⎦ C =0 C =1 ⎡ ⎤ 8 4 ⎤⎥ = ⎡ ⎤ = ⎢ 11 ⎥ ⎢⎣ 11 5 ⎥⎦ ⎢ ⎥ ⎥ 3 1 ⎦ B =0 B =1 ⎣ 5 ⎦ B=0 B =1 ⎡ 2 3 ⎤ ⎡ 22 33 ⎤ B = 0 ⎢ ⎥=⎢ ⎥ ⎢ 1 1 ⎥ ⎢ 5 5 ⎥ B =1 ⎣ ⎦ ⎣ ⎦ C =0 C =1 • If argmax JTA, just change the separator update to: * B φ = max A ψAB ⎡ 8 4 ⎤ ⎡ 8 ⎤ ⎡ ⎤ ⎥= ⎥ = max A ⎢ =⎢ ⎢ 3 1 ⎥ ⎢⎣ 8 4 ⎥⎦ ⎢ 4 ⎥ ⎣ ⎦ B =0 B =1 ⎣ ⎦ B=0 B =1 B=0 B =1