Giter Site home page Giter Site logo

mcl-okd's Introduction

Multi-view contrastive learning for online knowledge distillation

This project provides source code for official implementation of Multi-view contrastive learning for online knowledge distillation (MCL-OKD) and unofficial implementations of some representative Online Knowledge Distillation (OKD) methods:

We use some representative image classification networks as the role of backbone for evaluating OKD methods:

Installation

Requirements

Ubuntu 18.04 LTS

Python 3.8

CUDA 11.1

PyTorch 1.6.0

Create three folders ./data, ./result, and ./checkpoint,

Perform experiments on CIFAR-100 dataset

Dataset

CIFAR-100 : download

unzip to the ./data folder

Training for baseline

python main_cifar_baseline.py --arch densenetd40k12 --gpu 0 
python main_cifar_baseline.py --arch resnet32 --gpu 0 
python main_cifar_baseline.py --arch vgg16 --gpu 0 
python main_cifar_baseline.py --arch resnet110 --gpu 0 
python main_cifar_baseline.py --arch hcgnet_A1 --gpu 0 

Training by DML

python main_cifar_dml.py --arch dml_densenetd40k12 --gpu 0
python main_cifar_dml.py --arch dml_resnet32 --gpu 0
python main_cifar_dml.py --arch dml_vgg16 --gpu 0
python main_cifar_dml.py --arch dml_resnet110 --gpu 0
python main_cifar_dml.py --arch dml_hcgnet_A1 --gpu 0

Training by CL-ILR

python main_cifar_cl_ilr.py --arch cl_ilr_densenetd40k12 --gpu 0 
python main_cifar_cl_ilr.py --arch cl_ilr_resnet32 --gpu 0 
python main_cifar_cl_ilr.py --arch cl_ilr_vgg16 --gpu 0 
python main_cifar_cl_ilr.py --arch cl_ilr_resnet110 --gpu 0 
python main_cifar_cl_ilr.py --arch cl_ilr_hcgnet_A1 --gpu 0

Training by ONE

python main_cifar_one.py --arch one_densenetd40k12 --gpu 0
python main_cifar_one.py --arch one_resnet32 --gpu 0
python main_cifar_one.py --arch one_vgg16 --gpu 0
python main_cifar_one.py --arch one_resnet110 --gpu 0
python main_cifar_one.py --arch one_hcgnet_A1 --gpu 0

Training by OKDDip

python main_cifar_okddip.py --arch okddip_densenetd40k12 --gpu 0
python main_cifar_okddip.py --arch okddip_resnet32 --gpu 0
python main_cifar_okddip.py --arch okddip_vgg16 --gpu 0
python main_cifar_okddip.py --arch okddip_resnet110 --gpu 0
python main_cifar_okddip.py --arch okddip_hcgnet_A1 --gpu 0

Training by MCL-OKD

python main_cifar_mcl_okd.py --arch mcl_okd_densenetd40k12 --nce_k 256 --gpu 0
python main_cifar_mcl_okd.py --arch mcl_okd_resnet32 --nce_k 256 --gpu 0
python main_cifar_mcl_okd.py --arch mcl_okd_vgg16 --nce_k 16384 --gpu 0
python main_cifar_mcl_okd.py --arch mcl_okd_resnet110 --nce_k 256 --gpu 0
python main_cifar_mcl_okd.py --arch mcl_okd_hcgnet_A1 --nce_k 16384 --gpu 0
Model Params FLOPs Baseline DML (Ens) CL-ILR (Ens) ONE (Ens) OKDDip (Ens) MCL-OKD (Ens)
DenseNet-40-12 0.19M 0.07G 29.17 27.34 (26.02) 27.38 (26.19) 29.01 (28.67) 28.75 (27.51) 26.04 (23.55)
ResNet-32 0.47M 0.07G 28.91 24.92 (22.97) 25.40 (24.03) 25.74 (24.03) 25.76 (23.73) 24.52 (22.00)
VGG-16 14.77M 0.31G 25.18 24.14 (23.27) 23.58 (22.96) 25.22 (25.12) 24.86 (24.52) 23.11 (22.36)
ResNet-110 1.17M 0.17G 23.62 21.51 (19.12) 21.16 (18.66) 22.19 (20.23) 21.05 (19.40) 20.39 (18.29)
HCGNet-A1 1.10M 0.15G 22.46 18.98 (17.86) 19.04 (18.35) 22.30 (21.64) 21.54 (20.97) 18.72 (17.54)
  • Ens : Ensemble performance with retaining all peer networks.

Perform experiments on ImageNet dataset

Dataset preparation

  • Download the ImageNet dataset to YOUR_IMAGENET_PATH and move validation images to labeled subfolders

  • Create a datasets subfolder and a symlink to the ImageNet dataset

$ ln -s PATH_TO_YOUR_IMAGENET ./data/

Folder of ImageNet Dataset:

data/ImageNet
├── train
├── val

Training for baseline

python main_imagenet_baseline.py --arch resnet34 --gpu 0

Training by MCL-OKD

python main_imagenet_mcl_okd.py --arch mcl_okd_resnet34 --gpu 0
Model Baseline MCL-OKD MCL-OKD (Ens)
ResNet-34 25.43 24.64 23.26

Ablation study on CIFAR-100 dataset

Model Baseline +MCL +MCL+DOT (MCL-OKD)
DenseNet-40-12 29.17 28.07 26.04
ResNet-32 28.91 27.29 24.52
VGG-16 25.18 23.86 23.11
ResNet-110 23.62 21.65 20.39
HCGNet-A1 22.46 20.76 18.72
  • MCL : The loss of Multi-view Contrastive Learning

  • DOT : The loss of Distillation from an Online Teacher

mcl-okd's People

Contributors

winycg avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.