Learning Structured Prediction Models: A Large Margin Approach Ben Taskar U.C. Berkeley Vassil Chatalbashev Carlos Guestrin Daphne Koller Michael Collins Dan Klein Chris Manning “Don’t worry, Howard. The big questions are multiple choice.” Handwriting recognition x y brace Sequential structure Object segmentation x y Spatial structure Natural language parsing x y The screen was a sea of red Recursive structure Disulfide connectivity prediction x y RSCCPCYWGGCPW GQNCYPEGCSGPKV Combinatorial structure Outline Structured prediction models Geometric View Sequences (CRFs) Trees (CFGs) Associative Markov networks (Special MRFs) Matchings Structured model polytopes Linear programming inference Structured large margin estimation Min-max formulation Application: 3D object segmentation Certificate formulation Application: disulfide connectivity prediction Structured models scoring function space of feasible outputs Mild assumption: linear combination Chain Markov Net (aka CRF*) P(y|x) i (xi,yi) i (yi,yi+1) (xi,yi) = exp{ wf(xi,yi)} (yi,yi+1) = exp{ wf (yi,yi+1)} y a-z a-z a-z a-z a-z f(y,y’) = I(y=‘z’,y’=‘a’) f(x,y) = I(xp=1, y=‘z’) x *Lafferty et al. 01 Chain Markov Net (aka CRF*) P(y|x) i (xi,yi) i (yi,yi+1) = exp{wTf(x,y)} i (xi,yi) = exp{ w w = i f[… ,yw (xi, i)} , … , w, …] =i,y[… , f (x,y) , … , f(x,y) , …] i (yi,yi+1) = exp{ f(x,y) w i f (y i+1)} y a-z a-z a-z a-z a-z f(x,y) = #(y=‘z’,y’=‘a’) f(x,y) = #(xp=1, y=‘z’) x *Lafferty et al. 01 Associative Markov Nets Point features spin-images, point height Edge features length of edge, edge orientation “associative” restriction i yi ij yj PCFG #(NP DT NN) … #(PP IN NP) … #(NN ‘sea’) Disulfide bonds: non-bipartite matching 2 RSCCPCYWGGCPWGQNCYPEGCSGPKV 1 2 3 4 5 1 4 6 6 1 6 2 4 3 5 Fariselli & Casadio `01, Baldi et al. ‘04 3 5 Scoring function 2 RSCCPCYWGGCPWGQNCYPEGCSGPKV 1 2 3 4 5 2 3 4 5 1 4 6 RSCCPCYWGGCPWGQNCYPEGCSGPKV 1 3 6 6 5 String features: residues, physical properties Structured models scoring function space of feasible outputs Mild assumption: Another mild assumption: linear programming MAP inference linear program LP inference for Chains Trees Associative Markov Nets Bipartite Matchings … Markov Net Inference LP 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 Has integral solutions y for chains, trees Gives upper bound for general networks Associative MN Inference LP “associative” restriction For K=2, solutions are always integral (optimal) For K>2, within factor of 2 of optimal Constraint matrix A is linear in number of nodes and edges, regardless of the tree-width Other Inference LPs Context-free parsing Dynamic programs Bipartite matching Network flow Many other combinatorial problems Outline Structured prediction models Geometric View Sequences (CRFs) Trees (CFGs) Associative Markov networks (Special MRFs) Matchings Structured model polytopes Linear programming inference Structured large margin estimation Min-max formulation Application: 3D object segmentation Certificate formulation Application: disulfide connectivity prediction Learning w Training example (x, y*) Probabilistic approach: Maximize conditional likelihood Problem: computing Zw(x) is #P-complete Geometric Example Training data: Goal: Learn w s.t. wTf( , y*) points the “right” way OCR Example We want: argmaxword wT f( ,word) = “brace” Equivalently: wT f( wT f( … wT f( ,“brace”) > wT f( ,“brace”) > wT f( ,“aaaaa”) ,“aaaab”) ,“brace”) > wT f( ,“zzzzz”) a lot! Large margin estimation Given training example (x, y*), we want: Maximize margin Mistake weighted margin: # of mistakes in y *Taskar et al. 03 Large margin estimation Brute force enumeration Min-max formulation ‘Plug-in’ linear program for inference Min-max formulation Assume linear loss (Hamming): Inference LP inference Min-max formulation By strong LP duality Minimize jointly over w, z Min-max formulation Formulation produces compact QP for Low-treewidth Markov networks Associative Markov networks Context free grammars Bipartite matchings Any problem with compact LP inference 3D Mapping Data provided by: Michael Montemerlo & Sebastian Thrun Laser Range Finder GPS IMU Label: ground, building, tree, shrub Training: 30 thousand points Testing: 3 million points Segmentation results Hand labeled 180K test points Model Accuracy SVM 68% V-SVM 73% M3N 93% Fly-through Certificate formulation Non-bipartite matchings: O(n3) combinatorial algorithm No polynomial-size LP known 2 3 1 4 6 Spanning trees kl ij Intuition: No polynomial-size LP known Simple certificate of optimality 5 Verifying optimality easier than optimizing Compact optimality condition of y* wrt. Certificate for non-bipartite matching 2 3 Alternating cycle: Every other edge is in matching Augmenting alternating cycle: 1 4 6 5 Score of edges not in matching greater than edges in matching Negate score of edges not in matching Augmenting alternating cycle = negative length alternating cycle Matching is optimal Edmonds ‘65 no negative alternating cycles Certificate for non-bipartite matching 2 Pick any node r as root = length of shortest alternating path from r to j 3 1 4 6 5 Triangle inequality: Theorem: No negative length cycle distance function d exists Can be expressed as linear constraints: O(n) distance variables, O(n2) constraints Certificate formulation Formulation produces compact QP for Spanning trees Non-bipartite matchings Any problem with compact optimality condition Disulfide connectivity prediction Dataset Swiss Prot protein database, release 39 446 sequences (4-50 cysteines) Features: window profiles (size 9) around each pair Two modes: bonded state known/unknown Comparison: Fariselli & Casadio 01, Baldi et al. 04 SVM-trained weights (ignoring constraints during learning) DAG Recursive Neural Network [Baldi et al. 04] Our model: Max-margin matching using RBF kernel Training: off-the-shelf LP/QP solver CPLEX (~1 hour) Known bonded state Precision / Accuracy Bonds SVM DAG RNN [Baldi et al., 04] 2 3 4 0.63 / 0.63 0.51 / 0.38 0.34 / 0.12 0.74 / 0.74 0.61 / 0.51 0.44 / 0.27 5 0.31 / 0.07 0.41 / 0.11 4-fold cross-validation Max-margin matching 0.77 / 0.77 0.62 / 0.52 0.51 / 0.36 0.43 / 0.16 Unknown bonded state Precision / Recall / Accuracy Bonds DAG RNN [Baldi et al., 04] 2 3 4 0.49 / 0.59 / 0.40 0.45 / 0.50 / 0.32 0.37 / 0.36 / 0.15 5 0.31 / 0.28 / 0.03 Max-margin matching 0.57 / 0.59 / 0.44 0.48 / 0.52 / 0.28 0.39 / 0.40 / 0.14 0.31 / 0.33 / 0.07 4-fold cross-validation Formulation summary Brute force enumeration Min-max formulation ‘Plug-in’ convex program for inference Certificate formulation Directly guarantee optimality of y* Estimation Margin Discriminative P(y|x) Generative P(x,y) MEMMs CRFs HMMs PCFGs MRFs Local P(z) = i P(zi|z) Global P(z) = 1/Z c (zc) Omissions Formulation details Kernels Multiple examples Slacks for non-separable case Approximate learning of intractable models General MRFs Learning to cluster Structured generalization bounds Scalable algorithms (no QP solver needed) Structured SMO (works for chains, trees) Structured EG (works for chains, trees) Structured PG (works for chains, matchings, AMNs, …) Current Work Learning approximate energy functions Semi-supervised learning Protein folding Physical processes Hidden variables Mixing labeled and unlabeled data Discriminative structure learning Using sparsifying priors Conclusion Two general techniques for structured large-margin estimation Exact, compact, convex formulations Allow efficient use of kernels Tractable when other estimation methods are not Structured generalization bounds Efficient learning algorithms Empirical success on many domains Papers at http://www.cs.berkeley.edu/~taskar Duals and Kernels Kernel trick works! Scoring functions (log-potentials) can use kernels Same for certificate formulation Length: ~8 chars Letter: 16x8 pixels 10-fold Train/Test 5000/50000 letters 600/6000 words Test error (average per-character) Handwriting Recognition 30 raw pixels quadratic kernel cubic kernel better 25 20 15 Models: 10 Multiclass-SVMs* 45% error reduction over linear CRFs 5 CRFs 33% error reduction over multiclass SVMs 3 M nets 0 MC–SVMs *Crammer & Singer 01 CRFs M^3 nets Hypertext Classification WebKB dataset Four CS department websites: 1300 pages/3500 links Classify each page: faculty, course, student, project, other Train on three universities/test on fourth better 20 relaxed dual Test Error 15 10 5 loopy belief propagation *Taskar et al 02 0 53% error reduction over SVMs 38% error reduction over RMNs SVMs RMNS M^3Ns Projected Gradient yk+1 Projecting y’ onto constraints: yk yk+3 yk+2 yk+4 min-cost convex flow for Markov nets, matchings Convergence: same as steepest gradient Conjugate gradient also possible (two-metric proj.) Min-Cost Flow for Markov Chains a-z a-z a-z a-z a-z a a a a a s t z z z z z Capacities = C Edge costs = For edges from node s, to node t, cost = 0 Min-Cost Flow for Bipartite Matchings s t Capacities = C Edge costs = For edges from node s, to node t, cost = 0 CFG Chart CNF tree = set of two types of parts: Constituents (A, s, e) CF-rules (A B C, s, m, e) CFG Inference LP inside outside Has integral solutions y for trees