The Rescorla-Wagner Algorithm and Maximum Likelihood

advertisement
The Rescorla-Wagner algorithm and Maximum
Likelihood estimation of causal parameters.
Alan Yuille
Department of Statistics
University of California at Los Angeles
Los Angeles, CA 90095
yuille@stat.ucla.edu
Abstract
This paper analyzes generalization of the classic Rescorla-Wagner (RW) learning algorithm and studies their relationship to Maximum Likelihood estimation of causal parameters. We prove that the parameters
of two popular causal models, ∆P and P C, can be learnt by the same
generalized linear Rescorla-Wagner (GLRW) algorithm provided genericity conditions apply. We characterize the fixed points of these GLRW
algorithms and calculate the fluctuations about them, assuming that the
input is a set of i.i.d. samples from a fixed (unknown) distribution. We
describe how to determine convergence conditions and calculate convergence rates for the GLRW algorithms under these conditions.
1 Introduction
There has recently been growing interest in models of causal learning formulated as probabilistic inference [1,2,3,4,5]. There has also been considerable interest in relating this work
to the Rescorla-Wagner learning model [3,5,6] (also known as the delta rule). In addition,
there are studies of the equilibria of the Rescorla-Wagner model [6].
This paper proves mathematical results about these related topics. In Section (2), we describe two influential models, ∆P and P C, for causal inference and how their parameters
can be learnt by maximum likelihood estimation from training data. Section (3) introduces
the generalized linear Rescorla-Wagner (GLRW) algorithm, characterize its fixed points
and quantify its fluctuations. We demonstrate that a simple GLRW can estimate the ML parameters for both the ∆P and P C models provided certain genericity conditions are
satisfied. But the experimental conditions studied by Cheng [2] require a non-linear generalization of Rescorla-Wagner (Yuille, in preparation). Section (4) gives a way to determine
convergence conditions and calculate the convergence rates of GLRW algorithms. Finally
Section (5) sketches how the results in this paper can be extended to allow for arbitrary
number of causes.
2 Causal Learning and Probabilistic Inference
The task is to estimate the causal effect of variables. There is an observed event E and
two causes C1 , C2 . Observers are asked to determine the causal power of the two causes.
The variables are binary-valued. E = 1 means the event occurs, E = 0 means it does
not. Similarly for causes C1 and C2 . Much of the work in this section can be generalized
to cases where there are an arbitrary number of causes C1 , C2 , ..., CN , see section (5).
The training data {(E µ , C1µ , C2µ )} is assumed to be samples from an unknown distribution
Pemp (E, C1 , C2 ).
Two simple models, ∆P [1] and P C [2,3], have been proposed to account for how people
estimate causal power. There is also a more recent theory based on model selection [4].
The ∆P and P C theories are equivalent to assuming probability distributions for how
the training data is generated. Then the power of the causes is given by the maximum
likelihood estimation of the distribution parameters ω1 , ω2 . The two theories correspond to
probability distributions P∆P (E|C1 , C2 , ω1 , ω2 ) and PP C (E|C1 , C2 , ω1 , ω2 ) given by:
P∆P (E = 1|C1 , C2 , ω1 , ω2 ) = ω1 C1 + ω2 C2 . ∆P model.
PP C (E = 1|C1 , C2 , ω1 , ω2 ) = ω1 C1 + ω2 C2 − ω1 ω2 C1 C2 . P C model.
(1)
(2)
The later is a noisy-or model. The event E = 1 can be caused by C1 = 1 with probability
ω1 , by C2 = 1 with probability ω2 , or caused by both. The model can be derived by setting
PP C (E = 0|C1 , C2 , ω1 , ω2 ) = (1 − ω1 C1 )(1 − ω2 C2 ).
We assume that there is also a distribution on the causes P (C1 , C2 |~γ ) which the observers also learn from the training data. This is equivalent to maximizing (with respect to
ω1 , ω2 , ~γ )):
Y
Y
~ µ )} : ω
~µ : ω
~µ : ω
~ µ : ~γ ). (3)
P ({(E µ , C
~ , ~γ ) =
P (E µ , C
~ , ~γ ) =
P (E µ |C
~ )P (C
µ
µ
By taking logarithms, we see that estimating ω1 , ω2 and ~γ are independent. So we will
concentrate on estimating the ω1 , ω2 .
~ µ } is consistent with the model – i.e. there exist parameters
If the training data {E µ , C
)
ω1 , ω2 such Pemp (E|C1 , C2 ) = P (E|C1 , C2 , ω1 , ω2 – then we can calculate the solution
directly.
For the ∆P model, we have:
ω1 = Pemp (E = 1|C1 = 1, C2 = 0) = Pemp (E = 1|C1 = 1),
ω2 = Pemp (E = 1|C1 = 0, C2 = 1) = Pemp (E = 1|C2 = 1).
(4)
For the PP C model, we obtain Cheng’s measures of causality [2,3].
ω1
=
ω2
=
Pemp (E = 1|C1 = 1, C2 ) − Pemp (E = 1|C1 = 0, C2 )
1 − Pemp (E = 1|C1 = 0, C2 )}
Pemp (E = 1|C1 , C2 = 1) − Pemp (E = 1|C1 , C2 = 0)
.
1 − Pemp (E = 1|C1 , C2 = 0)}
(5)
3 Generalized Linear Rescorla-Wagner
The Rescorla-Wagner model [7] is an alternative way to account for human learning. This
iterative algorithm specifies an update rule for weights. These weights could measure the
strength of a cause, such as the parameters of the Maximum Likelihood estimation. Following recent work [3,6], we seek to find relationships between generalized linear RescorlaWagner (GLRW) and ML estimation.
3.1
GLRW and two special cases
~ } using training data {E µ , C
~ µ }. It is
The Rescorla-Wagner algorithm updates weights {V
of form:
~ t+1 = V
~ t + ∆V
~ t.
V
(6)
In this paper, we are particularly concerned with two special cases for choice of the update
∆V .
∆V1 = α1 C1 (E − C1 V1 − C2 V2 ), ∆V2 = α2 C2 (E − C1 V1 − C2 V2 ), basic
∆V1 = α1 C1 (1 − C2 )(E − V1 ), ∆V2 = α2 C2 (1 − C1 )(E − V2 ), variant.
(7)
(8)
The first (7) is the basic RW algorithm. The second (8) is a variant of RW with a natural
interpretation – a weight V1 is updated only if one cause is present, C1 = 1, and the other
cause is absent, C2 = 0.
The most general GLRW is of form:
∆Vit =
N
X
~ t ) + gi (E t , C
~ t ), ∀i,
Vjt fij (E t , C
(9)
j=1
where {fij (., .) : i, j = 1, ..., N } and {gi (.) : i = 1, ..., N } are functions of the data
~ µ.
samples E µ , C
3.2
GLRW and Stochastic Samples
~ µ )} are independent identical (i.i.d.)
Our analysis assumes that the data samples {E µ , C
~
~
samples from an unknown distribution Pemp (E|C)P (C).
In this case, the GLRW becomes stochastic. It defines a distribution on weights which is
updated as follows:
Z
N
Y
~ t+1 |V
~ t ) = dE t dC
~t
~ t ).
P (V
δ(Vit+1 − Vit − ∆Vit )P (E t , C
(10)
i=1
This defines a Markov Chain. If certain conditions are satisfied (see section (4), the chain
will converge to a fixed distribution
P ∗ (V ). This distribution can be characterized
by its
P
P
∗
∗
expected mean < V > = V V P (V ) and its expected covariance Σ∗ = V (V − <
V >∗ )(V − < V >∗ )T P ∗ (V ). In other words, even after convergence the weights will
fluctuate about the expected mean < V >∗ and the magnitude of the fluctuations will be
given by the expected covariance.
3.3
What Does GLRW Converge to?
~ ). We first
We now compute the means and covariance of the fixed point distribution P ∗ (V
do this for the GLRW, equation (9), and then we restrict ourselves to the two special cases,
equations (7,8).
~ ∗ and the covariance Σ∗ of the fixed point distribution P ∗ (V
~ ),
Theorem 1. The means V
~ are given by the
using the GLRW equation (9) and any empirical distribution Pemp (E, C)
solutions to the linear equations,
N
X
j=1
Vj∗
X
~
E,C
~ emp (E, C)
~ +
fij (E, C)P
X
~
E,C
~ emp (E, C)
~ = 0, ∀i,
gi (E, C)P
(11)
and ∀i, j:
Σ∗ik
=
X
Σ∗jl
+
X
~ k (E, C)P
~ emp (E, C),
~
Bi (E, C)B
jl
X
~ kl (E, C)P
~ emp (E, C)
~
Aij (E, C)A
~
E,C
(12)
~
E,C
~ = δij + fij (E, C)
~ and Bi (E, C)
~ = P V ∗ fij (E, C)
~ + gi (E, C)
~ (here
where Aij (E, C)
j j
δij is the Kronecker delta function defined by δij = 1, if i = j and = 0 otherwise).
P
~ ij (E, C)
~ is an invertible
The means have a unique solution provided E,C~ Pemp (E, C)f
matrix.
~ ∗ by taking the expectation of the update rule,
Proof. We derive the formula for the means V
∗ ~
~ To calculate the covariances,
see equations (9), with respect to P (V ) and Pemp (E, C).
we express the update rule as:
X
~ + Bi (E, C),
~ ∀i
Vit+1 − Vi∗ =
(Vjt − Vj∗ )Aij (E, C)
(13)
j
~ and Bi (E, C)
~ defined as above. Then we multiply both sides of equawith Aij (E, C)
tion (13) by their transpose (e.g. the left hand side by (Vkt+1 − Vk∗ )) and taking the expec~ ) and Pemp (E, C)
~ (making use of the result that the expected
tation with respect to P ∗ (V
value of Vjt − Vj∗ is zero as t → ∞.
We can apply these results to study the behaviour of the two special cases, equations (7,8),
when the data is generated by either the ∆P or P C model.
First consider the basic RW algorithm (7) when the data is generated by the P∆P model.
~ >∗ = ω
We can use Theorem 1 to rederive the result that < V
~ [3,6], and so basic RW
performs ML estimation for the P∆P model. It also follows directly. that if the data is
~ >∗ 6= ω
generated by the PP C model, then < V
~ (although they are related by a nonlinear
equation).
Now consider the variant RW, equation (8).
Theorem 2. The expected means of the fixed points of the variant RW equation (8) when
~ ω
~ ω
the data is generated by probability model PP C (E|C,
~ ) or P∆P (E|C;
~ ) are given by:
V1∗ = ω1 , V2∗ = ω2 ,
(14)
~ satisfies genericity conditions so that < C1 (1−C2 ) >< C2 (1−C1 ) >6=
provided Pemp (C)
0.
The expected covariances are given by:
α1
α2
Σ11 = ω1 (1 − ω1 )
, Σ22 = ω2 (1 − ω2 )
, Σ12 = Σ21 = 0.
(15)
2 − α1
2 − α2
. Proof. This is a direct calculation of quantities specified in Theorem 1. For example, we
~ and then with
calculate the expected value of ∆V1 and ∆V2 first with respect to P (E|C)
∗
respect to P (V ). This gives:
∗
< ∆V1 >P (E|C)P
~ ∗ (V ) = α1 C1 (1 − C2 )(ω1 − V1 ),
∗
< ∆V2 >P (E|C)P
(16)
~ ∗ (V ) = α2 C2 (1 − C1 )(ω2 − V2 ),
P
P
~ = ω1 C1 +ω2 C2 −ω1 ω2 C1 C2 ,
where we have used V P ∗ (V )V = V ∗ , E EPP C (E|C)
2
and logical relations to simply the terms (e.g. C1 = C1 , C1 (1 − C1 ) = 0).
Taking the expectation of < ∆V1 >P (E|C)P
~ ∗ (V ) with respect to P (C) gives,
α1 ω1 < C1 (1 − C2 ) >P (C) −α1 V1∗ < C1 (1 − C2 ) >= 0,
α2 ω2 < C2 (1 − C1 ) >P (C) −α2 V2∗ < C2 (1 − C1 ) >= 0,
(17)
and the result follows directly, except for non-generic cases where < C1 (1 − C2 ) >= 0 or
< C2 (1 − C1 ) >= 0. These degenerate cases are analyzed separately.
It is perhaps surprising that the same GLRW algorithm can perform ML estimation when
the data is generated by either model P∆P or PP C (and this can be generalized, see section (5)). Moreover, the expected covariance is the same for both models. Observe that the
covariance decreases if we make the update coefficients α1 , α2 of the algorithm small. The
convergence rates are given in the next section.
The non-generic cases include the situation studied in [2] where C1 is a background cause
that it assumed to be always present, so < C1 >= 1. In this case V1∗ = ω1 , but V2∗ is
unspecified. It can be shown (Yuille, in preparation) that a nonlinear generalization of RW
can perform ML on this problem (but it is eay to check that no GLRW can). But an even
more ambiguous case occurs when ω1 = 1 (i.e. cause C1 always causes event E), then
there is no way to estimate ω2 and Cheng’s measure of causality, equation (5), becomes
undefined.
4 Convergence of Rescorla-Wagner
We now analyze the convergence of the GLRW algorithm. We obtain conditions for the
algorithm to converge and give the convergence rates. For simplicity, the results will be
illustrated only on the simple models.
Our results are based on the following theorem for the convergence of the state vector
of a stochastic iterative equation. The theorem gives necessary and sufficient conditions
for convergence, shows what the expected state vector converges to, and gives the rate of
convergence.
Theorem 3. Let ~zt+1 = At ~zt be an iterative update equation, where ~z is a state vector and
the update matrices At are
Pi.i.d. samples from P (A). The convergence properties as t →
∞ depends on < A >= A AP (A). If < A > has a unit eigenvalue with eigenvector
~z∗ and the next largest eigenvalue has modulus λ < 1, then limt→∞ < ~zt >∝ ~z∗ and the
rate of convergence is et log λ . If the moduli of the eigenvalues of < A > are all less than
1, then limt→∞ < ~zt >= 0. If < A > has an eigenvalue with modulus greater than 1,
then < ~zt > diverges as t → ∞.
Proof. This is a standard result. To obtain it, write ~zt+1 = At At−1 ....A1 ~z1 , where ~z1 is the
initial condition. Now take the expectation of ~zt+1 with respect to the samples {(at , bt )}.
By the i.i.d. assumption, this gives < ~zt+1 >=< A >t ~z1 . The result follows by linear
algebra. Let the eigenvectors
and eigenvalues of < A > be {(λi , ~ei )}. Express
P
P the initial
conditions as ~z1 =
γi~ei where the {γi } are coefficients. Then < ~zt >= i γi λt~ei , and
the result follows.
We use Theorem 3 to obtain convergence results for the GLRW algorithm. To ensure convergence, we need both the expected covariance and the expected means to converge. Then
Markov’s lemma can be used to bound the fluctuations. (If we just require the expected
means to converge, then the fluctuations of the weights may be infinitely large). This can
be done by a suitable choice of the state vector ~z.
For simplicity of algebra, we demonstrate this for a GLRW algorithm with a single weight.
The update rule is Vt+1 = at Vt + bt where at , bt are random samples. We define the state
vector to be ~z = (Vt2 , Vt , 1).
Theorem 4. Consider the stochastic update rule Vt+1 = P
at Vt + bt where atP
and bt are
2
samples
from
distributions
P
(a)
and
P
(b).
Define
α
=
a
P
(a),
α
=
a
b
1
2
a
a aP (a),
P
P
P
β1 = b b2 P (b), β2 = b bP (b), and γ = 2 a,b abP (a, b). The algorithm converges
β2
if, and only if, α1 < 1, α2 < 1. If so, then limt→∞ < Vt >=< V >= 1−α
, limt→∞ <
2
(Vt − < V >)2 >=
β1 (1−α2 )+γβ2
(1−α1 )(1−α2 )
−
β22
(1−α2 )2 .
The convergence rate is {max{α1 , |α2 |}t .
Proof. Define ~zt = (Vt2 , Vt , 1) and express the update rule in matrix form:

 


2
Vt+1
a2t 2at bt b2t
Vt2
 Vt+1  =  0
at
bt   V t 
0
0
1
1
1
This is of the form analyzed in Theorem 3 provided we set:


!
α1 γ β1
a2t 2at bt b2t
0 α2 β2
at
bt  and < A >=
A= 0
,
0
0
1
0
0
1
P 2
P
P 2
P
where P
α1 =
a a P (a), α2 =
a aP (a), β1 =
b b P (b), β2 =
b bP (b), and
γ = 2 a,b abP (a, b).
The eigenvalues {λ} and eigenvectors {~e} of < A > are:
β1 (1 − α2 ) + γβ2
β2
λ1 = 1, ~e1 ∝ (
,
, 1)
(1 − α1 )(1 − α2 ) 1 − α2
γ
λ2 = α1 , ~e2 = (1, 0, 0), λ3 = α2 , ~e3 ∝ (
, 1, 0).
α2 − α1
The result follows from Theorem 3.
(18)
Observe that if |α2 | < 1 but α1 > 1, then < Vt > will converge but the expected variance
does not. The fluctuations in the GLRW algorithm will be infinitely large.
We can extend Theorem 4 to the variant of RW equation (8). Let P = Pemp , then
X
X
~ (C)C
~ 1 (1 − C2 ), β21 =
~ (C)C
~ 2 (1 − C1 ),
β12 =
P (E|C)P
P (E|C)P
~
E,C
γ12 =
X
~
E,C
~ (C)EC
~
P (E|C)P
1 (1 − C2 ), γ21 =
~
E,C
X
~ (C)EC
~
P (E|C)P
2 (1 − C1 ).
(19)
~
E,C
If the data is generated by P∆P or PP C , then β12 , β21 , γ12 , γ21 take the same values:
β12 =< C1 (1 − C2 ) >, β21 =< (1 − C1 )C2 >,
γ12 = ω1 < C1 (1 − C2 ) >, γ21 = ω2 < (1 − C1 )C2 > .
(20)
Theorem 5. The algorithm specified by equation (8) converges provided λ∗ =
max{|λ2 |, |λ3 |, |λ4 |, |λ5 |} < 1, where λ2 = 1 − (2α1 − α12 )β12 , λ3 = 1 − (2α2 − α22 )β21 ,
∗
λ4 = 1 − α1 β12 λ5 = 1 − α2 β21 . The convergence rate is et log λ . The expected means
and covariances can be calculated from the first eigenvector.
Proof. We define the state vector ~z = (V12 , V22 , V1 , V2 , 1) and derive the update matrix
A from equation (8). The eigenvectors and eigenvalues can be calculated (calculations
omitted due to space constraints). The eigenvalues are 1, λ1 , λ2 , λ3 , λ4 . The convergence
conditions and rates follow from Theorem 3. The expected means and covariances can be
calculated from the first eigenvector, which is:
2
2
2(α1 − α12 )γ12
α12 γ12
2(α2 − α22 )γ21
α22 γ21
γ12 γ21
~e1 = (
+
,
+
,
,
, 1),
2
2
2
2
2
2
(2α1 − α1 )β12 (2α1 − α1 )β12 (2α2 − α2 )β21 (2α2 − α2 )β21 β12 β21
(21)
and they agree with the calculations given in Theorem 2.
5 Generalization
The results of the previous sections can be generalized to cases where there are more than
two causes. For example, we can use the generalization of the P C model to include multi~ and preventative causes L,
~ [5] extending [2].
ple generative causes C
The probability distribution for this generalized P C model is:
n
Y
~ L;
~ ω
~ = {1 −
PP C (E = 1|C,
~ , Ω)
i=0
(1 − ωi Ci )}
m
Y
(1 − Ωj Lj ),
(22)
j=1
where there are n + 1 generative causes {Ci } and m preventative causes {Lj } specified in
terms of parameters {ωi } and {Ωj } (constrained to lie between 0 and 1).
We assume that there is a single background cause C0 which is always on (i.e. C0 = 1)
and whose strength ω0 is known (for relaxing this constraint, see Yuille in preparation).
Then it can be shown that the following GLRW algorithm will converge to the ML estimates
of the remaining parameters {ω1 : 1 = 1, ..., n} and {Ωj : j = 1, ..., m} of the generalized
P C model:
∆Vkt
=
Ck {
m
Y
(1 − Li )
i=1
∆Ult
=
Ll {
m
Y
k=1:k6=l
n
Y
j=1:j6=k
n
Y
(1 − Lk )
(1 − Cj )}(E − ω0 − (1 − ω0 )Vkt ),
(1 − Cj )}(E − ω0 − ω0 Ult ),
(23)
j=1
where {Vk : k = 1, ..., n} and {Ul : l = 1, ..., m} are weights.
The proof
algebra and
Q is based on the following identity for binary variQ
Q is straightforward
ables: j (1 − Ωj Lj ) j (1 − Lj ) = j (1 − Lj ).
The GLRW algorithm (23) will also perform ML estimation for data generated by other
probability distributions which share the same linear terms as the generalized P C model
(i.e. the terms linear in the {ωi } and {Ωj }.) The convergence conditions and the convergence rates can be calculated using the techniques in section (4).
These results all assume genericity conditions so that none of the generative or preventative
causes is either always on or always off (i.e. ruling out case like [2]).
6 Conclusion
This paper introduced and studied generalized linear Rescorla-Wagner (GLRW) algorithms. We showed that two influential theories, ∆P and P C, for estimating causal effects can
be implemented by the same GLRW, see (8). We obtained convergence results for GLRW including classifying the fixed points, calculating the asymptotic fluctuations, and the
convergence rates. Our results assume that the GLRW are i.i.d. samples from an unknown
~ Observe that the fluctuations of GLRW can be removed
empirical distribution Pemp (E, C).
by introducing damping coefficients which decrease over time. Stochastic approximation
theory [8] can then be used to give conditions for convergence.
More recent work (Yuille in preparation) clarifies the class of maximum likelihood inference problems that can be “solved” by GLRW and by non-linear GLRW. In particular, we
show that a non-linear RW can perform ML estimation for the non-generic case studied by
Cheng. We also investigate similarities to Kalman filter models [9].
Acknowledgements
I thank Patricia Cheng, Peter Dayan and Yingnian Wu for helpfell discussions. Anonymous
referees gave useful feedback that has motivated a follow-up paper. This work was partially
supported by an NSF SLC catalyst grant “Perceptual Learning and Brain Plasticity” NSF
SBE-0350356.
References
[1]. B. A. Spellman. “Conditioning Causality”. In D.R. Shanks, K.J. Holyoak, and D.L.
Medin, (eds). Causal Learning: The Psychology of Learning and Motivation, Vol. 34.
San Diego, California. Academic Press. pp 167-206. 1996.
[2]. P. Cheng. “From Covariance to Causation: A Causal Power Theory”. Psychological
Review, 104, pp 367-405. 1997.
[3]. M. Buehner and P. Cheng. “Causal Induction: The power PC theory versus the
Rescorla-Wagner theory”. In Proceedings of the 19th Annual Conference of the Cognitive Science Society”. 1997.
[4]. J.B. Tenenbaum and T.L. Griffiths. “Structure Learning in Human Causal Induction”.
Advances in Neural Information Processing Systems 12. MIT Press. 2001.
[5]. D. Danks, T.L. Griffiths, J.B. Tenenbaum. “Dynamical Causal Learning”. Advances in
Neural Information Processing Systems 14. 2003.
[6]. D. Danks. “Equilibria of the Rescorla-Wagner Model”. Journal of Mathematical
Psychology. Vol. 47, pp 109-121. 2003.
[7]. R.A. Rescorla and A.R. Wagner. “A Theory of Pavlovian Conditioning: Variations
in the Effectiveness of Reinforcement and Nonreinforcement”. In A.H. Black andW.F.
Prokasy, eds. Classical Conditioning II: Current Research and Theory. New York.
Appleton-Century-Crofts, pp 64-99. 1972.
[8]. H.J. Kushner and D.S. Clark. Stochastic Approximation for Constrained and Unconstrained Systems. New York. Springer-Verlag. 1978.
[9]. P. Dayan and S. Kakade. “Explaining away in weight space”. In Advances in Neural
Information Processing Systems 13. 2001.
Download