Introduction to Matlab & Data Analysis Lecture 9: Unsupervised Learning & Clustering Eran Eden, Weizmann 2008 © Some of the slides in this lecture are based on the Matlab Statics toolbox user’s guide 1 Unsupervised & supervised learning… 2 Two ways to learn… … from rules … from examples Training from examples An enemy is Rule based inference An enemy is an animal with sharp teeth, larger than 20 cm, that barks... A friend is smaller then 20 cm or larger than 20 cm with feathers… A friend is 3 What is supervised learning? Each example consists of an input vector (feature vector) and a desired output (label) 1 2 1.1 2 1 1.5 1 4 0 1 Enemy 1.5 1 4 0 1 0.1 1 1 3 0 0.3 1 1 2 0 0.4 1 1 2 0 Friend The goal is to find a function that maps between the inputs and the desired output 4 What is unsupervised learning? 1 2 1.1 2 1.5 1 4 0 1.5 1 4 0 0.1 1 1 3 Furry animals 0.3 1 1 2 0.4 1 1 2 Feathery animals ?#!@! The goal is to cluster the data, i.e. form 'natural groupings' Supervised versus unsupervised learning Supervised learning: examples are labeled 1 2 1.1 2 1 1.5 1 4 0 1 1.5 1 4 0 1 0.1 1 1 3 0 Enemy 0.3 1 1 2 0 0.4 1 1 2 0 Friend Unsupervised learning: examples are unlabeled 1 2 1.1 2 1.5 1 4 0 1.5 1 4 0 Furry animals 0.1 1 1 3 0.3 1 1 2 0.4 1 1 2 Feathery animals 6 Why do we need clustering? Examples: Market segmentation (because many times we don’t know the labeling and we wish to learn the rules from the data, like in the stock market) Gene clustering - “Guilt by association” conditions Cluster I genes Cluster II Cluster III 7 Clustering techniques Clustering Partitional Hierarchical 8 K-means algorithm Clustering Partitional Hierarchical K-means 9 K-means algorithm The k-means algorithm partitions the data into k mutually exclusive clusters Feature 2 Feature 1 10 K-means algorithm The k-means algorithm partitions the data into k mutually exclusive clusters Feature 2 K=2 Feature 1 11 K-means algorithm The k-means algorithm partitions the data into k mutually exclusive clusters Feature 2 K=3 Feature 1 How does it work? 12 The K-means algorithm goal is to minimize the variance in every cluster Formal definition: K d ( x j , i ) minimize total intra-cluster variance i 1 x j Si Si is the ith cluster (i = 1, 2, ..., K) µi is the ith centroid of all the points in cluster Si d is a distance function Optimal solution Suboptimal solution 13 K-means algorithm If we knew the cluster assignment of each point we could easily compute the centroids positions If we knew the centroid positions we could easily assign each point to a cluster But we don’t know neither of them 14 K-means algorithm Algorithm description Choose the number of clusters, K Randomly choose initial positions of K centroids Assign each of the points to the “nearest centroid” (depends on distance measure) K=3 15 K-means algorithm Algorithm description Choose the number of clusters - K Randomly choose initial positions of K centroids Assign each of the points to the “nearest centroid” (depends on distance measure) K=3 16 K-means algorithm Algorithm description Choose the number of clusters - K Randomly choose initial positions of K centroids Assign each of the points to the “nearest centroid” (depends on distance measure) Calculate the intra cluster variance Re-compute centroid positions If solution (the intra cluster variance didn’t change) converges Stop! K=3 17 K-means algorithm Algorithm description Choose the number of clusters - K Randomly choose initial positions of K centroids Assign each of the points to the “nearest centroid” (depends on distance measure) Calculate the intra cluster variance Re-compute centroid positions If solution (the intra cluster variance didn’t change) converges Stop! K=3 18 K-means algorithm Algorithm description Choose the number of clusters - K Randomly choose initial positions of K centroids Assign each of the points to the “nearest centroid” (depends on distance measure) Calculate the intra cluster variance Re-compute centroid positions If solution (the intra cluster variance didn’t change) converges Stop! K=3 19 K-means algorithm Algorithm description Choose the number of clusters - K Randomly choose initial positions of K centroids Assign each of the points to the “nearest centroid” (depends on distance measure) Calculate the intra cluster variance Re-compute centroid positions If solution (the intra cluster variance didn’t change) converges Stop! K=3 20 K-means algorithm Algorithm description Choose the number of clusters - K Randomly choose initial positions of K centroids Assign each of the points to the “nearest centroid” (depends on distance measure) Calculate the intra cluster variance Re-compute centroid positions If solution (the intra cluster variance didn’t change) converges Stop! K=3 21 K-means: other things we need to consider How should we choose K? What type of distance measures can we use, and how to choose between them? Euclidean 1 – Cos(t) And more… t 22 K-means: other things we need to consider Does the algorithm guarantee convergence to an optimal solution? Can you think of strategies for solving this... 23 Before we learn how to do K-means in Matlab let’s look at some real data… In the 1920's, botanists collected measurements on the 1) sepal length 2) sepal width 3) petal length 4) petal width of 150 iris, 50 from each of three species (setosa, versicolor, virginica) The measurements became known as Fisher's iris data 24 Fisher’s Iris data 5.1000 5.6000 6.1000 5.6000 5.5000 4.8000 5.4000 Feature #4 2.8000 2.6000 3.0000 3.4000 3.1000 3.0000 3.1000 Feature #3 Sample #n 6.3000 6.1000 7.7000 6.3000 6.4000 6.0000 6.9000 Feature #2 Sample #1 Sample #2 Sample #3 Feature #1 >> load fisheriris >> size(meas) 150 4 >> meas 1.5000 1.4000 2.3000 2.4000 1.8000 1.8000 2.1000 etc... >> size(species) 150 1 >> species 'versicolor' ,'versicolor', 'virginica', 'virginica', 'virginica', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa‘, etc... 25 Exploring correlations in the Fisher’s Iris data param_names = {'sepal length', 'sepal width', 'petal length', 'petal width'}; gplotmatrix(meas); text([.05 .30 .55 .80], [-0.1, -0.1, -0.1, -0.1], param_names, 'FontSize',12); text([-0.12, -0.12, -0.12, -0.12], [0.80 0.55 0.30 0.05], param_names, 'FontSize',12, 'Rotation',90) The petal length and width are highly correlated. 26 Visualizing Fisher’s Iris data % 3D Visualization plot3(meas(:, 1),meas(:, 2),meas(:, 3), 'o'); xlabel('Sepal Length'); ylabel('Sepal Width'); zlabel('Petal Length'); view(-137,10); grid on 7 6 Petal Length 5 4 3 2 1 8 7 6 Sepal Length 5 4 5 4 3 Sepal Width 2 27 K-means using Matlab Doing K-means in Matlab is simple: [cidx2,cmeans2] = kmeans(meas,2); By default kmeans uses squared Euclidian distance The cluster index each sample belongs to The clusters centorids The data K 28 K-means using Matlab Displaying the algorithm iterations [cidx2,cmeans2] = kmeans(meas,2,'display','iter'); iter phase sum 1 1 615.104 2 1 421.975 3 1 270.577 4 1 195.266 5 1 157.738 6 1 152.348 7 2 152.348 7 iterations, total sum of distances = 152.348 29 K-means using Matlab Clustering Visualization ptsymb = {'bs','r^','md','go','c+'}; %Plot cluster points for i = 1:2 clust = (cidx2 == i); plot3(meas(clust,1),meas(clust,2),meas(clust,3),ptsymb{i}); hold on end Notice that clustering is done on 4 dimensions but visualization on 3 dimensions %Plot cluster centroid plot3(cmeans2(:,1),cmeans2(:,2),cmeans2(:,3),'ko'); hold off xlabel('Sepal Length'); ylabel('Sepal Width'); zlabel('Petal Length'); view(-137,10); grid on title('Iris data clustered with K-means where K = 2') 30 K-means using Matlab Clustering visualization Cluster 2 7 6 5 Petal Length Cluster 1 4 3 2 1 8 7 6 5 4 4.5 4 3.5 3 Sepal Width Sepal Length 2.5 2 31 K-means using Matlab Clustering visualization 7 6 5 Petal Length 4 3 2 1 8 7 6 5 4 4.5 4 3.5 3 Sepal Width Sepal Length 2.5 2 because the upper cluster is spread out, these three points are closer to the centroid of the lower cluster than to that of the upper cluster 32 K-means using Matlab Increasing the number of clusters >>[cidx3,cmeans3] = kmeans(meas,3,'display','iter'); iter phase num sum 1 1 150 147.9 2 1 4 145.177 3 1 4 143.924 4 1 3 143.61 5 1 1 143.542 6 1 2 143.414 7 1 2 143.023 8 1 2 142.823 9 1 1 142.786 10 1 1 142.754 11 2 1 142.754 11 iterations, total sum of distances = 142.754 33 K-means using Matlab Clustering visualization for i = 1:3 clust = (cidx3 == i); plot3(meas(clust,1),meas(clust,2),meas(clust,3),ptsymb{i}); hold on end plot3(cmeans3(:,1),cmeans3(:,2),cmeans3(:,3),'ko'); hold off xlabel('Sepal Length'); ylabel('Sepal Width'); zlabel('Petal Length'); view(-137,10); grid on 34 K-means using Matlab 7 6 Petal Length 5 4 Poor clustering: solution is suboptimal! 3 2 1 8 7 6 5 4 4.5 4 3.5 3 Sepal Width Sepal Length 2.5 2 35 K-means using Matlab Avoiding local minima using a replicates strategy >> [cidx3,cmeans3,sumd3] = kmeans(meas,3,'replicates',5,'display','final') 8 iterations, total sum of distances 11 iterations, total sum of distances 4 iterations, total sum of distances 8 iterations, total sum of distances 7 iterations, total sum of distances = = = = = 142.754 78.8514 78.8514 78.8514 142.754 There are two different final solutions. kmeans returns the best one. 36 K-means using Matlab Clustering visualization 7 6 5 Petal Length 4 3 2 1 8 7 6 5 4 4.5 4 3.5 3 Sepal Width Sepal Length 2.5 2 37 K-means using Matlab [cidx_cos,cmeans_cos] = kmeans(meas,3,'dist','cos'); 7 We can use the cos function as a distance measure between samples 6 Petal Length 5 4 3 2 1 8 7 6 5 4 4.5 4 3.5 3 Sepal Width Sepal Length 2.5 2 38 K-means using Matlab Which distance measure is more “suitable” for clustering the Iris data? We know the label of each sample. We can compare the clusters discovered by kmeans to the actual flower types. Note: usually in unsupervised learning we do NOT know the labels of the samples. 39 K-means using Matlab %Testing the clustering accuracy figure for i = 1:3 clust = find(cidx_cos == i); plot3(meas(clust,1),meas(clust,2),meas(clust,3),ptsymb{i}); hold on end xlabel('Sepal Length'); ylabel('Sepal Width'); zlabel('Petal Length'); view(-137,10); grid on org_idx(strcmp(species, 'versicolor')) = 1; org_idx(strcmp(species, 'setosa')) = 2; org_idx(strcmp(species, 'virginica')) = 3; miss = find(cidx_cos ~= org_idx'); plot3(meas(miss,1),meas(miss,2),meas(miss,3),'k*'); legend({'setosa','versicolor','virginica'},1); hold off 40 K-means using Matlab Cosine based distance: 5 misses setosa versicolor virginica 7 6 5 Petal Length 4 3 2 1 8 7 6 5 4 4.5 4 3.5 3 Sepal Width Sepal Length 2.5 2 41 K-means using Matlab Euclidean based distance: 14 misses setosa versicolor virginica 7 6 5 Petal Length 4 3 2 1 8 7 6 5 4 4.5 4 3.5 3 Sepal Width Sepal Length 2.5 2 42 How to choose K We need a quantitative method to assess the quality of a clustering… The silhouette value of a point is a measure of how similar a point is to points in its own cluster compared to points in other clusters Formal definition: s(i) b(i ) - a(i) max(a(i ), b(i)) a(i) is the average distance of the point i to the other points in its own cluster A d(i, C) is the average distance of the point i to the other points in the cluster C b(i) is the minimal d(i, C) over all clusters other than A a(i) b(i) 43 How to choose K Silhouette values ranges from -1 to +1: ~= 1 object is well classified ~ =0 object is on the border between 2 clusters ~= -1 Object is classified wrong! The silhouette coefficient is the average silhouette value over all points It is a quantitative measure that can assess the quality of a clustering 44 How to choose K To demonstrate the utility of the silhouette coefficient we can test it on synthetic data for which we know the number of clusters x1 = randn(1, 100); y1 = randn(1, 100); scatter(x1, y1, 25, [1 0 0], 'filled'); hold on x2 = randn(1, 100) + 3; y2 = randn(1, 100) + 3; scatter(x2, y2, 25, [0 1 0] , 'filled'); x3 = randn(1, 100) + 8; y3 = randn(1, 100); scatter(x3, y3, 25, [0 0 1], 'filled'); hold off 45 How to choose K To demonstrate the utility of the silhouette coefficient we can test it on synthetic data for which we know the number of clusters We know that K is 3 6 5 4 3 2 1 0 -1 -2 -3 -4 -2 0 2 4 6 8 10 12 46 How to choose K We run the k-means algorithm for different Ks >> x = [x1, x2, x3]; y = [y1, y2, y3]; data = [x', y']; >> [cidx2,cmeans2] = kmeans(data, 2,'replicates', 10); >> [cidx3,cmeans3] = kmeans(data, 3,'replicates', 10); >> [cidx4,cmeans4] = kmeans(data, 4,'replicates', 10); K=3 K=2 K=4 6 6 6 5 5 5 4 4 4 3 3 3 2 2 2 1 1 1 0 0 0 -1 -1 -1 -2 -2 -2 -3 -5 0 5 10 15 -3 -5 0 5 10 15 -3 -5 0 5 10 47 15 How to choose K Silhouette coefficient Computing the silhouette plots >>[silh2,h] = silhouette(data,cidx2); >> mean(silh2) 0.7856 Cluster 1 Points that are poorly clustered 2 0 0.2 0.4 0.6 Silhouette Value 0.8 1 48 How to choose K Computing the silhouette plots >> [silh3,h] = silhouette(data,cidx3); >> mean(silh3) 0.8109 Cluster 1 2 3 0 0.2 0.4 0.6 Silhouette Value 0.8 1 49 How to choose K Computing the silhouette plots >> [silh4,h] = silhouette(data,cidx4); >> mean(silh4) 0.6935 1 Cluster 2 3 4 0 0.2 0.4 0.6 Silhouette Value 0.8 1 50 How to choose K Optimal Silhouette value is achieved when K = 3 ! 0.9 0.85 0.8 Mean silhouette value 0.75 0.7 0.65 0.6 0.55 0.5 2 3 4 5 K 6 7 8 51 K-means algorithm Clustering Partitional Hierarchical investigate grouping in your data, simultaneously over a variety of scales K-means Agglomerative Divisive 52 Hierarchical clustering Algorithm description: 1) Determine the distance between each pair of points m * (m - 1) / 2 different pairs Types of distances (Euclidean, correlation, etc…) 1 2 3 4 5 1 0.0 2.9 1.0 3.0 3.0 2 2.9 0.0 2.5 3.4 2.5 3 1.0 2.5 0.0 2.1 2.1 4 3.0 3.4 2.1 0.0 1.0 5 3.0 2.5 2.1 1.0 0.0 53 Hierarchical clustering Algorithm description: 1) Determine the distance between each pair of points 2) Iteratively group points into a binary hierarchical tree (linkage) Connect the closest pair of points and re-compute distance matrix 10 2.5 8 2.0 6 7 The distance at which the pair of points were connected 1.0 1 3 4 5 2 54 Hierarchical clustering Algorithm description: 1) Determine the distance between each pair of points 2) Iteratively group points into a binary hierarchical tree (linkage) 3) Cut the hierarchical tree into clusters 2.5 2.0 1.0 1 3 4 5 2 55 Hierarchical clustering, other things we need to consider Types of linkage: Single linkage clustering (“nearest neighbor”). Distance between groups is defined as the distance between the closest pair of points from each group. 56 Hierarchical clustering, other things we need to consider Types of linkage: Complete linkage clustering (“farthest neighbor”): Distance between groups is defined as the distance between the most distant pair of points from two groups 57 Hierarchical clustering, other things we need to consider Types of linkage: Average linkage clustering: The distance between two clusters is defined as the average of distances between all pairs of points (of opposite clusters) 58 Hierarchical clustering, other things we need to consider Where to cut the tree Cutting at an arbitrary height 59 Hierarchical clustering, other things we need to consider Where to cut the tree Cutting at an arbitrary height Cutting at inconsistency links Compare the height of each link in the tree with the heights of links below it: If approximately equal This link exhibits a high level of consistency. There are no distinct divisions between the objects joined at this level of the hierarchy. Inconsistent link Inconsistent link consistent link If heights differ This link is said to be inconsistent in respect to the links below it. This indicates the border of a natural division in a data set. For formal definitions see toolbox help… 60 Hierarchical clustering using Matlab Load the Iris data. >> load fisheriris 1) Compute the distances between each pair >> euc_dist = pdist(meas,'euclidean'); >> size(euc_dist) ans = 1 11175 >> euc_dist 0.7810 0.3606 0.6708 0.9487 0.5831 1.0677. . . 61 Hierarchical clustering using Matlab 2) Iteratively group points into a binary hierarchical tree >> clust_tree_euc = linkage(euc_dist, 'average'); >> size(clust_tree_euc) ans = Other Linkage types: 149 3 complete, single, >> clust_tree_euc 75 54 173 70 174 79 98 90 182 160 179 171 Indices of a pair of points 0.2000 0.2000 0.2110 0.2189 0.2196 0.2225 median, and more… . . . Distance between the pair of points 62 Hierarchical clustering using Matlab Visualize the hierarchy tree (dendrogram): [h,nodes] = dendrogram(clust_tree_euc,0); set(gca,'TickDir','out','TickLength',[.002 0], 'XTickLabel',[]); 4 3.5 3 2.5 2 1.5 1 0.5 0 63 Hierarchical clustering using Matlab 3) Cutting the hierarchical tree into clusters For example: use arbitrary cutoff for generating 3 clusters hidx = cluster(clust_tree_euc,'MaxClust', 3); 4 3.5 3 2.5 2 1.5 1 0.5 0 64 Hierarchical clustering using Matlab Visualizing the actual data and clusters for i = 1 : max(hidx) clust = find(hidx == i); plot3(meas(clust,1),meas(clust,2),meas(clust,3),ptsymb{i}); hold on end hold off xlabel('Sepal Length'); ylabel('Sepal Width'); zlabel('Petal Length'); view(-137,10); grid on 65 Hierarchical clustering using Matlab 7 6 Petal Length 5 4 3 2 1 8 7 6 Sepal Length 5 4 5 4 3 2 Sepal Width 66 Hierarchical clustering using Matlab Summary of the steps we performed in order to generate the hierarchical clustering: euc_dist = pdist(meas,'euclidean'); clust_tree_euc = linkage(euc_dist, 'average'); hidx = cluster(clust_tree_euc,'MaxClust', 3); The following command will give the exact same results hidx = clusterdata(meas, 'maxclust', 3, 'linkage', 'average') 67 Hierarchical clustering Using 'single' linkage will generate a different dendrogram in which two distinct clusters apear 1.6 1.4 1.2 1 0.8 0.6 0.4 0.2 0 68