Giter Site home page Giter Site logo

ofakd's Introduction

One-for-All: Bridge the Gap Between Heterogeneous Architectures in Knowledge Distillation

Official PyTorch implementation of OFA-KD, from the following paper:
One-for-All: Bridge the Gap Between Heterogeneous Architectures in Knowledge Distillation
Zhiwei Hao, Jianyuan Guo, Kai Han, Yehui Tang, Han Hu, Yunhe Wang, Chang Xu

This paper studies using heterogeneous teacher and student models for knowledge distillation (KD) and proposes a one-for-all KD framework (OFA-KD). In this framework, intermediate features are projected into an aligned latent space to discard architecture-specific information, and an adaptive target enhancement scheme is proposed to prevent the student from being disturbed by irrelevant information.

If you find this project useful in your research, please cite:

@inproceedings{hao2023ofa,
  author    = {Zhiwei Hao and Jianyuan Guo and Kai Han and Yehui Tang and Han Hu and Yunhe Wang and Chang Xu},
  title     = {One-for-All: Bridge the Gap Between Heterogeneous Architectures in Knowledge Distillation},
  booktitle = {Advances in Neural Information Processing Systems},
  year      = {2023}
}

Usage

First, clone the repository locally:

git clone https://github.com/Hao840/OFAKD.git

Then, install PyTorch and timm 0.6.5

conda install -c pytorch pytorch torchvision
pip install timm==0.6.5

Our results are produced with torch==1.10.2+cu113 torchvision==0.11.3+cu113 timm==0.6.5. Other versions might also work.

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is:

│path/to/imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Training on ImageNet

To train a resnet18 student using DeiT-T teacher on ImageNet on a single node with 8 GPUs, run:

python -m torch.distributed.launch --nproc_per_node=8 train.py /path/to/imagenet --config configs/imagenet/cnn.yaml --model resnet18 --teacher deit_tiny_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --distiller ofa --ofa-eps 1.5

Other results can be reproduced following similar commands by modifying:

--config : configuration of training strategy.

--model: student model architecture.

--teacher: teacher model architecture.

--teacher-pretrained: path to checkpoint of pretrained teacher model.

--distiller: which KD algorithm to use.

For information about other tunable parameters, please refer to train.py.

Training on CIFAR-100

To train a resnet18 student using Swin-T teacher on CIFAR-100 on a single node with 8 GPUs, run:

python -m torch.distributed.launch --nproc_per_node=8 train.py /path/to/cifar100 --config configs/cifar/cnn.yaml --model resnet18 --teacher swin_tiny_patch4_window7_224 --teacher-pretrained /path/to/teacher_checkpoint --num-classes 100 --distiller ofa --ofa-eps 1.0

Pretrained teacher models can be found here

Custom usage

KD algorithm: create new KD algorithm following examples in the ./distillers folder.

Model architecture: create new model architecture following examples in the ./custom_model folder. If intermediate features of the new model are required for KD, rewrite its forward() method following examples in the ./custom_forward folder.

Acknowledgement

This repository is built using the timm and the mdistiller library.

ofakd's People

Contributors

hao840 avatar

Stargazers

 avatar  avatar Bmgu avatar f0engc8hu1n2 avatar  avatar  avatar Jia-Chang Feng avatar Jinghuan Wei avatar 严加琪 avatar Andy Jeong avatar amir avatar  avatar Henry avatar  avatar  avatar  avatar hyrum avatar Austin Xiao avatar  avatar Jiajun Liu avatar chunguo avatar Wendy avatar ranpin avatar Wang Yifei avatar FanqingM avatar  avatar  avatar Li Kaiyuan avatar arashi_waseda avatar Benjay·Shaw avatar Ning Ding avatar ~Nick avatar xidian vegetable dog avatar Xiaolin Wang avatar Bin Chen avatar Mahmoud Afifi avatar Renbo Tu avatar Rinfall avatar  avatar  avatar  avatar 김진우(jinwooKim) avatar  avatar  avatar Zheng Qu avatar Jeff Carpenter avatar Garfield avatar betem avatar Xinyu Liu avatar  avatar  avatar nifeng avatar senlinuc avatar  avatar MingxinLiu avatar  avatar YoungT avatar  avatar Kevin Zheng avatar Zidea avatar YiwenCao avatar zhenghaohao avatar xum avatar Courage  avatar Yeojeong Park avatar  avatar Sang avatar  avatar jedibobo avatar  avatar Lovemefan avatar  avatar TZC avatar Jerry Wang avatar Yang Yingqing avatar  avatar  avatar Muhammad Haritsah Mukhlis avatar  avatar  avatar wangq95 avatar Guopeng Li avatar  avatar An-zhi WANG avatar Albert avatar  avatar cyberPanda avatar Yuan Zhi avatar  avatar

Watchers

 avatar  avatar

ofakd's Issues

Reproducing CIFAR-100 - Training Student and Teacher Models from Scratch

Hello,

I'm having some challenges reproducing the values in Table 2 (CIFAR-100). Hoping to get some feedback on how to fix this:

