Giter Site home page Giter Site logo

aberhu / knowledge-distillation-zoo Goto Github PK

View Code? Open in Web Editor NEW
1.5K 1.5K 260.0 93 KB

Pytorch implementation of various Knowledge Distillation (KD) methods.

Python 89.48% Shell 10.52%
distillation kd kd-methods knowledge-distillation knowledge-transfer model-compression teacher-student

knowledge-distillation-zoo's People

Contributors

aberhu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

knowledge-distillation-zoo's Issues

About the temperature parameter in the AT method

I noticed that in the original paper and code of AT, the method AT+KD has a temperature parameter. However, it is not present in this code repository. Could you please clarify if this absence will have any impact?

About CRD Loss

Thanks for sharing. About CRD LOSS, why the student/teacher anchor is random initialized contant? Hope you can reply. Thanks

nan values when training student network (logits)

When running the logits, KD received nan values. I decided to modify the class below and realized that the student network didn't learn at all.
Update: This repo partially based on this. The original author also used log_softmax

class Logits(nn.Module):
	'''
	Do Deep Nets Really Need to be Deep?
	http://papers.nips.cc/paper/5484-do-deep-nets-really-need-to-be-deep.pdf
	'''
	def __init__(self):
		super(Logits, self).__init__()

	def forward(self, out_s, out_t):
		loss = F.mse_loss(out_s, out_t)
                #i printed the out_s here
                return loss

Figure: the tensors from out_s were vanishing quickly to nan

image

If anyone has the similar issue, please let me know.

Result compare

Thank you for open project.
The result compare cant comprehend that resnet20 teacher resnet20 got very well performence than other distillation. I guess teacher mode huge while student cant study more info.

Question on FitNet (Hint)

Thanks for your code!
From what I know, fitnet uses "hints" and both KD loss
seems like KDCriterion only use module as criterion

is there something I might be missing? thanks!

About lambda_st

Thank you very much for your project and it has helped me a lot. But I have some doubts about train_st.py.

cls_loss = criterionCls(output_s, target)
st_loss = criterionST(F.log_softmax(output_s/args.T, dim=1),
F.softmax(output_t/args.T, dim=1)) * (args.T*args.T) / img.size(0)
st_loss = st_loss * args.lambda_st
loss = cls_loss + st_loss

Here lambda_st is similar to a relaxation factor. But I see the following formula elsewhere.

 cls_loss = criterionCls(output_s, target) 
 st_loss  = criterionST(F.log_softmax(output_s/args.T, dim=1), 
 	                   F.softmax(output_t/args.T, dim=1)) * (args.T*args.T) / img.size(0) 

 loss = (1 - * args.lambda_st ) * cls_loss + st_loss * args.lambda_st 

Could you tell me the impact of the two?

https://github.com/PolarisShi/distillation/blob/8290b4b2138410cb4f1f39ecf2411f7dbb2a676a/student2.py#L218-L223

Additional Log in the softmax function

Hi,

I just have a question, why did you use an additional log before the softmax function in st.py loss on one input but not the other?

def forward(self, out_s, out_t):
		loss = F.kl_div(F.log_softmax(out_s/self.T, dim=1),
						F.softmax(out_t/self.T, dim=1),
						reduction='batchmean') * self.T * self.T

		return loss

How to use CRD loss for training on large scale dataset

Hi @AberHu ,

Thank you very much for your work.

Now I am working on using CRD loss for training my face recognition model, while your work and the original paper show that CRD loss with KL divergence for distillation is better than the others at this moment. But I found that it needs 2 memory buffers, which make it unfeasible when the dataset is really huge. So I wonder if there is a softer way to implement this.

Hope for your reply. Thanks.

About WRN

Could you give the code for WRN,thanks

train_init

What does this function that “train_init” do,Why is there one training function “train_init” and another training function “train”?

Realization of the mothods

Hi, thanks for your contribtution, I am studing knowledge distillation as now, would you mind that i create a pull request of some other competitive distillation methods while using PaddlePaddle, tensorflow or other frameworks?

Thanks for you good job on knowledge again!

About DML.

Dose anyone reproduce the result of WRN-28-10 in DML? I found that the accuracy is 80.75 in original paper while in DML it is 78.69.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.