Course Content
Cluster Analysis
Cluster Analysis
Mean Shift Clustering
Mean shift is the most simple density-based clustering algorithm. Simply speaking, "mean shift" equals "iteratively shifting to the mean". In the algorithm, every data point is shifted to the "regional mean" step by step, and the location of the final destination of each point represents the cluster it belongs to. Algorithm consists of the next steps:
Step 1. For each data point, you have to create a sliding window with a specified radius (bandwidth);
Step 2. Shift each of the sliding windows towards higher density regions by shifting its centroid to the data points' mean within the window. This step will be repeated until there will be no increase in the number of points in the sliding window or the centroid will stop moving;
Step 3. Selection of sliding windows by merging overlapping windows. When multiple windows overlap, the window containing the most points is preserved, and the others are merged with it;
Step 4. Assign the data points to the sliding window where they reside. If the data point is out of the window, assign it to the nearest window.
Mean shift shifts the windows to a higher density region by shifting their centroid (center of the sliding window) to the mean of the data points inside the sliding window.
So the Mean shift algorithm is very similar to the K-means algorithm: it also works on the mean of the points and can only work on isolated clusters. But there is one significant difference: the algorithm does not need to manually set the number of clusters.
Let's look at the example of using Mean shift clustering in Python:
from sklearn.cluster import MeanShift import numpy as np import matplotlib.pyplot as plt from sklearn.datasets import make_blobs, make_moons # Create dataset for clustering X, y = make_blobs(n_samples=500, cluster_std=1, centers=4 ) transformation = [[0.6, -0.6], [-0.4, 0.8]] X_aniso = np.matmul(X, transformation) # Train Mean Shift model on blobs dataset and visualize the results blobs_clustering = MeanShift(bandwidth=2).fit(X_aniso) fig, axes = plt.subplots(1, 2) axes[0].scatter(X[:, 0], X[:, 1], c=blobs_clustering.labels_, s=50, cmap='tab20b') axes[0].set_title('Clustered anizo blobs data') axes[1].scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='tab20b') axes[1].set_title('Real anizo blobs data')
Let's check how Mean shift algorithm will deal with the moons dataset:
from sklearn.datasets import make_moons import matplotlib.pyplot as plt from sklearn.cluster import MeanShift # Create moons dataset for clustering X, y = make_moons(n_samples=500) # Fit Mean Shift model on moons dataset and visualize the results moons_clustering = MeanShift(bandwidth=0.7).fit(X) fig,axes = plt.subplots(1,2) axes[0].scatter(X[:, 0], X[:, 1], c=moons_clustering.labels_, s=50, cmap='tab20b') axes[0].set_title('Clustered moons data') axes[1].scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='tab20b') axes[1].set_title('Real moons data')
In the code above, we use the MeanShift
class to create the model: the bandwidth
parameter defines the radius within which the average value is calculated.
Note
In
MeanShift
class you can use.predict()
method to make predictions based on an already trained model.
Thanks for your feedback!