TAMING THE CURSE OF DIMENSIONALITY: DISCRETE INTEGRATION BY HASHING AND OPTIMIZATION

advertisement
1
TAMING THE CURSE OF
DIMENSIONALITY:
DISCRETE INTEGRATION BY
HASHING AND OPTIMIZATION
Stefano Ermon*, Carla P. Gomes*,
Ashish Sabharwal+, and Bart Selman*
*Cornell University
+IBM Watson Research Center
ICML - 2013
2
High-dimensional integration
• High-dimensional integrals in statistics, ML, physics
• Expectations / model averaging
• Marginalization
• Partition function / rank models / parameter learning
• Curse of dimensionality:
n dimensional
hypercube
L
L2
L3
L4
• Quadrature involves weighted sum over exponential
number of items (e.g., units of volume)
Ln
3
Discrete Integration
Size visually represents weight
2n Items
• We are given
• A set of 2n items
• Non-negative weights w
1
4
…
0
5
• Goal: compute total weight
• Compactly specified weight function:
• factored form (Bayes net, factor graph, CNF, …)
• potentially Turing Machine
• Example 1: n=2 dimensions, sum over 4 items
5
0
2
1
Goal: compute 5 + 0 + 2 + 1 = 8
5
0
1
2
• Example 2: n= 100 dimensions, sum over 2100 ≈1030 items
(intractable)
4
EXP
PSPACE
Hardness
0
1
• 0/1 weights case:
• Is there at least a “1”?  SAT
0
1
• How many “1” ?  #SAT
• NP-complete vs. #P-complete. Much harder
• General weights:
• Find heaviest item (combinatorial optimization)
• Sum weights (discrete integration)
P^#P
PH
NP
P
Easy
0
3
4
7
• This Work: Approximate Discrete Integration via
Optimization
• Combinatorial optimization (MIP, Max-SAT,CP) also often
fast in practice:
• Relaxations / bounds
• Pruning
Hard
5
Previous approaches: Sampling
Idea:
60
2
5
9
 Randomly select a region
 Count within this region
 Scale up appropriately
Advantage:
2
70
5
5
 Quite fast
Drawback:
100
5
9
2
5
5
5
9
 Robustness: can easily under- or
over-estimate
 Scalability in sparse spaces:
e.g. 1060 items with non-zero weight
out of 10300 means need region
much larger than 10240 to “hit” one
 Can be partially mitigated using
importance sampling
6
Previous approaches: Variational methods
Idea:
 For exponential families, use convexity
 Variational formulation (optimization)
 Solve approximately (using messagepassing techniques)
Advantage:
 Quite fast
Drawback:
 Objective function is defined indirectly
 Cannot represent the domain of
optimization compactly
 Need to be approximated (BP, MF)
 Typically no guarantees
7
A new approach : WISH
Suppose items are sorted by weight
b4=2
2
2
b3=5
2
5 5 5 5 5 5 5
# items
w
b1=70
b2=9
9
9
60
9
70
b0=100
100
Geometrically
increasing bin
sizes
CDF-style plot
How many items with Area under the curve equals the total weight we want to
compute.
weight at least b
How to estimate? Divide into slices and sum up
8
1
2i-largest weight (quantile) bi
4
2
1
1
Geometrically divide y axis
b
Given the endpoints bi, we have a 2-approximation
Can bound area in each slice within a factor of 2
Also works if we have approximations Mi of bi
How to estimate the bi?
8
Estimating the endpoints (quantiles) bi
Hash 2n items into 2i buckets, then look at a single bucket.
Find heaviest weight wi in the bucket.
For i=2, hashing 16 items into 22=4 buckets
2
9
5
5
9
70
5
2
9
5
5
60
5
2
100
5
Wi=9
INTUITION. Repeat several times.
With High Probability: wi often found to be larger than w*
items with weight larger than w*.
there are at least 2i
9
Hashing and Optimization
• Hash into 2i buckets, then look at a single bucket
• With probability >0.5:
• There is nothing from the small set (vanishes)
• There is something from the larger set (survives)
2
bi
bi-2

