Giter Site home page Giter Site logo

sforaidl / kd_lib Goto Github PK

View Code? Open in Web Editor NEW
571.0 16.0 56.0 22.73 MB

A Pytorch Knowledge Distillation library for benchmarking and extending works in the domains of Knowledge Distillation, Pruning, and Quantization.

Home Page: https://kd-lib.readthedocs.io/

License: MIT License

Python 98.78% Makefile 1.22%
knowledge-distillation model-compression pruning quantization pytorch deep-learning-library machine-learning data-science benchmarking algorithm-implementations

kd_lib's Introduction

KD-Lib

A PyTorch model compression library containing easy-to-use methods for knowledge distillation, pruning, and quantization

Installation

From source (recommended)

https://github.com/SforAiDl/KD_Lib.git
cd KD_Lib
python setup.py install

From PyPI

pip install KD-Lib

Example usage

To implement the most basic version of knowledge distillation from Distilling the Knowledge in a Neural Network and plot loss curves:

import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import VanillaKD

# This part is where you define your datasets, dataloaders, models and optimizers

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

teacher_model = <your model>
student_model = <your model>

teacher_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
student_optimizer = optim.SGD(student_model.parameters(), 0.01)

# Now, this is where KD_Lib comes into the picture

distiller = VanillaKD(teacher_model, student_model, train_loader, test_loader, 
                      teacher_optimizer, student_optimizer)  
distiller.train_teacher(epochs=5, plot_losses=True, save_model=True)    # Train the teacher network
distiller.train_student(epochs=5, plot_losses=True, save_model=True)    # Train the student network
distiller.evaluate(teacher=False)                                       # Evaluate the student network
distiller.get_parameters()                                              # A utility function to get the number of 
                                                                        # parameters in the  teacher and the student network

To train a collection of 3 models in an online fashion using the framework in Deep Mutual Learning and log training details to Tensorboard:

import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import DML
from KD_Lib.models import ResNet18, ResNet50          # To use models packaged in KD_Lib

# Define your datasets, dataloaders, models and optimizers

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

student_params = [4, 4, 4, 4, 4]
student_model_1 = ResNet50(student_params, 1, 10)
student_model_2 = ResNet18(student_params, 1, 10)

student_cohort = [student_model_1, student_model_2]

student_optimizer_1 = optim.SGD(student_model_1.parameters(), 0.01)
student_optimizer_2 = optim.SGD(student_model_2.parameters(), 0.01)

student_optimizers = [student_optimizer_1, student_optimizer_2]

# Now, this is where KD_Lib comes into the picture 

distiller = DML(student_cohort, train_loader, test_loader, student_optimizers, log=True, logdir="./logs")

distiller.train_students(epochs=5)
distiller.evaluate()
distiller.get_parameters()

Methods Implemented

Some benchmark results can be found in the logs file.

Paper / Method Link Repository (KD_Lib/)
Distilling the Knowledge in a Neural Network https://arxiv.org/abs/1503.02531 KD/vision/vanilla
Improved Knowledge Distillation via Teacher Assistant https://arxiv.org/abs/1902.03393 KD/vision/TAKD
Relational Knowledge Distillation https://arxiv.org/abs/1904.05068 KD/vision/RKD
Distilling Knowledge from Noisy Teachers https://arxiv.org/abs/1610.09650 KD/vision/noisy
Paying More Attention To The Attention https://arxiv.org/abs/1612.03928 KD/vision/attention
Revisit Knowledge Distillation: a Teacher-free
Framework
https://arxiv.org/abs/1909.11723 KD/vision/teacher_free
Mean Teachers are Better Role Models https://arxiv.org/abs/1703.01780 KD/vision/mean_teacher
Knowledge Distillation via Route Constrained
Optimization
https://arxiv.org/abs/1904.09149 KD/vision/RCO
Born Again Neural Networks https://arxiv.org/abs/1805.04770 KD/vision/BANN
Preparing Lessons: Improve Knowledge Distillation
with Better Supervision
https://arxiv.org/abs/1911.07471 KD/vision/KA
Improving Generalization Robustness with Noisy
Collaboration in Knowledge Distillation
https://arxiv.org/abs/1910.05057 KD/vision/noisy
Distilling Task-Specific Knowledge from BERT into
Simple Neural Networks
https://arxiv.org/abs/1903.12136 KD/text/BERT2LSTM
Deep Mutual Learning https://arxiv.org/abs/1706.00384 KD/vision/DML
The Lottery Ticket Hypothesis: Finding Sparse,
Trainable Neural Networks
https://arxiv.org/abs/1803.03635 Pruning/lottery_tickets
Regularizing Class-wise Predictions via
Self-knowledge Distillation
https://arxiv.org/abs/2003.13964 KD/vision/CSDK

