Clustering
Clustering data is a interesting topic because it aims to group data in such a way that a cluster contains data points that look similar and/or share some common properties. An example use case is the problem of Market Segmentation, where a business would like to take different actions based on different type of customers. One approach would be to make assumptions about the customers and create rule-based groups, a better approach is to use cluster analysis and derive data-driven groups.
Mixture models are probabilistic models that can be used to cluster data. The main idea/assumption is that our data are generated from a mixture of distributions rather than just one. Each one the those distributions are associated with one of the clusters that represent our data. There are different mixture models out there with the most popular one the Gaussian Mixture Model, which as the name suggests each of the corresponding components is a Gaussian distribution.
Expectation Maximization
If we knew beforehand the grouping of our data, then it would be easy to define a mixture model, for each of the clusters we fit and compute the corresponding parameters, e.g. for a Gaussian Mixture Model, the mean and covariance. On the other hand if we knew the model parameters then we would be able to group the data accordingly.
The problem here is that we do not have the grouping, neither the model parameters. In order to overcome this problem we use an approach called Expectation Maximization. During the expectation round, we compute the expected grouping of the data, while on the maximization round we update the model parameters by finding values that maximize the likelihood (i.e. the probability that the observations came out from those parameters).
Gaussian Mixture Model
Now lets see a few more details, we have observations
that we would like to cluster in
groups. In the case of a Gaussian Mixture Model we have
components and each of them is described by its mean
and covariance
. Our model has also component priors
that sum to one. The grouping that we do not know is represented as latent variables
, e.g.
means that distribution 2 is responsible for generating
.
The density function of our model is
our goal now is to find the parameters of the model, for all components. One popular approach is Maximum Likelihood Estimation, i.e. we will take the parameters that maximize the likelihood:
by setting the derivative of the log-likelihood to 0, we would end up with the following formulas for the means (similarly for mixing coefficients and covariances ):
here the is the posterior of the latent variables, i.e. the probability that component
created the datapoint
. As we said before, if we knew this probabilities then it would have been easy to compute the parameters and if we knew the parameters it would have been easy to compute those probabilities.
Therefore we employ an iterative approach where each step consists of two rounds.
Expectation – Compute the latent posteriors:
Maximization – Use the maximizers of the log-likelihood to update the model parameters:
As a recap, the training of such a model is: we randomly initialize the parameters, we follow expectation and maximization steps, and finally we stop when log-likelihood does not further increase (or the difference is smaller than a threshold).
In order to derive update rules for mixture models of other distributions, we need to compute the posterior of the latent variables for the expectation step, and to maximize the corresponding log-likelihood according to the parameters of the distribution.
Clustering Digits
In order to demonstrate the aforementioned model, we will use the MNIST dataset to cluster digits. The following Python code fit a 4-component GMM to the 0,1,2,3 digits. Plotting the means of the 4 distributions we clearly see that each cluster represent a different digit.
from __future__ import division import numpy as np from mnist import read, show from sklearn import mixture if __name__ == '__main__': img_gen = read("training") all_0123 = [img for img in img_gen if (img[0]<=3)] # vectorize imgs all_vec = [img[1].reshape((784)) for img in all_0123] # fit gmm with 4 components g = mixture.GMM(n_components=4) g.fit(all_vec) # have a look at the means mmmm = g.means_ m_a = mmmm[0, :] m_b = mmmm[1, :] m_c = mmmm[2, :] m_d = mmmm[3, :] show(m_a.reshape((28,28))) show(m_b.reshape((28,28))) show(m_c.reshape((28,28))) show(m_d.reshape((28,28)))
References
[1] Bishop, Christopher (2006). Pattern recognition and machine learning. New York: Springer
[2] MNIST dataset http://yann.lecun.com/exdb/mnist/
[3] Loading MNIST in Python https://gist.github.com/akesling/5358964