Giter Site home page Giter Site logo

rhfeiyang / ppot Goto Github PK

View Code? Open in Web Editor NEW
7.0 1.0 1.0 77 KB

Official implementation of 'P$^2$OT: Progressive Partial Optimal Transport for Deep Imbalanced Clustering'. (Accepted by ICLR 2024)

License: Other

Python 100.00%
clustering machine-learning optimal-transport imbalanced-clustering

ppot's Introduction

P2OT: Progressive Partial Optimal Transport for Deep Imbalanced Clustering

By Chuyu Zhang*, Hui Ren*, and Xuming He (* indicates equal contribution)

This repo contains the Pytorch implementation of our paper. (Accepted by ICLR 2024)

Installation

git clone https://github.com/rhfeiyang/PPOT.git
cd PPOT
conda env create -f environment.yml

Training

Setup

Follow the steps below to setup the datasets:

  • Change the file paths to the datasets in utils/mypath.py, e.g. /path/to/cifar100.

Our experimental evaluation includes the following datasets: CIFAR100, imagenet-r and iNaturalist18. Our code will build the imbalanced datasets automatically.

Train model

For training on different datasets, args --train_db_name and --val_db_name should be specified. For example:

# For cifar100(imbalance ratio 100):
python train.py --train_db_name cifar_im --val_db_name cifar_im --imbalance_ratio 0.01 --num_classes 100 --num_heads 2  --output_dir experiment/PPOT/cifar100/ckpts
# For imagenet-r:
python train.py --train_db_name imagenet-r_im --val_db_name imagenet-r_im --num_classes 200 --num_heads 1 --output_dir experiment/PPOT/imagenet-r/ckpts
# For iNature100, 500, 1000:
python train.py --train_db_name iNature_im --val_db_name iNature_im --num_classes 100 --num_heads 1 --output_dir experiment/PPOT/inature100/ckpts
python train.py --train_db_name iNature_im --val_db_name iNature_im --num_classes 500 --num_heads 1 --output_dir experiment/PPOT/inature500/ckpts
python train.py --train_db_name iNature_im --val_db_name iNature_im --num_classes 1000 --num_heads 1 --output_dir experiment/PPOT/inature1000/ckpts

Remarks

We use multi-heads(2 heads of the same number of clusters) on CIFAR100, while one head for others. If you want to try overclustering, for example, heads of 100 and 200 clusters, you should set --num_heads 2 --num_classes 100 200. For overclustering, only the first head will be used for evaluation.

Evaluation

For evaluation, just change the script file to "eval.py". Models in "output_dir" will be loaded. For example:

# For cifar100(imbalance ratio 100):
python eval.py --train_db_name cifar_im --val_db_name cifar_im --imbalance_ratio 0.01 --num_classes 100 --num_heads 2  --output_dir experiment/PPOT/cifar100/ckpts
# For imagenet-r:
python eval.py --train_db_name imagenet-r_im --val_db_name imagenet-r_im --num_classes 200 --num_heads 1 --output_dir experiment/PPOT/imagenet-r/ckpts
# For iNature100, 500, 1000:
python eval.py --train_db_name iNature_im --val_db_name iNature_im --num_classes 100 --num_heads 1 --output_dir experiment/PPOT/inature100/ckpts
python eval.py --train_db_name iNature_im --val_db_name iNature_im --num_classes 500 --num_heads 1 --output_dir experiment/PPOT/inature500/ckpts
python eval.py --train_db_name iNature_im --val_db_name iNature_im --num_classes 1000 --num_heads 1 --output_dir experiment/PPOT/inature1000/ckpts

License

This software is released under a creative commons license which allows for personal and research use only. For a commercial license please contact the authors. You can view a license summary here.

Citation

@misc{zhang2024p2ot,
      title={P$^2$OT: Progressive Partial Optimal Transport for Deep Imbalanced Clustering}, 
      author={Chuyu Zhang and Hui Ren and Xuming He},
      year={2024},
      eprint={2401.09266},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

ppot's People

Contributors

rhfeiyang avatar kleinzcy avatar

Stargazers

Miya Nakajima avatar Jeff Carpenter avatar  avatar  avatar Zeen Chi avatar Xiaobing Han avatar  avatar

Watchers

 avatar

Forkers

whuhxb

ppot's Issues

Some doubts about the method

Hi, I have some doubts about the method.

  • How sensitive is the proposed method to the assumption of a uniform prior distribution in the KL divergence constraint for pseudo-label generation?
  • Can you provide additional insights or empirical evidence supporting the choice of gradually increasing the fraction of high-confident samples (ρ) and its impact on model performance?
  • Could you elaborate on the specific efficiency gains achieved by the proposed solver in solving the unbalanced optimal transport problem?
  • Why are experiments not conducted on ImageNet datasets, especially when discussing pre-trained models? What considerations led to the choice of datasets, and are there potential limitations introduced by this selection?
  • In the context of imbalanced datasets, have you considered discussing the limitations of clustering accuracy (ACC), normalized mutual information (NMI), and F1-score as evaluation metrics? Are there other metrics or analyses that could provide a more comprehensive evaluation?
  • How sensitive is the proposed method's performance to the choice of hyperparameters such as λ, ϵ, and initial ρ? Have you conducted a sensitivity analysis or ablation study to explore these aspects?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.