Giter Site home page Giter Site logo

fchiaroni / clustering_softmax_predictions Goto Github PK

View Code? Open in Web Editor NEW
6.0 1.0 2.0 10.22 MB

Python implementations of clustering algorithms applied on the probability simplex domain (e.g. clustering of softmax predictions from Black-Box source models).

License: MIT License

Python 100.00%
clustering gmm k-means k-medians k-medoids k-modes k-dirs k-sbetas kl-kmeans python

clustering_softmax_predictions's Introduction

Clustering Softmax Predictions

Updates

Paper

If you find this code useful for your research, please cite our paper:

@article{ch2022sc,
  title={Simplex Clustering via sBeta with Applications to Online Adjustments of Black-Box Predictions},
  author={Chiaroni, Florent and Boudiaf, Malik and Mitiche, Amar and Ben Ayed, Ismail},
  journal={arXiv preprint arXiv:2208.00287},
  year={2022}
}

Abstract

We explore clustering the softmax predictions of deep neural networks and introduce a novel probabilistic clustering method, referred to as k-sBetas. In the general context of clustering discrete distributions, the existing methods focused on exploring distortion measures tailored to simplex data, such as the KL divergence, as alternatives to the standard Euclidean distance. We provide a general maximum a posteriori (MAP) perspective of clustering distributions, which emphasizes that the statistical models underlying the existing distortion-based methods may not be descriptive enough. Instead, we optimize a mixed-variable objective measuring the conformity of data within each cluster to the introduced sBeta density function, whose parameters are constrained and estimated jointly with binary assignment variables. Our versatile formulation approximates a variety of parametric densities for modeling simplex data, and enables to control the cluster-balance bias. This yields highly competitive performances for unsupervised adjustments of black-box model predictions in a variety of scenarios. Our code and comparisons with the existing simplex-clustering approaches along with our introduced softmax-prediction benchmarks are publicly available: https://github.com/fchiaroni/Clustering_Softmax_Predictions.

Pre-requisites

  • Python 3.9.4
  • numpy 1.22.0
  • scikit-learn 0.24.1
  • scikit-learn-extra 0.2.0 (for k-medoids only)
  • Pytorch 1.11.0 (for GPU-based k-sBetas only)
  • CUDA 11.3 (for GPU-based k-sBetas only)

You can install all the pre-requisites using

$ cd <root_dir>
$ pip install -r requirements.txt

Datasets

The comparisons are performed on the following datasets:

Note that we used the source models implemented in this code https://github.com/tim-learn/SHOT to generate these real-world softmax prediction datasets.

Implemented clustering models

The script compare_softmax_preds_clustering.py compares the following clustering alogithms:

Running the code

You can select the methods to compare by setting the config file ./configs/select_methods_to_compare.py .

Compare clustering approaches on SVHN to MNIST dataset:

$ cd <root_dir>
$ python compare_softmax_preds_clustering.py --dataset SVHN_to_MNIST

Compare clustering approaches on VISDA-C dataset:

$ cd <root_dir>
$ python compare_softmax_preds_clustering.py --dataset VISDA_C

Compare clustering approaches on highly imbalanced iVISDA-Cs datasets:

$ cd <root_dir>
$ python compare_softmax_preds_clustering.py --dataset iVISDA_Cs

Run only k-sBetas (GPU-based):

$ cd <root_dir>/clustering_methods
$ python k_sbetas_GPU.py --dataset SVHN_to_MNIST
$ python k_sbetas_GPU.py --dataset VISDA_C
$ python k_sbetas_GPU.py --dataset iVISDA_Cs

Results

Table 1: Accuracy scoresTable 2: mean IoU scores
(Acc) SVHN to MNIST VISDA-C iVISDA-Cs
argmax 69.8 53.1 44.2
K-means 68.9 47.9 39.3
KL K-means 75.5 51.2 41.8
GMM 67.6 45.7 37.0
K-medians 68.8 40.0 36.9
K-medoids 71.3 46.8 40.4
K-modes 71.3 31.1 29.9
K-Betas 41.2 24.9 27.2
k-sBetas
(proposed)
76.5 56.0 46.8
(mIoU) SVHN to MNIST VISDA-C iVISDA-Cs
argmax 54.3 32.5 22.7
K-means 55.7 34.6 24.2
KL K-means 62.1 37.3 24.9
GMM 55.0 30.1 20.4
K-medians 56.0 29.6 22.4
K-medoids 57.5 33.7 22.5
K-modes 56.2 24.3 18.4
K-Betas 25.4 14.0 14.1
k-sBetas
(proposed)
63.6 39.0 26.9

Recommendations

  • The most appropriate value for the "delta" parameter of k-sBetas may change depending on the datasets distributions. We recommend to select delta using a validation set.
  • On small-scale datasets, the biased formulation for k-sBetas could be more stable.
  • On large-scale imbalanced datasets, the unbiased formulation provides better results.

clustering_softmax_predictions's People

Contributors

dependabot[bot] avatar fchiaroni avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.