Giter Site home page Giter Site logo

slimclr's Introduction

SlimCLR

The official implementation of Slimmable Networks for Contrastive Self-supervised Learning .

Table of Contents

News

  • [29/07/2024] Accepted by IJCV. Repo online.

Introduction

Self-supervised learning makes significant progress in pre-training large models, but struggles with small models. Mainstream solutions to this problem rely mainly on knowledge distillation, which involves a two-stage procedure: first training a large teacher model and then distilling it to improve the generalization ability of smaller ones. In this work, we introduce another one-stage solution to obtain pre-trained small models without the need for extra teachers, namely, slimmable networks for contrastive self-supervised learning (SlimCLR). A slimmable network consists of a full network and several weight-sharing sub-networks, which can be pre-trained once to obtain various networks, including small ones with low computation costs. However, interference between weight-sharing networks leads to severe performance degradation in self-supervised cases, as evidenced by gradient magnitude imbalance and gradient direction divergence. The former indicates that a small proportion of parameters produce dominant gradients during backpropagation, while the main parameters may not be fully optimized. The latter shows that the gradient direction is disordered, and the optimization process is unstable. To address these issues, we introduce three techniques to make the main parameters produce dominant gradients and sub-networks have consistent outputs. These techniques include slow start training of sub-networks, online distillation, and loss re-weighting according to model sizes. Furthermore, theoretical results are presented to demonstrate that a single slimmable linear layer is sub-optimal during linear evaluation. Thus a switchable linear probe layer is applied during linear evaluation. We instantiate SlimCLR with typical contrastive learning frameworks and achieve better performance than previous arts with fewer parameters and FLOPs.

Features

  • Supervised training, SimCLR, MoCov2, MoCov3, SlimCLR-MoCov2, SlimCLR-MoCov3, visualization of optimization trajectory.
  • Fast data IO with LMDB / MXRecordIO. Fast data augmentation ops with NVIDIA DALI.

Installation

Prepare data

# script to extract ImageNet dataset
# ILSVRC2012_img_train.tar (about 138 GB)
# ILSVRC2012_img_val.tar (about 6.3 GB)
# make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar in your current directory

# 1. Extract the training data:

mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
cd ..

# 2. Extract the validation data and move images to subfolders:

mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash

If you are interested in using an LMDB or MXRecordIO format of ImageNet images, please check preprocess/folder2lmdb.py or preprocess/im2rec.sh

Generally, directories are organized as follows:

${ROOT}
├── dataset
│   │
│   ├──imagenet
│   │   ├──train
│   │   └──val    
│   └──coco
│       ├──train2017
│       └──val2017
│
├── code
│   └── SlimCLR
│ 
├── output (save the output of the program)
│
...

Dependency

Requires Python >= 3.8 and PyTorch >= 1.10. The following commands are tested on a Linux machine with CUDA Driver Version 525.105.17.

conda create --name slimclr python=3.8.19 -y
conda install pytorch==1.11.0 torchvision==0.12.0 cudatoolkit=11.3 -c pytorch -y
pip install -r requirements.txt 

Checkpoints

Model Epoch Ckpt Log
SlimCLR-MoCov2 200 48213217 48213217_log
SlimCLR-MoCoV2-Linear 200 789b1a3b17 789b1a3b17_log
SlimCLR-MoCov2 800 eda810a6a9 eda810a6a9_log
SlimCLR-MoCoV2-Linear 800 35600f623f 35600f623f_log
SlimCLR-MoCoV2-MaskRCNN 800 78afcc6ae3 78afcc6ae3_log
SlimCLR-MoCov3 300 57e298e9cd 57e298e9cd_log
SlimCLR-MoCoV3-Linear 300 e35321e95c e35321e95c_log

A backup of these checkpoints is at BaiduYunPan.

Training

Self-supervised Training

  • For MoCoV2 or SlimCLR-MoCov2, refer to
bash scripts/mocov2.sh
  • For MoCoV3 or SlimCLR-MoCov3, refer to
bash scripts/mocov3.sh
  • For Slimmable Networks with SimCLR, refer to
bash scripts/simclr.sh

For an inference purpose, set the variable test_only=1 in shell scripts.

Supervised Training

  • For supervised training with slimmable ResNet, refer to
bash scripts/slimmable.sh

Transfer Learning

  • For transfer learning with MaskRCNN, refer to
# training with 8 GPUs
bash benchmark/train.sh 8

BibTex

@article{zhao2024slimclr,
  title={Slimmable Networks for Contrastive Self-supervised Learning},
  author={Zhao, Shuai and Zhu, Linchao and Wang, Xiaohan and Yang, Yi},
  journal={International Journal of Computer Vision},
  year={2024},
}

Acknowledgements

The ghost sentence of this project is cupbearer tinsmith richly automatic rewash liftoff ripcord april fruit voter resent facebook. Check it at https://arxiv.org/abs/2403.15740.

slimclr's People

Contributors

mzhaoshuai avatar

Stargazers

zou hongwei avatar Xuanmeng Zhang avatar Bo Miao avatar  avatar Yucheng Suo avatar Chao Liang avatar

Watchers

Linchao Zhu avatar Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.