Giter Site home page Giter Site logo

milan's Introduction

MILAN: Masked Image Pretraining on Language Assisted Representation

This repository contains the PyTorch implementation of the paper MILAN: Masked Image Pretraining on Language Assisted Representation.

  • This repo was built upon the MAE repo. Installation and preparation follow that repo.

Prepare the dataset

mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
cd ..
  • Extract the validation data and move the images to subfolders:
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash

Pretraining

Example of applying MILAN to pretrain ViT-B/16 on ImageNet-1K using 8 GPUs:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_pretrain.py \
    --model mae_vit_base_patch16 \
    --batch_size 256 \
    --accum_iter 2 \
    --mask_ratio 0.75 \
    --epochs 400 \
    --warmup_epochs 40 \
    --blr 1.5e-4 \
    --weight_decay 0.05 \
    --data_path /dataset/imagenet \
    --output_dir ./milan_vit_base_pretrain_400epoch_useclip_changedecoder_attnmask \ 
    --log_dir ./milan_vit_base_pretrain_400epoch_useclip_changedecoder_attnmask \
    --use_clip \
    --change_decoder \
    --attn_mask
  • The available --model choices are listed in models_milan.py.
  • Effective batch size is 256 (--batch_size per GPU) * 2 (--accum_iter gradient accumulation) * 8 (GPUs) = 4096.
  • Effective learning rate is 1.5e-4 (--blr base learning rate) * 4096 (effective batch size) / 256 = 2.4e-3.
  • --mask_ratio: percentage of patches to remove.
  • --epochs: total pretraining epochs, --warmup_epochs: learning rate warmup epochs.
  • We apply --weight decay of 0.05 during pretraining.
  • We use the ViT-B/16 CLIP image encoder obtained from here to produce the reconstruction targets during pretraining.
  • --change_decoder: switch to the prompting decoder.
  • --attn_mask: switch to the semantic aware masking strategy.
  • Training time is ~39h using 8 40GB A100 GPUs for 400 epochs.

Finetuning and linear probing

  • Example of finetuning the pretrained ViT-B/16 on ImageNet-1K:
OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
    --accum_iter 1 \
    --batch_size 128 \
    --model vit_base_patch16 \
    --finetune ./milan_vit_base_pretrain_400epoch_useclip_changedecoder_attnmask/checkpoint-399.pth \
    --epochs 100 \
    --blr 1e-4 --layer_decay 0.65 \
    --weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
    --dist_eval --data_path /data/imagenet \
    --output_dir ./milan_vit_base_finetune_pretrain400epochuseclipchangedecoderattnmask \ 
    --log_dir ./milan_vit_base_finetune_pretrain400epochuseclipchangedecoderattnmask \
    --global_pool
  • Example of performing linear probing on the pretrained ViT-B/16:
OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main_linprobe.py \
    --accum_iter 1 \
    --batch_size 2048 \
    --model vit_base_patch16 \
    --cls_token \
    --finetune ./milan_vit_base_pretrain_400epoch_useclip_changedecoder_attnmask/checkpoint-399.pth \
    --epochs 100 \
    --blr 0.05 \
    --weight_decay 0.0 \
    --dist_eval --data_path /data/imagenet \
    --output_dir ./milan_vit_base_linearprobe_pretrain400epochuseclipchangedecoderattnmask \
    --log_dir ./milan_vit_base_linearprobe_pretrain400epochuseclipchangedecoderattnmask

Checkpoints

We provide the pretrained ViT-B/16 and ViT-L/16 checkpoints.

ViT-Base ViT-Large
Pretrained checkpoint download download

Citation

If you find this repository helpful, please consider citing:

@article{MILAN2022,
  title   = {MILAN: Masked Image Pretraining on Language Assisted Representation},
  author  = {Hou, Zejiang and Sun, Fei and Chen, Yen-Kuang and Xie, Yuan and Kung, Sun-Yuan},
  journal = {arXiv preprint arXiv:2208.06049},
  year    = {2022},
}

milan's People

Contributors

zejiangh avatar

Stargazers

 avatar Siyuan Yan avatar ShengHeng Ye avatar Luis Reyes avatar  avatar wangq95 avatar jinqiwen avatar KhoiTrant avatar Xiaobing Han avatar Kostas Georgiou avatar Kaicheng Yang avatar ztao avatar  avatar  avatar Mobeen Ahmad avatar Lele Li avatar Stan Lei avatar Zeyu Lu (Lauren Lu) avatar Junghwan Heo avatar Udon avatar  avatar Qing Jiang avatar  avatar Yifan Zhang avatar  avatar Yixi avatar  avatar NoahRe1 avatar Qian Li avatar  avatar Naoto Usuyama avatar  avatar Jun Chen avatar Leilei Ma avatar Yao(Mark) Mu avatar Kunchang Li avatar goldpig avatar XingWu_UCAS avatar  avatar jyLin8100 avatar Mohan avatar Yaya Shi avatar Arka Sadhu avatar Frank (Haofan) Wang avatar Amir H. Farzaneh avatar Ge Wu avatar Gang Li avatar  avatar  avatar Yang Yang avatar Ethan avatar Lei Li avatar Chang-Bin Zhang avatar ChengZeLu avatar TzuRen avatar Zhixin Piao avatar  avatar Peng Xia avatar Songyang Zhang avatar  avatar RL avatar  avatar Researcher.YuanYuhui avatar  avatar Brandon Han avatar 爱可可-爱生活 avatar Yuhang Zang avatar Yuanhan Zhang avatar Jiazhi Yang avatar Yuchong Yao avatar Lixiang Ru avatar Li Bo avatar Yang avatar Rui Shao avatar  avatar  avatar  avatar

Watchers

XingWu_UCAS avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.