Giter Site home page Giter Site logo

Comments (2)

Hzzone avatar Hzzone commented on September 26, 2024 2

Thanks for your interest.

  1. I have rewritten the k-means of sklearn using PyTorch, referring and https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html and https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/cluster/_kmeans.py#L51. I can provide the main code of my implementation:
import numpy as np
import torch
import tqdm
from ..__base__ import BasicClustering, pairwise_euclidean, pairwise_cosine
from .kmeans_plus_plus import _kmeans_plusplus


class PyTorchKMeans(BasicClustering):
    def __init__(self,
                 metric='euclidean',
                 init='k-means++',
                 random_state=0,
                 n_clusters=8,
                 n_init=10,
                 max_iter=300,
                 tol=1e-4,
                 distributed=False,
                 num_chunk=1,
                 verbose=True):
        super().__init__(n_clusters=n_clusters,
                         init=init,
                         random_state=random_state,
                         n_init=n_init,
                         max_iter=max_iter,
                         tol=tol,
                         verbose=verbose,
                         distributed=distributed)
        self.distance_metric = {'euclidean': pairwise_euclidean, 'cosine': pairwise_cosine}[metric]
        # self.distance_metric = lambda a, b: torch.cdist(a, b, p=2.)
        if isinstance(self.init, (np.ndarray, torch.Tensor)): self.n_init = 1
        self.num_chunk = num_chunk

    def initialize(self, X: torch.Tensor, random_state: int):
        num_samples = len(X)
        if isinstance(self.init, str):
            g = torch.Generator()
            g.manual_seed(random_state)
            if self.init == 'random':
                indices = torch.randperm(num_samples, generator=g)[:self.n_clusters]
                init_state = X[indices]
            elif self.init == 'k-means++':
                init_state, _ = _kmeans_plusplus(X,
                                                 random_state=random_state,
                                                 n_clusters=self.n_clusters,
                                                 pairwise_distance=self.distance_metric)
                # init_state = X[torch.randperm(num_samples, generator=g)[0]].unsqueeze(0)
                # for k in range(1, self.n_clusters):
                #     d = torch.min(self.distance_metric(X, init_state), dim=1)[0]
                #     init_state = torch.cat([init_state, X[torch.argmax(d)].unsqueeze(0)], dim=0)
            else:
                raise NotImplementedError
        elif isinstance(self.init, (np.ndarray, torch.Tensor)):
            init_state = self.init.to(X)
        else:
            raise NotImplementedError
        labels = torch.argmin(self.distance_metric(X, init_state), dim=1)

        return init_state, labels

    def fit_predict(self, X: torch.Tensor):

        tol = torch.mean(torch.var(X, dim=0)) * self.tol

        min_inertia, best_states, best_labels = float('Inf'), None, None

        random_states = torch.arange(self.n_init * self.world_size) + self.random_state
        random_states = random_states[self.rank:len(random_states):self.world_size]
        # g = torch.Generator()
        # g.manual_seed(self.random_state)
        # random_states = torch.randperm(10000, generator=g)[:self.n_init * self.world_size]
        # random_states = random_states[self.rank:self.n_init * self.world_size:self.world_size]

        for n_init in range(self.n_init):
            random_state = int(random_states[n_init])
            old_state, old_labels = self.initialize(X, random_state=random_state)

            labels = old_labels

            progress_bar = tqdm.tqdm(total=self.max_iter, disable=not ((self.rank == 0) and self.verbose))

            for n_iter in range(self.max_iter):

                # https://discuss.pytorch.org/t/groupby-aggregate-mean-in-pytorch/45335/7
                # n_samples = X.size(0)
                # weight = torch.zeros(self.n_clusters, n_samples, dtype=X.dtype, device=X.device)  # L, N
                # weight[labels, torch.arange(n_samples)] = 1
                # weight = F.normalize(weight, p=1, dim=1)  # l1 normalization
                # state = torch.mm(weight, X)  # L, F
                state = torch.zeros(self.n_clusters, X.size(1), dtype=X.dtype, device=X.device)
                counts = torch.zeros(self.n_clusters, dtype=X.dtype, device=X.device) + 1e-6
                classes, classes_counts = torch.unique(labels, return_counts=True)
                counts[classes] = classes_counts.to(X)
                state.index_add_(0, labels, X)
                state = state / counts.view(-1, 1)

                # d = self.distance_metric(X, state)
                # inertia, labels = d.min(dim=1)
                # inertia = inertia.sum()
                labels, inertia = self.predict(X, state)

                if inertia < min_inertia:
                    min_inertia = inertia
                    best_states, best_labels = state, labels

                if self.verbose:
                    progress_bar.set_description(
                        f'nredo {n_init + 1}/{self.n_init:02d}, iteration {n_iter:03d} with inertia {inertia:.2f}')
                    progress_bar.update(n=1)

                center_shift = self.distance_metric(old_state, state).diag()

                if torch.equal(labels, old_labels):
                    # First check the labels for strict convergence.
                    if self.verbose:
                        print(f"Converged at iteration {n_iter}: strict convergence.")
                    break
                else:
                    # center_shift = self.distance_metric(old_state, state).diag().sum()
                    # No strict convergence, check for tol based convergence.
                    # center_shift_tot = (center_shift ** 2).sum()
                    center_shift_tot = center_shift.sum()
                    if center_shift_tot <= tol:
                        if self.verbose:
                            print(
                                f"Converged at iteration {n_iter}: center shift "
                                f"{center_shift_tot} within tolerance {tol} "
                                f"and min inertia {min_inertia.item()}."
                            )
                        break

                old_labels[:] = labels
                old_state = state
            progress_bar.close()

        min_inertia, best_labels, best_states = self.distributed_sync(min_inertia, best_labels, best_states)

        if self.verbose:
            print(f"Final min inertia {min_inertia.item()}.")

        self.cluster_centers_ = best_states
        return best_labels

    def predict(self, X: torch.Tensor, cluster_centers_=None):
        if cluster_centers_ is None:
            cluster_centers_ = self.cluster_centers_
        if self.num_chunk > 1:
            split_centers = cluster_centers_.chunk(self.num_chunk, dim=1)
            labels, inertia = [], []
            class_indices = torch.arange(cluster_centers_.size(0), device=X.device).long()
            for centers in split_centers:
                d = self.distance_metric(X, centers)
                chunk_inertia, chunk_labels = d.min(dim=1)
                labels.append(class_indices[chunk_labels])
                inertia.append(chunk_inertia)
            labels = torch.cat(labels, dim=0)
            inertia = torch.cat(inertia, dim=0)
            indices = torch.argmin(inertia, dim=1)
            labels = labels[torch.arange(labels.size(0)), indices]
            inertia = inertia[torch.arange(labels.size(0)), indices].sum()
        else:
            d = self.distance_metric(X, cluster_centers_)
            inertia, labels = d.min(dim=1)
            inertia = inertia.sum()
        # c = torch.argmin(self.distance_metric(X, self.cluster_centers_), dim=1)
        return labels, inertia