Please cite our pre-print if you find KD-Lib useful in any way :)

@misc{shah2020kdlib,
  title={KD-Lib: A PyTorch library for Knowledge Distillation, Pruning and Quantization}, 
  author={Het Shah and Avishree Khare and Neelay Shah and Khizir Siddiqui},
  year={2020},
  eprint={2011.14691},
  archivePrefix={arXiv},
  primaryClass={cs.LG}
}

kd_lib's People

Contributors

ashwinvaswani avatar avishreekh avatar dependabot[bot] avatar het-shah avatar khizirsiddiqui avatar neelays avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

kd_lib's Issues

Update README.rst

Currently, the readme has nothing in it. Add installation instructions for the library and also make a checklist of papers we have currently implemented.

Clean directories

Clear up every implementation. Now that every implementation is refactored clean up files that are not necessary.

Benchmarking Pruning and Quantization

We also need to benchmark the Lottery-tickets Pruning algorithm and the Quantization algorithms. The models used for this would be the student networks discussed in #105 (ResNet18, MobileNet v2, Quantization v2).

Pruning (benchmark upto 40, 50 and 60 % pruned weights)

  • Lottery Tickets

Quantization

  • Static
  • QAT

Pruning

Description

Pruning is the first step mostly taken to compress neural networks

Modifications

None as of now, might find some later on

Optuna

We can maybe add tutorials to use Optuna with KD Lib as of now. Later on, we can integrate it into the core of the library.

Paper: Revisit Knowledge Distillation: A Teacher Free Framework

Description

Proposes that one of the major reasons KD works is because of the presence of soft logits for the student to learn from instead of hard labels. Builds upon this to come up with two teacher-free frameworks - 

1. Virtual teacher - Generates soft labels during student trainig with a high probability(~0.9) given to 
                              the correct class and the rest distributed uniformly amongst the remaining 
                              classes.

2. Self training - Trains a copy of the student itself to be the teacher. Uses this copy to perform 
                           regular KD on the student. 

Paper: Subclass Distillation

Description

The teacher is made to divide each class into many subclasses that it invents
during the supervised training. The student is then trained to match the subclass probabilities
Intuition is that doing this provides more fine grained knowledge and helps the student learn faster.

[Paper] Regularizing Class-wise Predictions via Self-knowledge Distillation

Description

Deep neural networks with millions of parameters may suffer from poor generalization due to overfitting. To mitigate the issue, the authors propose a new regularization method that penalizes the predictive distribution between similar samples. In particular, they distill the predictive distribution between different samples of the same label during training. This results in regularizing the dark knowledge (i.e., the knowledge on wrong predictions) of a single network (i.e., a self-knowledge distillation) by forcing it to produce more meaningful and consistent predictions in a class-wise manner. Consequently, it mitigates overconfident predictions
and reduces intra-class variations. Experimental results on various image classification tasks demonstrate that the simple yet powerful method can significantly improve not only the generalization ability but also the calibration performance of modern convolutional neural networks.

Pip install "stable" doesn't work

Running pip install KD-Lib does not correctly install the library and leads to import errors (tested on Colab here).