Something in here is likely to be in the bucket,
so if we take a max , it will be in this range
16 times larger
increasing weight
100
9 5
5
2i-2=2i/4
heaviest items
100
bi+2
2
b0
Geometrically
increasing bin
sizes
2i+2=4.2i
heaviest items
Remember items are sorted so
max picks the “rightmost” item…
10
Universal Hashing
• Represent each item as an n-bit vector x
• Randomly generate A in {0,1}i×n,b in {0,1}i
n
• Then A x + b (mod 2) is:
• Uniform
• Pairwise independent
bi+2
x
x
i
x x
Max w(x) subject to A x = b mod 2
is in here “frequently”
A
= b
(mod 2)
x
bi-2
bi
Bucket content is
implicitly defined
by the solutions of
A x = b mod 2
(parity constraints)
b0
x
Repeat several times.
Median is in the desired range with high probability
11
WISH : Integration by Hashing and
Optimization
# items
CDF-style plot
WISH (WeightedIntegralsSumsByHashing)
• T = log (n/δ)
Outer Loop over n+1 endpoints of the n slices (bi)
•For i = 0, … , n
• For t = 1, … ,T
Hash into 2i buckets
• Sample uniformly A in {0,1}i×n, b in {0,1}i
Find heaviest item
Repeat log(n) times
• wit = max w(x) subject to A x = b (mod 2)
Mi estimates the 2i-largest weight bi
• Mi = Median (wi1, … , wiT)
Sum up estimated area in each vertical slice
•Return M0 + Σi Mi+1 2i
The algorithm requires only O(n log n) optimizations
for a sum over 2n items
12
Visual working of the algorithm
• How it works 1 random parity
Function to be integrated
2 random parity
constraints
constraint
….
n times
3 random parity
constraints
….
….
….
Log(n) times
Mode M0
+
median M1
×1
+
median M2 ×2
+
median M3 ×4 + …
13
Accuracy Guarantees
• Theorem 1: With probability at least 1- δ (e.g., 99.9%)
WISH computes a 16-approximation of a sum over 2n
items (discrete integral) by solving θ(n log n)
optimization instances.
• Example: partition function by solving θ(n log n) MAP queries
• Theorem 2: Can improve the approximation factor to
(1+ε) by adding extra variables.
• Example: factor 2 approximation with 4n variables
• Byproduct: we also obtain a 8-approximation of the tail
distribution (CDF) with high probability
14
Key features
• Strong accuracy guarantees
• Can plug in any combinatorial optimization tool
• Bounds on the optimization translate to bounds on the
sum
• Stop early and get a lower bound (anytime)
• (LP,SDP) relaxations give upper bounds
• Extra constraints can make the optimization harder or easier
• Massively parallel (independent optimizations)
• Remark: faster than enumeration force only when
combinatorial optimization is efficient (faster than
brute force).
15
Experimental results
• Approximate the partition function of undirected graphical
models by solving MAP queries (find most likely state)
• Normalization constant to evaluate probability, rank models
• MAP inference on graphical model augmented with random parity
constraints
Original graphical model
Parity check nodes
enforcing A x = b (mod 2)
• Toulbar2 (branch&bound) solver for MAP inference
• Augmented with Gauss-Jordan filtering to efficiently handle the
parity constraints (linear equations over a field)
• Run in parallel using > 600 cores
16
Sudoku
• How many ways to fill a valid sudoku square?
….
1
2
• Sum over 981 ~ 1077 possible squares (items)
• w(x)=1 if it is a valid square, w(x)=0 otherwise
• Accurate solution within seconds:
• 1.634×1021 vs 6.671×1021
?
17
Random Cliques Ising Models
Partition function
MAP query
Strength of the interactions
Other methods
fall way out of the
error band
Very small error
band is the 16approximation
range
18
Model ranking - MNSIT
• Use the function estimate to rank models (data likelihood)
Visually, a
better
model for
handwritten
digits
• WISH ranks them correctly. Mean-field and BP do not.
19
Conclusions
• Discrete integration reduced to small number of
optimization instances
• Strong (probabilistic) accuracy guarantees by universal hashing
• Can leverage fast combinatorial optimization packages
• Works well in practice
• Future work:
• Extension to continuous integrals
• Further approximations in the optimization [UAI -13]
• Coding theory / Parity check codes / Max-Likelihood Decoding
• LP relaxations
• Sampling from high-dimensional probability distributions?
Download