if __name__ == '__main__':
    torch.cuda.set_device(1)
    clustering_model = PyTorchKMeans(metric='cosine',
                                     init='k-means++',
                                     random_state=0,
                                     n_clusters=1000,
                                     n_init=10,
                                     max_iter=300,
                                     tol=1e-4,
                                     distributed=False,
                                     verbose=True)
    X = torch.randn(1280000, 256).cuda()
    clustering_model.fit_predict(X)

The keys of this k-means are spherical, distributed, and k-means++ init. As I have tried, this implementation may be better than sklearn and outperformed faiss.
Hope this can help you.

Another possible spherical k-means clustering can be found at https://github.com/facebookresearch/swav/blob/06b1b7cbaf6ba2a792300d79c7299db98b93b7f9/main_swav.py#L354.
You can also use faiss, which can very fast for you:

import numpy as np
import torch
import torch.nn.functional as F
import torch.distributed as dist

try:
    import faiss
except:
    print('faiss not installed')
from .__base__ import BasicClustering


class FaissKMeans(BasicClustering):
    def __init__(self,
                 metric='euclidean',
                 n_clusters=8,
                 n_init=10,
                 max_iter=300,
                 random_state=1234,
                 distributed=False,
                 verbose=True):
        super().__init__(n_clusters=n_clusters,
                         n_init=n_init,
                         max_iter=max_iter,
                         distributed=distributed,
                         verbose=verbose)

        if metric == 'euclidean':
            self.spherical = False
        elif metric == 'cosine':
            self.spherical = True
        else:
            raise NotImplementedError
        self.random_state = random_state

    def apply_pca(self, X, dim):
        n, d = X.shape
        if self.spherical:
            X = F.normalize(X, dim=1)
        mat = faiss.PCAMatrix(d, dim)
        mat.train(n, X)
        X = mat.apply_py(X)

    def fit_predict(self, input: torch.Tensor):
        n, d = input.shape

        if self.spherical:
            X = F.normalize(input, dim=1)

        if input.is_cuda:
            device = input.device.index
        else:
            device = -1

        X = input.cpu().numpy().astype(np.float32)

        random_states = torch.arange(self.world_size) + self.random_state
        random_state = random_states[self.rank]
        if device > -1:
            # faiss implementation of k-means
            clus = faiss.Clustering(d, self.n_clusters)

            # Change faiss seed at each k-means so that the randomly picked
            # initialization centroids do not correspond to the same feature ids
            # from an epoch to another.
            #                 clus.seed = np.random.randint(1234)
            clus.seed = int(random_state)

            clus.niter = self.max_iter
            clus.max_points_per_centroid = 10000000
            clus.min_points_per_centroid = 10
            clus.spherical = self.spherical
            clus.nredo = self.n_init
            clus.verbose = self.verbose
            res = faiss.StandardGpuResources()
            flat_config = faiss.GpuIndexFlatConfig()
            flat_config.useFloat16 = False
            flat_config.device = device
            flat_config.verbose = self.verbose
            flat_config.spherical = self.spherical
            flat_config.nredo = self.n_init
            index = faiss.GpuIndexFlatL2(res, d, flat_config)

            # perform the training
            clus.train(X, index)
            D, I = index.search(X, 1)
        else:
            clus = faiss.Kmeans(d=d,
                                k=self.n_clusters,
                                niter=self.max_iter,
                                nredo=self.n_init,
                                verbose=self.verbose,
                                spherical=self.spherical)
            clus.train(X.astype(np.float32))
            # self.cluster_centers_ = self.kmeans.centroids
            D, I = clus.index.search.search(X, 1)  # for each sample, find cluster distance and assignments

        best_labels = torch.from_numpy(I.flatten()).to(input.device)
        min_inertia = torch.from_numpy(D.flatten()).to(input).sum()
        best_states = faiss.vector_to_array(clus.centroids).reshape(self.n_clusters, d)
        best_states = torch.from_numpy(best_states).to(input)

        min_inertia, best_labels, best_states = self.distributed_sync(min_inertia, best_labels, best_states)

        if self.verbose:
            print(f"Final min inertia {min_inertia.item()}.")

        self.cluster_centers_ = best_states
        return best_labels


if __name__ == '__main__':
    dist.init_process_group(backend='nccl', init_method='env://')
    torch.cuda.set_device(dist.get_rank())
    X = torch.randn(1280, 256).cuda()
    clustering_model = FaissKMeans(metric='euclidean',
                                   n_clusters=10,
                                   n_init=2,
                                   max_iter=1,
                                   random_state=1234,
                                   distributed=True,
                                   verbose=True)
    clustering_model.fit_predict(X)
  1. You can implement PS by sampling noise from a gaussian distribution.

from propos.

Hzzone avatar Hzzone commented on September 26, 2024

Hi!

I have uploaded my pytorch implementations of kmeans and GMM at https://github.com/Hzzone/torch_clustering.

The code of this paper will be updated soon later!

from propos.

Related Issues (17)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.