Installing manually from source does work (as #82 mentions); Maybe the README should be updated to state only manual install is working?

Regardless, thanks for the cool library!

Evaluators missing in models

Some of the modules such as TAKD, mean teacher, etc. having functions missing for evaluation and validation.

Edit: Did not see base class implementations for it. Closing the issue.

Restructure text-based models and utilities

I think we should have a separate sub-directory for text models and associated utilities. What do you think @Het-Shah ?

We could look at these tasks, to begin with:

  • Add train, evaluate and distill functions for different models to text.utils
  • Move LSTM class to text.models
  • Move Bert2LSTM to text (after correcting the existing mistakes)

Benchmarking KD

We need to benchmark the following algorithms on three datasets (MNIST, CIFAR10, CIFAR100). This is so that we are sure that our implementations are fairly accurate on most datasets.

We also need to ensure that the distillation works with a variety of student networks. @Het-Shah has suggested that we report results on ResNet18, MobileNet v2 and ShuffleNet v2 as student networks. ResNet50 can be the teacher network for all the distillations.

  • VanillaKD
  • TAKD
  • Noisy Teacher
  • Attention
  • BANN
  • Bert2lstm
  • RCO
  • Messy Collab
  • Soft Random
  • CSKD
  • DML
  • Self-training
  • Virtual Teacher
  • RKD Loss
  • KA/ProbShift
  • KA/LabelSmoothReg

If you wish to work on any of the above algorithms, just mention the algorithms in the discussion.

Documentation

Description

Undocumented parts:

  • Parameters to init of BaseClass are undocumented.
  • calculate_kd_loss() for every class other than BaseClass.
  • All NoisyTeacher, orignal, attention require params in addition to BaseClass params.
  • Citations to papers in classes definition (?)
  • ResNet and other Models

This is not a priority obviously, we can do them in between other new paper implementations, I guess.

Implement Knowledge distillation by Functional Mapping

Description

  • Paper focuses on 2 important aspects of Knowledge Distillation: Consistency & Patience.
  • In function matching, the authors quote knowledge distillation shouldn’t just be about matching the predictions on this target data and you should try to increase the support of the data distribution. So what they use here is something called mixup augmentation, you can use out-of-domain data or this sort of mix-up data way of interpreting between data points to match the function across the data distribution with an interesting view of the sample.
  • Another component of the Knowledge distillation training recipe is patience. Knowledge distillation benefits from long training schedules.
    Results:
    image

Dynamic Quantization

Description

Reduce the size of a model after training so that mentioned layers have weights of type INT8 instead of FP.

Update setup.py

The requirements list which is passed into install_requires in setup() is empty.
It should be populated with requirements_dev.txt, shouldn't it ?

Born Again Neural Networks

Description

So far, most of the papers focus on transfer of knowledge from models of greater parameters to a model of smaller number of parameters. The experiments in this paper train student models with identical parameters and surprisingly outperform the teacher models.
It also introduces two distillation objectives: (i) Confidence-Weighted by Teacher Max (CWTM) and (ii) Dark Knowledge with Permuted Predictions (DKPP).

This was a pretty expensive experiment ~50k/USD (https://danieltakeshi.github.io/2018/05/27/bann/) I am not sure if we will able to reproduce the results in the paper.

NameError: name 'best_student_id' is not defined

I am working with the demo script for DML given in the REAMDE and get the following error after finishing training:
NameError: name 'best_student_id' is not defined

This is the code for the models and training:

student_model_1 = Shallow(img_size=28, hidden_size=100, num_classes=10, num_channels=1)
student_model_2 = Shallow(img_size=28, hidden_size=100, num_classes=10, num_channels=1)

student_cohort = [student_model_1, student_model_2]

student_optimizer_1 = optim.SGD(student_model_1.parameters(), 0.01)
student_optimizer_2 = optim.SGD(student_model_2.parameters(), 0.01)

student_optimizers = [student_optimizer_1, student_optimizer_2]

distiller = DML(student_cohort, train_loader, test_loader, student_optimizers, log=True, logdir="./Logs")
distiller.train_students(epochs=5)

Output:

Training students...
Accuracy: 0.9131
Accuracy: 0.9152
Epoch: 1, Loss: 1218.8985595703125, Accuracy: 0.87275
Accuracy: 0.9379
Accuracy: 0.9368
Epoch: 2, Loss: 567.33447265625, Accuracy: 0.9462166666666667
Accuracy: 0.9466
Accuracy: 0.9448
Epoch: 3, Loss: 459.7298889160156, Accuracy: 0.9591333333333333
Accuracy: 0.9526
Accuracy: 0.9544
Epoch: 4, Loss: 400.6322021484375, Accuracy: 0.9670666666666666
Accuracy: 0.957
Accuracy: 0.9584
Epoch: 5, Loss: 359.12127685546875, Accuracy: 0.9721666666666666

---------------------------------------------------------------------------

NameError                                 Traceback (most recent call last)

<ipython-input-12-e08730181774> in <module>()
----> 1 distiller.train_students(epochs=5)

/content/KD_Lib/KD_Lib/KD/vision/DML/dml.py in train_students(self, epochs, plot_losses, save_model, save_model_path)
    138         if save_model:
    139             print(
--> 140                 f"The best student model is the model number {best_student_id+1} in the cohort"
    141             )
    142             torch.save(self.best_student.state_dict(), save_model_path)

NameError: name 'best_student_id' is not defined

RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4

Traceback (most recent call last):
File "tools/train_kd_lib.py", line 98, in
distiller.train_teacher(epochs=5, plot_losses=True, save_model=True) # Train the teacher network
File "/home/xxx/anaconda3/envs/tp/lib/python3.8/site-packages/KD_Lib/KD/common/base_class.py", line 119, in train_teacher
loss = self.ce_fn(out, label)
File "/home/xxx/anaconda3/envs/tp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/xxx/anaconda3/envs/tp/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1047, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "/home/xxx/anaconda3/envs/tp/lib/python3.8/site-packages/torch/nn/functional.py", line 2693, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "/home/xxx/anaconda3/envs/tp/lib/python3.8/site-packages/torch/nn/functional.py", line 2390, in nll_loss
ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4

Paper: Paying More Attention To The Attention - Improving the Performance of CNNs via Attention Transfer

Description

The paper proposes a mechanism to transfer knowledge in form of attention (both gradient and activation based spatial attention maps). The approach provides significant increase in the performance of resnets and non-resnets.

Modifications

Not yet any? 

I have only read the abstract yet, sust to confirm, this paper lies within the KD domain, right? It seems like one of the approaches that can be used with other techniques in KD like using relational KD loss function along wit, maybe, Teacher-Assisted KD.
If this is completely different topic, please close this issue.

Restructuring KD_Lib

From the discussion in #57 , should I start with the restructuring ?
Also, the KD algos which are used for image classification in the papers, they could be applied to other classification tasks such as tabular classification as well as long as dataloaders are used. So, should we put them under a vision folder or would it be better to have a different name ?

Relational Knowledge Distillation

Description

RKD transfers mutual relations of data examples instead. They propose distance-wise and angle-wise distillation losses that penalize structural differences in relations. Excitingly, in metric learning, it allows students to outperform their teachers' performance, achieving the state of the arts on standard benchmark datasets. 

Modifications

Can't think of any presently. Will update if I do.

Making a pipeline for Pruning, Quantization and Knowledge Distillation

Currently, the user has to import and run all three things independently. Having a pipeline will make the entire work streamlined for the end-user.

The user should be able to add pipelines like [KD, Pruning] or [KD, Quantization].

We can discuss how we want to go about doing this.

DML Loss function

  • Paper: Deep Mutual Learning
  • Paper Link:

Description

Hi. According to the paper, The total loss objective should be the summation of cross-entropy and Kullback_Leibler divergence. But I found that you have used MSE loss. Is there a specific reason for it?```
### Modifications 

Device parameter of Base class

The Base class has a parameter device for choosing between CPU and GPU. However, there are a few issues with respect to this:

  1. There is no check for the validity of the device specified. For example, I could specify a device "kd" which does not exist.
  2. There is no check for GPU availability when "cuda" is specified as a parameter. The class should default to use CPU when cuda is not available.

Distributed Training

We need to add support for Distributed training, we can directly make use of Pytorch DDP if we want as of now. Let me know if anyone wants to take this up.

import error

Hello
some error occurs when I run the basic version of code
from KD_Lib import DML
from KD_Lib import ResNet18, ResNet50
from KD_Lib import VanillaKD

ImportError: cannot import name 'VanillaKD' from 'KD_Lib' (/usr/local/lib/python3.7/dist-packages/KD_Lib/init.py)
how can I handle this error?

thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.