Comments (2)
Thanks for your interest.
- I have rewritten the k-means of sklearn using PyTorch, referring and https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html and https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/cluster/_kmeans.py#L51. I can provide the main code of my implementation:
import numpy as np
import torch
import tqdm
from ..__base__ import BasicClustering, pairwise_euclidean, pairwise_cosine
from .kmeans_plus_plus import _kmeans_plusplus
class PyTorchKMeans(BasicClustering):
def __init__(self,
metric='euclidean',
init='k-means++',
random_state=0,
n_clusters=8,
n_init=10,
max_iter=300,
tol=1e-4,
distributed=False,
num_chunk=1,
verbose=True):
super().__init__(n_clusters=n_clusters,
init=init,
random_state=random_state,
n_init=n_init,
max_iter=max_iter,
tol=tol,
verbose=verbose,
distributed=distributed)
self.distance_metric = {'euclidean': pairwise_euclidean, 'cosine': pairwise_cosine}[metric]
# self.distance_metric = lambda a, b: torch.cdist(a, b, p=2.)
if isinstance(self.init, (np.ndarray, torch.Tensor)): self.n_init = 1
self.num_chunk = num_chunk
def initialize(self, X: torch.Tensor, random_state: int):
num_samples = len(X)
if isinstance(self.init, str):
g = torch.Generator()
g.manual_seed(random_state)
if self.init == 'random':
indices = torch.randperm(num_samples, generator=g)[:self.n_clusters]
init_state = X[indices]
elif self.init == 'k-means++':
init_state, _ = _kmeans_plusplus(X,
random_state=random_state,
n_clusters=self.n_clusters,
pairwise_distance=self.distance_metric)
# init_state = X[torch.randperm(num_samples, generator=g)[0]].unsqueeze(0)
# for k in range(1, self.n_clusters):
# d = torch.min(self.distance_metric(X, init_state), dim=1)[0]
# init_state = torch.cat([init_state, X[torch.argmax(d)].unsqueeze(0)], dim=0)
else:
raise NotImplementedError
elif isinstance(self.init, (np.ndarray, torch.Tensor)):
init_state = self.init.to(X)
else:
raise NotImplementedError
labels = torch.argmin(self.distance_metric(X, init_state), dim=1)
return init_state, labels
def fit_predict(self, X: torch.Tensor):
tol = torch.mean(torch.var(X, dim=0)) * self.tol
min_inertia, best_states, best_labels = float('Inf'), None, None
random_states = torch.arange(self.n_init * self.world_size) + self.random_state
random_states = random_states[self.rank:len(random_states):self.world_size]
# g = torch.Generator()
# g.manual_seed(self.random_state)
# random_states = torch.randperm(10000, generator=g)[:self.n_init * self.world_size]
# random_states = random_states[self.rank:self.n_init * self.world_size:self.world_size]
for n_init in range(self.n_init):
random_state = int(random_states[n_init])
old_state, old_labels = self.initialize(X, random_state=random_state)
labels = old_labels
progress_bar = tqdm.tqdm(total=self.max_iter, disable=not ((self.rank == 0) and self.verbose))
for n_iter in range(self.max_iter):
# https://discuss.pytorch.org/t/groupby-aggregate-mean-in-pytorch/45335/7
# n_samples = X.size(0)
# weight = torch.zeros(self.n_clusters, n_samples, dtype=X.dtype, device=X.device) # L, N
# weight[labels, torch.arange(n_samples)] = 1
# weight = F.normalize(weight, p=1, dim=1) # l1 normalization
# state = torch.mm(weight, X) # L, F
state = torch.zeros(self.n_clusters, X.size(1), dtype=X.dtype, device=X.device)
counts = torch.zeros(self.n_clusters, dtype=X.dtype, device=X.device) + 1e-6
classes, classes_counts = torch.unique(labels, return_counts=True)
counts[classes] = classes_counts.to(X)
state.index_add_(0, labels, X)
state = state / counts.view(-1, 1)
# d = self.distance_metric(X, state)
# inertia, labels = d.min(dim=1)
# inertia = inertia.sum()
labels, inertia = self.predict(X, state)
if inertia < min_inertia:
min_inertia = inertia
best_states, best_labels = state, labels
if self.verbose:
progress_bar.set_description(
f'nredo {n_init + 1}/{self.n_init:02d}, iteration {n_iter:03d} with inertia {inertia:.2f}')
progress_bar.update(n=1)
center_shift = self.distance_metric(old_state, state).diag()
if torch.equal(labels, old_labels):
# First check the labels for strict convergence.
if self.verbose:
print(f"Converged at iteration {n_iter}: strict convergence.")
break
else:
# center_shift = self.distance_metric(old_state, state).diag().sum()
# No strict convergence, check for tol based convergence.
# center_shift_tot = (center_shift ** 2).sum()
center_shift_tot = center_shift.sum()
if center_shift_tot <= tol:
if self.verbose:
print(
f"Converged at iteration {n_iter}: center shift "
f"{center_shift_tot} within tolerance {tol} "
f"and min inertia {min_inertia.item()}."
)
break
old_labels[:] = labels
old_state = state
progress_bar.close()
min_inertia, best_labels, best_states = self.distributed_sync(min_inertia, best_labels, best_states)
if self.verbose:
print(f"Final min inertia {min_inertia.item()}.")
self.cluster_centers_ = best_states
return best_labels
def predict(self, X: torch.Tensor, cluster_centers_=None):
if cluster_centers_ is None:
cluster_centers_ = self.cluster_centers_
if self.num_chunk > 1:
split_centers = cluster_centers_.chunk(self.num_chunk, dim=1)
labels, inertia = [], []
class_indices = torch.arange(cluster_centers_.size(0), device=X.device).long()
for centers in split_centers:
d = self.distance_metric(X, centers)
chunk_inertia, chunk_labels = d.min(dim=1)
labels.append(class_indices[chunk_labels])
inertia.append(chunk_inertia)
labels = torch.cat(labels, dim=0)
inertia = torch.cat(inertia, dim=0)
indices = torch.argmin(inertia, dim=1)
labels = labels[torch.arange(labels.size(0)), indices]
inertia = inertia[torch.arange(labels.size(0)), indices].sum()
else:
d = self.distance_metric(X, cluster_centers_)
inertia, labels = d.min(dim=1)
inertia = inertia.sum()
# c = torch.argmin(self.distance_metric(X, self.cluster_centers_), dim=1)
return labels, inertia
if __name__ == '__main__':
torch.cuda.set_device(1)
clustering_model = PyTorchKMeans(metric='cosine',
init='k-means++',
random_state=0,
n_clusters=1000,
n_init=10,
max_iter=300,
tol=1e-4,
distributed=False,
verbose=True)
X = torch.randn(1280000, 256).cuda()
clustering_model.fit_predict(X)
The keys of this k-means are spherical, distributed, and k-means++ init. As I have tried, this implementation may be better than sklearn and outperformed faiss.
Hope this can help you.
Another possible spherical k-means clustering can be found at https://github.com/facebookresearch/swav/blob/06b1b7cbaf6ba2a792300d79c7299db98b93b7f9/main_swav.py#L354.
You can also use faiss, which can very fast for you:
import numpy as np
import torch
import torch.nn.functional as F
import torch.distributed as dist
try:
import faiss
except:
print('faiss not installed')
from .__base__ import BasicClustering
class FaissKMeans(BasicClustering):
def __init__(self,
metric='euclidean',
n_clusters=8,
n_init=10,
max_iter=300,
random_state=1234,
distributed=False,
verbose=True):
super().__init__(n_clusters=n_clusters,
n_init=n_init,
max_iter=max_iter,
distributed=distributed,
verbose=verbose)
if metric == 'euclidean':
self.spherical = False
elif metric == 'cosine':
self.spherical = True
else:
raise NotImplementedError
self.random_state = random_state
def apply_pca(self, X, dim):
n, d = X.shape
if self.spherical:
X = F.normalize(X, dim=1)
mat = faiss.PCAMatrix(d, dim)
mat.train(n, X)
X = mat.apply_py(X)
def fit_predict(self, input: torch.Tensor):
n, d = input.shape
if self.spherical:
X = F.normalize(input, dim=1)
if input.is_cuda:
device = input.device.index
else:
device = -1
X = input.cpu().numpy().astype(np.float32)
random_states = torch.arange(self.world_size) + self.random_state
random_state = random_states[self.rank]
if device > -1:
# faiss implementation of k-means
clus = faiss.Clustering(d, self.n_clusters)
# Change faiss seed at each k-means so that the randomly picked
# initialization centroids do not correspond to the same feature ids
# from an epoch to another.
# clus.seed = np.random.randint(1234)
clus.seed = int(random_state)
clus.niter = self.max_iter
clus.max_points_per_centroid = 10000000
clus.min_points_per_centroid = 10
clus.spherical = self.spherical
clus.nredo = self.n_init
clus.verbose = self.verbose
res = faiss.StandardGpuResources()
flat_config = faiss.GpuIndexFlatConfig()
flat_config.useFloat16 = False
flat_config.device = device
flat_config.verbose = self.verbose
flat_config.spherical = self.spherical
flat_config.nredo = self.n_init
index = faiss.GpuIndexFlatL2(res, d, flat_config)
# perform the training
clus.train(X, index)
D, I = index.search(X, 1)
else:
clus = faiss.Kmeans(d=d,
k=self.n_clusters,
niter=self.max_iter,
nredo=self.n_init,
verbose=self.verbose,
spherical=self.spherical)
clus.train(X.astype(np.float32))
# self.cluster_centers_ = self.kmeans.centroids
D, I = clus.index.search.search(X, 1) # for each sample, find cluster distance and assignments
best_labels = torch.from_numpy(I.flatten()).to(input.device)
min_inertia = torch.from_numpy(D.flatten()).to(input).sum()
best_states = faiss.vector_to_array(clus.centroids).reshape(self.n_clusters, d)
best_states = torch.from_numpy(best_states).to(input)
min_inertia, best_labels, best_states = self.distributed_sync(min_inertia, best_labels, best_states)
if self.verbose:
print(f"Final min inertia {min_inertia.item()}.")
self.cluster_centers_ = best_states
return best_labels
if __name__ == '__main__':
dist.init_process_group(backend='nccl', init_method='env://')
torch.cuda.set_device(dist.get_rank())
X = torch.randn(1280, 256).cuda()
clustering_model = FaissKMeans(metric='euclidean',
n_clusters=10,
n_init=2,
max_iter=1,
random_state=1234,
distributed=True,
verbose=True)
clustering_model.fit_predict(X)
- You can implement PS by sampling noise from a gaussian distribution.
from propos.
Hi!
I have uploaded my pytorch implementations of kmeans and GMM at https://github.com/Hzzone/torch_clustering.
The code of this paper will be updated soon later!
from propos.
Related Issues (17)
- About Performance on Imagenetdogs HOT 4
- Could u please share the config file for STL? HOT 2
- Got 65% ACC for BYOL on ImageNetdogs HOT 2
- About cifar10 performance on SimSiam HOT 5
- How to change parameters when num_device =2 HOT 1
- How to train ProPos on the STL-10 dataset?
- what does the memory data loader do in the basic_template.py? HOT 1
- Could you provide the pre-trained models for ProPos on ImageNet-1k, ImageNet-50/100/200 HOT 17
- How to obtain the clustering result after training? HOT 3
- how to reproduce pcl HOT 2
- Can't Reproduce Result in CIFAR-20 HOT 14
- How to generate pseudo label for unlabeled data in STL-10 and use them in PSL term? HOT 1
- I can't get the results in the paper using the pretrained models you provided HOT 7
- How to train Propos on STL-10 dataset HOT 1
- Results on CIFAR10 with ResNet34 HOT 1
- Training speed of ImageNet-10 HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from propos.