How to implement a clustering algorithm with K-Means in Python?
Today we are going to dive into the implementation of the K-Means algorithm, specifically using the mini-batch method (MiniBatch K-Means), for effective and efficient clustering. We will use a dataset containing features from 85 different candies. The goal: to obtain a detailed analysis of how to cluster these candies in a meaningful way.
What is the candy dataset?
The candy dataset has 85 different types and several features:
- Candy name: Identification of the candy.
- Attributes in composition: Whether it contains chocolate, fruits, etc.
- Sugar percentage: Relative amount of sugar compared to other candies.
- Price percentage: Comparative price with others.
- Public preference: Proportion of times it was chosen in one-to-one comparative tests.
How do we prepare the data in Python?
First we import the necessary libraries and load the data into a pandas DataFrame.
import pandas as pdfrom sklearn.cluster import MiniBatchKMeans
#df = pd.read_csv('data/Candy.csv')print(df.head(10))
It is important to look at the data to make sure you have loaded it correctly.
What is MiniBatch K-Means and how does it work?
MiniBatch K-Means is a variation of the traditional K-Means algorithm, specially optimized to run on resource-constrained machines. It works by grouping subsets of data (batches) instead of the whole, thus reducing memory usage and computation time.
How do we configure and train the model?
This time, we are going to configure our model for 4 clusters. This decision is based on the fictitious idea of a store that wants to organize its candies on 4 shelves, based on their similarities.
kmeans = MiniBatchKMeans(n_clusters=4, batch_size=8)kmeans.fit(df.drop(columns=['candy_name']))
How do we interpret the results?
Once the model is trained, we obtain:
- Cluster centers: we verify that 4 centers have been created as we want.
print(kmeans.cluster_centers_)
- Cluster predictions: Each candy is categorized into one of the 4 clusters, making it easier to interpret which group a candy most resembles.
cluster_labels = kmeans.predict(df.drop(columns=['candy_name']))df['cluster_label'] = cluster_labelsprint(df.head())
What's next after sorting?
With the clusters identified, it is possible to:
- Export the results to a file for sharing or future analysis.
- Plot data to visualize the clusters, if we want a more intuitive visual analysis.
df.to_csv('clustered_candy.csv')
This K-Means example culminates with the integration of the data and its clusters into a single file, facilitating further analysis. Now it is up to you to explore and continue learning about clustering methods and their applications in different areas!
Want to see more contributions, questions and answers from the community?