Source code for humancompatible.explain.glance.clustering.kmeans

from ..base import ClusteringMethod
from sklearn.cluster import KMeans


[docs] class KMeansMethod(ClusteringMethod): """ Implementation of a clustering method using KMeans. This class provides an interface to apply KMeans clustering to a dataset. """ def __init__(self, num_clusters, random_seed): """ Initializes the KMeansMethod class. Parameters: ---------- num_clusters : int The number of clusters to form as well as the number of centroids to generate. random_seed : int A seed for the random number generator to ensure reproducibility. """ self.num_clusters = num_clusters self.random_seed = random_seed self.model = KMeans()
[docs] def fit(self, data): """ Fits the KMeans model on the provided dataset. Parameters: ---------- data : array-like or sparse matrix, shape (n_samples, n_features) Training instances to cluster. Returns: ------- None """ self.model = KMeans( n_clusters=self.num_clusters, n_init=10, random_state=self.random_seed ) self.model.fit(data)
[docs] def predict(self, instances): """ Predicts the nearest cluster each sample in the provided data belongs to. Parameters: ---------- instances : array-like or sparse matrix, shape (n_samples, n_features) New data to predict. Returns: ------- labels : array, shape (n_samples,) Index of the cluster each sample belongs to. """ return self.model.predict(instances)