Giter Site home page Giter Site logo

amazon-science / exponential-moving-average-normalization Goto Github PK

View Code? Open in Web Editor NEW
97.0 11.0 13.0 73 KB

PyTorch implementation of EMAN for self-supervised and semi-supervised learning: https://arxiv.org/abs/2101.08482

Python 100.00%
computer-vision self-supervised-learning semi-supervised-learning normalization

exponential-moving-average-normalization's Introduction

EMAN: Exponential Moving Average Normalization for Self-supervised and Semi-supervised Learning

This is a PyTorch implementation of the EMAN paper. It supports three popular self-supervised and semi-supervised learning techniques, i.e., MoCo, BYOL and FixMatch.

If you use the code/model/results of this repository please cite:

@inproceedings{cai21eman,
  author  = {Zhaowei Cai and Avinash Ravichandran and Subhransu Maji and Charless Fowlkes and Zhuowen Tu and Stefano Soatto},
  title   = {Exponential Moving Average Normalization for Self-supervised and Semi-supervised Learning},
  booktitle = {CVPR},
  Year  = {2021}
}

Install

First, install PyTorch and torchvision. We have tested on version of 1.7.1, but the other versions should also be working, e.g. 1.5.1.

$ conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.1 -c pytorch

Also install other dependencies.

$ pip install pandas faiss-gpu

Data Preparation

Install ImageNet dataset following the official PyTorch ImageNet training code, with the standard data folder structure for the torchvision datasets.ImageFolder. Please download the ImageNet index files for semi-supervised learning experiments. The file structure should look like:

$ tree data
imagenet
├── train
│   ├── class1
│   │   └── *.jpeg
│   ├── class2
│   │   └── *.jpeg
│   └── ...
├── val
│   ├── class1
│   │   └── *.jpeg
│   ├── class2
│   │   └── *.jpeg
│   └── ...
└── indexes
    └── *_index.csv

Training

To do self-supervised pre-training of MoCo-v2 with EMAN for 200 epochs, run:

python main_moco.py \
  --arch MoCoEMAN --backbone resnet50_encoder \
  --epochs 200 --warmup-epoch 10 \
  --moco-t 0.2 --cos \
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
  /path/to/imagenet

To do self-supervised pre-training of BYOL with EMAN for 200 epochs, run:

python main_byol.py \
  --arch BYOLEMAN --backbone resnet50_encoder \
  --lr 1.8 -b 512 --wd 0.000001 \
  --byol-m 0.98 \
  --epochs 200 --cos --warmup-epoch 10 \
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
  /path/to/imagenet

To do semi-supervised training of FixMatch with EMAN for 100 epochs, run:

python main_fixmatch.py \
  --arch FixMatch --backbone resnet50_encoder \
  --eman \
  --lr 0.03 \
  --epochs 100 --schedule 60 80 \
  --warmup-epoch 5 \
  --trainindex_x train_10p_index.csv --trainindex_u train_90p_index.csv \
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
  /path/to/imagenet

Linear Classification and Finetuning

With a pre-trained model, to train a supervised linear classifier on frozen features/weights (e.g. MoCo) on 10% imagenet, run:

python main_lincls.py \
  -a resnet50 \
  --lr 30.0 \
  --epochs 50 --schedule 30 40 \
  --eval-freq 5 \
  --trainindex train_10p_index.csv \
  --model-prefix encoder_q \
  --pretrained /path/to/model_best.pth.tar \
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
  /path/to/imagenet

To finetune the self-supervised pretrained model on 10% imagenet, with different learning rates for pretrained backbone and last classification layer, run:

python main_cls.py \
  -a resnet50 \
  --lr 0.001 --lr-classifier 0.1 \
  --epochs 50 --schedule 30 40 \
  --eval-freq 5 \
  --trainindex train_10p_index.csv \
  --model-prefix encoder_q \
  --self-pretrained /path/to/model_best.pth.tar \
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
  /path/to/imagenet

For BYOL, change to --model-prefix online_net.backbone. For the best performance, follow the learning rate setting in Section 5.2 in the paper.

Models

Our pre-trained ResNet-50 models can be downloaded as following:

name epoch acc@1% IN acc@10% IN acc@100% IN model
MoCo-EMAN 200 48.9 60.5 67.7 download
MoCo-EMAN 800 55.4 64.0 70.1 download
MoCo-2X-EMAN 200 56.8 65.7 72.3 download
BYOL-EMAN 200 55.1 66.7 72.2 download

License

This project is under the CC-BY-NC 4.0 license. See LICENSE for details.

exponential-moving-average-normalization's People

Contributors

amazon-auto avatar zhaoweicai avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

exponential-moving-average-normalization's Issues

Asking for the experimental settings

Hi,

I have a question for the experimental setup in Table 7 of the paper (i.e., The comparison with other semi-supervised models.)
What are the experimental settings used when compared with the state-of-the-art methods?
Is it the set of hyper-parameters mentioned in this paper (64 batch size and 100 epochs) or the set of hyper-parameters used in FixMatch (1024 batch size and 300 epochs)? (Or maybe another set of hyper-parameters.)
Could you provide me with this information.
Many thanks.

Learning Rate schedule for 3x Fixmatch

Thank you for your amazing work !
Would you like to provide the Learning Rate schedule for 3x Fixmatch? Do you just simply drop the lr by 10x at 180 and 240 epoch?

Mapping momentum update of mean and variance in code

Hi,

Thank you for uploading the code.
I understand that EMAM is implemented in momentum_update(self, cur_iter, max_iter) method of BYOLEMAM class.
Since minibatch mean (mu) and variance(sigma^2) are not trainable parameters of the BN layer in PyTorch, does minibatch mean and variance gets 'momentum update' via running_mean and running_variance parameters of the BatchNorm layer?

Regards

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.