Giter Site home page Giter Site logo

cmt.pytorch's Introduction

CMT.pytorch

CMT on ImageNet-1K Classification

Model Top 1 Acc. Log Ckpt
CMT-Ti 79.0% github github
CMT-XS 81.8% github github
CMT-Small 83.5% github github
CMT-Base 84.5% github github

Set up

- python==3.6
- cuda==10.0

# other pytorch/timm version can also work

pip install torch==1.7.0 torchvision==0.8.1;
pip install timm==0.3.2;
pip install torchprofile;

# build apex

cd /your_path_to/apex-master/;
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is:

│path/to/imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Training

To train CMT-Tiny on ImageNet-1K on a single node with 8 gpus:

python -m torch.distributed.launch --nproc_per_node=8 train.py --data-path /your_path_to/imagenet/ --output_dir /your_path_to/output/ --model cmt_ti --batch-size 256 --apex-amp --input-size 160 --weight-decay 0.05 --drop-path 0.1 --epochs 800 --test_freq 100 --test_epoch 760 --warmup-lr 1e-7 --warmup-epochs 20 --lr 8e-4 --min-lr 1e-5 --no-model-ema

To train CMT-XS on ImageNet-1K on a single node with 8 gpus:

python -m torch.distributed.launch --nproc_per_node=8 train.py --data-path /your_path_to/imagenet/ --output_dir /your_path_to/output/ --model cmt_xs --batch-size 256 --apex-amp --input-size 192 --weight-decay 0.04 --drop-path 0.08 --epochs 400 --test_freq 100 --test_epoch 360 --warmup-lr 1e-6 --warmup-epochs 20 --lr 7e-4 --min-lr 2e-5 --model-ema-decay 0.9998

To train CMT-Small on ImageNet-1K on a single node with 8 gpus:

python -m torch.distributed.launch --nproc_per_node=8 train.py --data-path /your_path_to/imagenet/ --output_dir /your_path_to/output/ --model cmt_s --batch-size 128 --apex-amp --input-size 224 --weight-decay 0.05 --drop-path 0.1 --epochs 300 --test_freq 100 --test_epoch 260 --warmup-lr 1e-7 --warmup-epochs 20

To train CMT-Base on ImageNet-1K on a single node with 8 gpus:

python -m torch.distributed.launch --nproc_per_node=8 train.py --data-path /your_path_to/imagenet/ --output_dir /your_path_to/output/ --model cmt_b --batch-size 64 --apex-amp --input-size 256 --weight-decay 0.05 --drop-path 0.25 --epochs 300 --test_freq 100 --test_epoch 260 --warmup-lr 1e-6 --min-lr 2e-5 --warmup-epochs 20

Acknowledgement

This repo is based on DeiT and pytorch-image-models.

Citation

If you find this project useful in your research, please consider cite:

@article{guo2021cmt,
  title={Cmt: Convolutional neural networks meet vision transformers},
  author={Guo, Jianyuan and Han, Kai and Wu, Han and Xu, Chang and Tang, Yehui and Xu, Chunjing and Wang, Yunhe},
  journal={arXiv preprint arXiv:2107.06263},
  year={2021}
}

License

License: MIT

Other implementations

Keras code @leondgarse

cmt.pytorch's People

Contributors

ggjy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

cmt.pytorch's Issues

Wrong link for cmt_xs.pth

Thanks for sharing these!
I think your released cmt_xs.pth is actually a duplication of cmt_small.pth. Would it be convenient uploading a new one?
Btw, I've also ran some training earlier with a little different structure: keras_cv_attention_models/cmt. The main difference is that I'm using individual relative positional embedding for each attention block like beit. With 305 epochs, cmt_t + 160 input accuracy is 78.94, and I think it's not bad considering the parameters and FLOPs.

loading pretrained model

Hello, may I ask how the code you wrote trains your actual classification when loading the pre trained model? After reading the relevant code you wrote, it was not successfully modified. Can you tell me in which section to make the changes?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.