Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Using Clustering On Real Data | Basic Clustering Algorithms
Cluster Analysis
course content

Conteúdo do Curso

Cluster Analysis

Cluster Analysis

1. What is Clustering?
2. Basic Clustering Algorithms
3. How to choose the best model?

book
Using Clustering On Real Data

We have considered 4 clustering algorithms and looked at the principles of their work on toy datasets. Now let's try to use these clustering methods for solving the real-life problem with real data.

We will use the Iris dataset which consists of 50 samples from each of three species of Iris (Iris setosa, Iris virginica, and Iris versicolor); four features were measured from each sample: the length and the width of the sepals and petals, in centimeters. The task is to determine the type of Iris using these features: we will provide clustering and assume that each cluster represents one of the Iris species.

To provide understandable visualizations we will use only two features for clustering: the length of sepals and the length of petals. Let's look at our data:

1234567891011
from sklearn.datasets import load_iris import numpy as np import matplotlib.pyplot as plt X_iris, y_iris = load_iris(return_X_y=True) X_iris = X_iris[:, [0,2]] plt.scatter(X_iris[:, 0], X_iris[:, 1], c=y_iris, cmap='tab20b') plt.title('Iris data') plt.xlabel('Length of sepals') plt.ylabel('Length of petals') plt.show()
copy

Let's use K-means to provide clustering and compare results with real data:

123456789101112131415161718192021
from sklearn.datasets import load_iris from sklearn.cluster import KMeans import numpy as np import matplotlib.pyplot as plt import warnings warnings.filterwarnings('ignore') X_iris, y_iris = load_iris(return_X_y=True) X_iris = X_iris[:, [0,2]] kmeans = KMeans(n_clusters=3).fit(X_iris) fig, axes = plt.subplots(1, 2) fig.set_size_inches(10, 5) axes[0].scatter(X_iris[:, 0], X_iris[:, 1], c=y_iris, cmap='tab20b') axes[0].set_title('Real clusters') axes[1].scatter(X_iris[:, 0], X_iris[:, 1], c = kmeans.labels_, cmap='tab20b') axes[1].set_title('Clusters with K-means') plt.setp(axes[0], xlabel='Length of sepals') plt.setp(axes[0], ylabel='Length of petals') plt.setp(axes[1], xlabel='Length of sepals') plt.setp(axes[1], ylabel='Length of petals')
copy

Now let's try agglomerative algorithm:

1234567891011121314151617
from sklearn.datasets import load_iris from sklearn.cluster import AgglomerativeClustering import numpy as np import matplotlib.pyplot as plt X_iris, y_iris = load_iris(return_X_y=True) X_iris = X_iris[:, [0,2]] agglomerative = AgglomerativeClustering(n_clusters = 3).fit(X_iris) fig, axes = plt.subplots(1, 2) fig.set_size_inches(10, 5) axes[0].scatter(X_iris[:, 0], X_iris[:, 1], c=y_iris, cmap='tab20b') axes[0].set_title('Real clusters') axes[1].scatter(X_iris[:, 0], X_iris[:, 1], c = agglomerative.labels_, cmap='tab20b') axes[1].set_title('Clusters with Agglomerative') plt.setp(axes[0], xlabel='Length of sepals') plt.setp(axes[0], ylabel='Length of petals') plt.setp(axes[1], xlabel='Length of sepals') plt.setp(axes[1], ylabel='Length of petals')
copy

Note

We mentioned in the Agglomerative Clustering chapter that we can manually define the number of clusters. Here we used this ability because we have information about the number of target clusters (3 species of Iris = 3 clusters).

Using Mean shift clustering algorithm:

1234567891011121314151617
from sklearn.datasets import load_iris from sklearn.cluster import MeanShift import numpy as np import matplotlib.pyplot as plt X_iris, y_iris = load_iris(return_X_y=True) X_iris = X_iris[:, [0,2]] mean_shift= MeanShift(bandwidth=2).fit(X_iris) fig, axes = plt.subplots(1, 2) fig.set_size_inches(10, 5) axes[0].scatter(X_iris[:, 0], X_iris[:, 1], c=y_iris, cmap='tab20b') axes[0].set_title('Real clusters') axes[1].scatter(X_iris[:, 0], X_iris[:, 1], c = mean_shift.labels_, cmap='tab20b') axes[1].set_title('Clusters with Mean shift') plt.setp(axes[0], xlabel='Length of sepals') plt.setp(axes[0], ylabel='Length of petals') plt.setp(axes[1], xlabel='Length of sepals') plt.setp(axes[1], ylabel='Length of petals')
copy

Finally, let's try to use DBSCAN:

1234567891011121314151617
from sklearn.datasets import load_iris from sklearn.cluster import DBSCAN import numpy as np import matplotlib.pyplot as plt X_iris, y_iris = load_iris(return_X_y=True) X_iris = X_iris[:, [0,2]] dbscan = DBSCAN(eps=1, min_samples=10).fit(X_iris) fig, axes = plt.subplots(1, 2) fig.set_size_inches(10, 5) axes[0].scatter(X_iris[:, 0], X_iris[:, 1], c=y_iris, cmap='tab20b') axes[0].set_title('Real clusters') axes[1].scatter(X_iris[:, 0], X_iris[:, 1], c = dbscan.labels_, cmap='tab20b') axes[1].set_title('Clusters with DBSCAN') plt.setp(axes[0], xlabel='Length of sepals') plt.setp(axes[0], ylabel='Length of petals') plt.setp(axes[1], xlabel='Length of sepals') plt.setp(axes[1], ylabel='Length of petals')
copy

Note

In the code above, we manually defined the parameters for the algorithms (eps, min_samples for DBSCAN, and bandwidth for Mean shift). In real tasks, to determine optimal values of these parameters, it is necessary to use additional techniques (cross-validation, grid search, etc.).

We can see that due to visualizations K-means and Agglomerative algorithms can solve the task. At the same time, Mean shift and DBSCAN can't extract 3 different clusters.

Thus, we can conclude that for each individual task, an individual approach is needed: the choice of an algorithm, the selection of parameters, etc. In addition, it is necessary to set certain metrics with which we can evaluate the quality of clustering. The use of plots of clusters is not the best indicator for two reasons:

  1. Plots will not be able to adequately show the distribution into clusters for multivariate data( data with more than 3 features can't be visualized properly);
  2. Plots can show algorithms that give very poor results ( like DBSCAN and Mean shift in the example above). But if the results are good, then it is very difficult to understand where the clustering quality is better( like K-means and Agglomerative in the example above).

We will talk about evaluating the quality of clustering in the next section.

Tudo estava claro?

Como podemos melhorá-lo?

Obrigado pelo seu feedback!

Seção 2. Capítulo 8
We're sorry to hear that something went wrong. What happened?
some-alt