(1) Train from Scratch:
Could you please advise on the correct way to use the framework to train the student from scratch? I tried the following for ResNe18, but got very high results (95%+) that are different from the results reported in the paper.

Config used is the default. Here is the command I'm using:
python -m torch.distributed.launch --nproc_per_node=1 train.py /home/me/kd/cifar100 --config configs/cifar/cnn.yaml --model resnet18 --teacher swin_tiny_patch4_window7_224 --teacher-pretrained models/swin_tiny_patch4_window7_224.pth --distiller Vanilla

Is this correct? Please correct me if I'm wrong, but I'm using the Vanilla distiller to train the student from scratch on the hard labels and ignoring the teacher.

(2) Fine-Tuned Teacher
Could you share the sources and configs you used to train the teacher models from scratch (Swin-T, ViT-S, Mixer-B/16)?

(3) Release Configs
If it is not possible to release the models, is it possible for you to release the configs that were used to train the models? It would be immensely helpful for the debugging process.

Looking forward to your reply.

Thanks!

How do I combine FitNet when using OFA?

Thanks to the authors for their contributions. I am facing some problems in reproducing the OFA combined with FitNet approach, e.g. the case of ResNet-50 for teacher model and DeiT-T for student model. I don't know how to start my training process and wonder if the authors can give some specific guidance. Thank you very much for your help.

dataset

Hello, can this distillation model be used for time series models, the dataset I want to process is related to weather prediction, can this be used

Different from your paper.

Hi @Hao840 ,

I guess there may be some issues. You use Deit, Mobilenetv2, and ResMLP in your paper. But I only find Beit and Mobilenetv1 in your codes. Because I'm not familiar with their architecture, could you polish them or give me some instructions?

Can not reproduce MobileNetV2 on CIFAR-100

Hello Author,

Thank you for providing this excellent work.

However, I could not find the details regarding the implementation of MobileNetV2 in the code you provided. Could you please specify which version of the MobileNetV2 model you used? Additionally, could you provide the command to run it?

If possible, I would also appreciate it if you could share the related experiment log files.

Thank you again.

Accuracy in cifar100

Hi @Hao840 ,

I reproduced the results in cifar100 with kd (teacher convnext-tiny 89.96), but they are not normal compared to your paper. For example, acc@1 is 70.78, which is lower than the reported 72.99 when the student is deit-tiny. I guess that the training settings may be incorrect when the student is vit-based. Can you give me the "args.yaml" and log file?

Can you provide more information on the details of the training?

Thanks to the authors for their excellent work. However, there are some problems when trying to reproduce the distillation methods documented in the paper, and it is difficult to reproduce them to the documented accuracy. I wonder if the authors can provide information on the hyperparameters of each distillation method during training.

Reproduce results on ImageNet is too low

Hi, @Hao840 ,

I trained a Resnet18 student using with DeiT-T as teacher in ImageNet-1k with your latest codes, but currently have some problem to reproduce the results in your paper.

I followed the instructions provided in the README and trained the model on 2 V100 GPUs. However, as shown in the figure below, the model only achieved 0.1% accuracy on epoch 15, which is significantly lower than the results reported in your released log.

messageImage_1716184614088

Additionally, I noticed that even the vanilla model did not achieve good results on ImageNet, reaching only 0.08% accuracy by epoch 8. This suggests there might be some incorrect training configurations in the code (But I can reproduce the results on CIFAR-100). Do you have any suggestions about this problem ?

Thank you for your great work.

About the checkpoints

Dear authors,

Thanks for your impressive work. Is it possible to release the checkpoints of teacher and student in Table 1? Looking forward to your reply. Thanks again.

Inquiry about CKA Implementation

Hello,

I've been exploring your work on model comparisons and I'm particularly interested in the Centered Kernel Alignment (CKA) method you've mentioned.

Would it be possible for you to share the code snippet? Any additional guidance on using CKA with different neural network models would also be greatly appreciated.

Thank you for considering my request. I am looking forward to your response.

custom dataset other than CIFAR/ Imagenet

Hello,
I was able to run your code successfully,

Now my Aim is
I want to use the OFA loss for custom dataset, so how should I be approaching it,
Should I train the teacher from scratch or should I finetune the teacher model for the same then do it
What all Hyperparameters I should be focusing upon to do the same

Your suggestions being the author will be very helpful

Issue with training of student

Hiiie the work is fantastic
Was trying my hands on distillation using your code
Did following till now and need your guidance to move forward pls

  1. I have used the pretrained model given at the link your code of deit_tiny_patch16_224 and fine tuned it on CIFAR 100 for around 50 epochs to get an accuracy of 76% approx with following code
    import os
    import urllib.request
    import tarfile
    from torchvision import datasets, transforms
    from PIL import Image

Define the URL for CIFAR-100 dataset

cifar100_url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"

Define the target directory for the CIFAR-100 dataset

target_dir = '/content/datset'

Create the target directory if it doesn't exist

os.makedirs(target_dir, exist_ok=True)

Download and extract the CIFAR-100 dataset

