Machine Learning 4771

advertisement
Tony Jebara, Columbia University
Machine Learning
4771
Instructor: Tony Jebara
Tony Jebara, Columbia University
Topic 18
• The Junction Tree Algorithm
• Collect & Distribute
• Algorithmic Complexity
• ArgMax Junction Tree Algorithm
Tony Jebara, Columbia University
Review: Junction Tree Algorithm
• Send message from each clique to its separators of
what it thinks the submarginal on the separator is.
• Normalize each clique by incoming message
from its separators so it agrees with them
V = {A,B }
If agree:
∑
V \S
ψV = φS = p (S ) = φS = ∑W \S ψW
Else: Send message
From V to W…
Send message
From W to V…
φ = ∑V \S ψV
φ = ∑W \S ψ
φ*S
φ**
S
*
S
*
ψW
=
φS
ψV* = ψV
S = {B }
ψW
**
S
ψV** =
*
W
φ*S
**
*
ψW
= ψW
ψV*
W = {B,C }
…Done!
Now they
Agree…Done!
∑
ψ = ∑V \S
V \S
**
V
=
φ**
S
φ*S
φ**
S
φ*S
ψV*
∑
*
ψ
V \S V
**
= φ**
=
ψ
∑
S
W \S W
Tony Jebara, Columbia University
JTA with many cliques
• Problem: what if we have more than two cliques?
1) Update AB & BC
AB
B
BC
C
CD
2) Update BC & CD
AB
B
BC
C
CD
• Problem:
AB has not heard about CD!
After BC updates, it will be inconsistent for AB
• Need to iterate the pairwise updates many times
• This will eventually converge to consistent marginals
• But, inefficient… can we do better?
Tony Jebara, Columbia University
JTA: Collect & Distribute
• Use tree recursion rather than iterate messages mindlessly!
initialize(DAG){ Pick root
Set all variables as: ψCi = p x i | πi , φS = 1 }
(
)
collectEvidence(node) {
for each child of node {
update1(node,collectEvidence(child)); }
return(node); }
distributeEvidence(node) {
for each child of node {
update2(child,node);
distributeEvidence(child); } }
*
V ∩W
update1(node w,node v) { φ
**
V ∩W
update2(node w,node v) { φ
( )
normalize() { p XC =
∑
1
ψ**
C C
= ∑V \(V ∩W ) ψV , ψW =
φV* ∩W
= ∑V \(V ∩W ) ψV , ψW =
φV**∩W
ψC** ∀C , p (X S ) =
∑
1
φ**
S S
φV ∩W
φV* ∩W
φ**
∀S
S
ψW
}
ψW
}
}
Tony Jebara, Columbia University
Junction Tree Algorithm
• JTA:
1)Initialize 2)Collect
3)Distribute 4)Normalize
• Note: leaves do not change their ψ during collect
• Note: the first cliques collect changes are parents of leaves
• Note: root does not change its ψ during distribute
Tony Jebara, Columbia University
Algorithmic Complexity
• The 5 steps of JTA are all efficient:
1) Moralization
Polynomial in # of nodes
2) Introduce Evidence
(fixed or constant)
Polynomial in # of nodes (convert pdf to slices)
3) Triangulate
(Tarjan & Yannakakis 1984)
Suboptimal=Polynomial, Optimal=NP
4) Construct Junction Tree
(Kruskal)
Polynomial in # of cliques
5) Junction Tree Algorithm (Init,Collect,Distribute,Normalize)
Polynomial (linear) in # of cliques, Exponential in Clique Cardinality
Tony Jebara, Columbia University
Junction Tree Algorithm
x 3x 4
• Convert Directed Graph to Junction Tree
x 2x 3
!
x1x 2
x 3x 5
x 5x 6
x 5x 7
• Initialize separators to 1 (and Z=1) and set clique
tables to the CPTs in the Directed Graph
p (X ) = p (x1 ) p (x 2 | x1 ) p (x 3 | x 2 ) p (x 4 | x 3 ) p (x 5 | x 3 ) p (x 6 | x 5 ) p (x 7 | x 5 )
1
p (X ) =
Z
∏ ψ (X ) = 1 p (x ,x ) p (x
∏ φ (X ) 1
C
C
S
S
1
2
3
| x 2 ) p (x 4 | x 3 ) p (x 5 | x 3 ) p (x 6 | x 5 ) p (x 7 | x 5 )
1×1×1×1×1
• Run Collect, Distribute, Normalize
• Get valid marginals from all ψ , φ tables
Tony Jebara, Columbia University
JTA with Extra Evidence
• If extra evidence is observed, must slice tables accordingly
• Example: p (A,B,C,D ) = Z1 ψAB ψBC ψCD
Z =1
ψAB
AB
⎡ 8 4 ⎤
⎥
=⎢
⎢ 3 1 ⎥
⎣
⎦
B =0 B =1
BC
B
A=0
ψBC
A=1
C
⎡ 2 3 ⎤
⎥
=⎢
⎢ 1 1 ⎥
⎣
⎦
B=0
B =1
CD
ψCD
C =0 C =1
⎡ 1 4 ⎤
⎥
=⎢
⎢ 1 1 ⎥
⎣D = 0 D = 1⎦
C =0
C =1
• You are given evidence: A=0. Replace table with slices…
ψAB → ⎡⎢⎣ 8 4 ⎤⎥⎦
ψBC
⎡ 2 3 ⎤
⎥
=⎢
⎢ 1 1 ⎥
⎣
⎦
ψCD
⎡ 1 4 ⎤
⎥
=⎢
⎢ 1 1 ⎥
⎣
⎦
• JTA now gives ψ, φ as marginals conditioned on evidence
p (B | A = 0) =
ψ**
AB
p (B,C | A = 0) =
∑ B ψ**AB
ψ**
BC
∑ B,C ψ**BC
p (C,D | A = 0) =
• All denominators equal the new normalizer Z’
Z ' = p (EVIDENCE ) = ∑ ψ = ∑ ψ = ∑ ψ
B
**
AB
B,C
**
BC
C ,D
**
CD
∑
**
ψCD
C ,D
**
ψCD
Tony Jebara, Columbia University
ArgMax Junction Tree Algorithm
• We can also use JTA for finding the max not the sum
over the joint to get argmax of marginals & conditionals
• Say have some evidence: p (X F , X E ) = p (x1,…,xn ,xn+1,…,x N )
• Most likely (highest p) XF?
X F* = arg max X p (X F , X E )
F
• What is most likely state of patient with fever & headache?
pF* = max x
2
,x 3 ,x 4 ,x 5
p (x1 = 1,x 2,x 3,x 4 ,x 5,x 6 = 1)
= max x p (x 2 | x1 = 1) p (x1 = 1) max x p (x 3 | x1 = 1)
2
3
max x p (x 4 | x 2 ) max x p (x 5 | x 3 ) p (x 6 = 1 | x 2,x 5 )
4
5
• Solution: replace sum with max inside JTA update code
*
V ∩W
φ
= maxV \(V ∩W ) ψV , ψW =
φV* ∩W
φV ∩W
ψW
**
V ∩W
φ
= maxV \(V ∩W ) ψV , ψW =
φV**∩W
φV* ∩W
ψW
• Final potentials are max marginals: ψ** (XC ) = maxU \C p (X )
• Highest value in potential is most likely: XC* = arg maxC ψ** (XC )
Tony Jebara, Columbia University
ArgMax Junction Tree Algorithm
• Why do I need the ArgMax junction tree algorithm?
• Can’t I just compute marginals using the Sum algorithm
and then find the highest value in each marginal???
• No!! Here’s a counter-example:
x1 x1 x1
p (x1,x 2 ) =
A
x2 = 0
x2 = 1
B C
⎡ .14 .05 .27 ⎤
⎢
⎥
⎢ .24 .20 .10 ⎥
⎣
⎦
• Most likely is x1*=C and x2*=0
• But the sub-marginals p(x1) and p(x2) do not reveal this…
p (x1 ) =
A
B
C
⎡
⎤
⎣⎢ 0.38 0.25 0.37 ⎥⎦
p (x 2 ) =
x2 = 0
x2 = 1
⎡.46 ⎤
⎢ ⎥
⎢.54 ⎥
⎣ ⎦
• The marginals would falsely imply that is x1*=A and x2*=1
Tony Jebara, Columbia University
Example
AB
BC
B
• Note that products are element-wise
• Let us send a regular JTA message from AB to BC
ψAB
⎡ 8 4 ⎤
⎥
=⎢
⎢ 3 1 ⎥
⎣
⎦
B =0 B =1
A=0
A=1
⎡ 1 ⎤
φB = ⎢ ⎥
⎢ 1 ⎥
⎣ ⎦
⎡
φ = ∑ A ψAB = ∑ A ⎢
⎢
⎣
⎡ 11 ⎤
⎢
⎥
*
⎢ 5 ⎥
φB
*
⎦
ψBC =
ψBC = ⎣
⎡ 1 ⎤
φB
⎢ ⎥
⎢ 1 ⎥
⎣ ⎦
*
B
B=0
ψBC
B =1
⎡ 2 3 ⎤
⎥
=⎢
⎢ 1 1 ⎥
⎣
⎦
C =0 C =1
⎡
⎤
8 4 ⎤⎥ = ⎡
⎤ = ⎢ 11 ⎥
⎢⎣ 11 5 ⎥⎦ ⎢
⎥
⎥
3 1 ⎦
B =0 B =1
⎣ 5 ⎦
B=0
B =1
⎡ 2 3 ⎤ ⎡ 22 33 ⎤ B = 0
⎢
⎥=⎢
⎥
⎢ 1 1 ⎥ ⎢ 5 5 ⎥ B =1
⎣
⎦ ⎣
⎦
C =0 C =1
• If argmax JTA, just change the separator update to:
*
B
φ = max A ψAB
⎡ 8 4 ⎤
⎡ 8 ⎤
⎡
⎤
⎥=
⎥
= max A ⎢
=⎢
⎢ 3 1 ⎥ ⎢⎣ 8 4 ⎥⎦ ⎢ 4 ⎥
⎣
⎦ B =0 B =1
⎣
⎦
B=0
B =1
B=0
B =1
Download