tar_file_path = os.path.join(target_dir, 'cifar-100-python.tar.gz')
urllib.request.urlretrieve(cifar100_url, tar_file_path)

Extract the contents of the tar file

with tarfile.open(tar_file_path, 'r:gz') as tar:
tar.extractall(target_dir)

Define the transformations for the images

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)) # Normalize the images
])

Load the CIFAR-100 training dataset

train_dataset = datasets.CIFAR100(root=target_dir, train=True, download=False, transform=transform)

Save preprocessed images as JPG files for visualization

save_dir_train = os.path.join(target_dir, 'preprocessed_images', 'train')
os.makedirs(save_dir_train, exist_ok=True)

for i, (image, label) in enumerate(train_dataset):
image = transforms.ToPILImage()(image)
image.save(os.path.join(save_dir_train, f'image_{i}label{label}.jpg'))

Load the CIFAR-100 testing dataset

test_dataset = datasets.CIFAR100(root=target_dir, train=False, download=False, transform=transform)

Save preprocessed images as JPG files for visualization

save_dir_test = os.path.join(target_dir, 'preprocessed_images', 'test')
os.makedirs(save_dir_test, exist_ok=True)

for i, (image, label) in enumerate(test_dataset):
image = transforms.ToPILImage()(image)
image.save(os.path.join(save_dir_test, f'image_{i}label{label}.jpg'))

  1. Now I tried to do the distillation using the following wherein I used the finetuned model generated from above code-- !python -m torch.distributed.launch --nproc_per_node=1 train.py /content/datset/preprocessed_images/ --config configs/cifar/cnn.yaml --model resnet18 --teacher deit_tiny_patch16_224 --teacher-pretrained /content/fine_tuned_deit_model.pth --distiller ofa --ofa-eps 1.5
  2. now the issue is the test accuracy from 1st epoch I am getting is 0% thats strange and stops me from further progress
  3. I have attached the relevant screenshots just for your ref
  4. hoping to hear from you soon

image

Training problem on CIFAR100

I loaded the pretrained weights “deit_tiny_patch16_224-a1311bcf.pth”, but there is an error "size mismatch for head.weight: copying a param with shape torch.Size([1000, 192]) from checkpoint, the shape in current model is torch.Size([100, 192])", I have modified the "--num-classes" to 100

Can not reproduce some results on CIFAR-100.

Hello Author,

Thank you for providing this excellent work.

However, I could not reproduce OFA performance of these architecture on CIFAR-100:

  1. Mixer-B/16 -- DeiT-T
  2. Mixer-B/16 -- Swin-P
  3. ConvNeXt-T -- Swin-P
  4. ConvNeXt-T -- ResMLP-S12
  5. Swin-T -- ResMLP-S12

Could you provide the command to run it? And if possible, I would also appreciate it if you could share the related experiment log files.

Thank you again.

How to get the training results on cifia-100 in the paper?

Can you tell me the specific training strategies for vit-small teachers?
The best result I trained is 88.50. How did you get the 92.04 in the paper?

Top-1 Accuracy on CIFAR-100
Teacher: Swin-T -87.01
Teacher: ViT-S -88.50

your paper:
Teacher: Swin-T -89.26
Teacher: ViT-S -92.04

Fitnet for VIT-based models.

Hi @Hao840

I find that Fitnet is not supported for VIT-based models. But it's very important to reproduce hint-based results. Can you provide that?

image

Best,
Guopeng.

Training hyperparameters for ViT-based models

Hi, I've asked you for the training hyperparameters to reproduce your experimental results before, and I looked up the issue you guided me. However, I couldn't find the hyperparameter settings for training ViT-based models using KD.

I'm trying to reproduce the ConvNeXt-T to DeiT-T results for OFAKD and KD. I was able to reproduce the ConvNext-T teacher model's accuracy, but I can't get the accuracy for:

  • DeiT-T from scratch (yours: 68.00%, mine: 67.00%)
  • ConvNeXt-T to DeiT-T KD (yours: 72.99%, mine: 67.94%)

Can you please share the training hyperparameters for the above experiments?

Thank you again for your great work.

An error of CRD

Hi, @Hao840 ,

I miss an error when I run the codes for CRD. Could you tell me the solutions?
image

Can you provide the train from scratch checkpoint of Swin-N on ImageNet?

Hi, @Hao840

Thanks for your impressive work.

I am interested in conducting experiments on OFA for downstream tasks, which require the initial training checkpoint for the student model. However, the checkpoints for Swin-N have not been publicly released, and it is costly to train a new one for me.

Would it be possible for you to provide the required checkpoint? Looking forward to your reply.
Thanks

Reproduced Results

Hi, @Hao840 ,

I trained Resnet18 in CIFAR100 with cross-entropy loss by your latest codes, but it is difficult to reproduce 74.01 in your paper. My two results are about 77.
image

open discussions

Hi, @Hao840 ,

I want to know if all the models in this paper are trained from scratch.
If yes, have you tried fine-tuning some existing student models with your method? Does it work?
If not, are there any other special operations?

The motivation for this question is that distilling existing teacher models into existing student models is more time-saving and practical.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.