Giter Site home page Giter Site logo

knowledge_distillation_caffe's Introduction

KnowledgeDistillation Layer (Caffe implementation)

This is a CPU implementation of knowledge distillation in Caffe.
This code is heavily based on softmax_loss_layer.hpp and softmax_loss_layer.cpp.

Please refer to the paper

Hinton, G. Vinyals, O. and Dean, J. Distilling knowledge in a neural network. 2015.

Installation

  1. Install Caffe in your directory CAFFE
  2. Clone this repository in your directory ROOT
cd $ROOT
git clone https://github.com/wentianli/knowledge_distillation_caffe.git
  1. Move files to your Caffe folder
cp $ROOT/knowledge_distillation_layer.hpp $CAFFE/include/caffe/layers
cp $ROOT/knowledge_distillation_layer.cpp $CAFFE/src/caffe/layers
  1. Modify $CAFFE/src/caffe/proto/caffe.proto
    add optional KnowledgeDistillationParameter in LayerParameter
message LayerParameter {
  ...

  //next available layer-specific ID
  optional KnowledgeDistillationParameter knowledge_distillation_param = 147;
}


add message KnowledgeDistillationParameter

message KnowledgeDistillationParameter {
  optional float temperature = 1 [default = 1];
}
  1. Build Caffe

Usage

KnowledgeDistillation Layer has one specific parameter temperature.

The layer takes 2 or 3 input blobs:
bottom[0]: the logits of the student
bottom[1]: the logits of the teacher
bottom[2](optional): label inputs
The logits are first divided by temperatrue T, then mapped to probability distributions over classes using the softmax function. The layer computes KL divergence instead of cross entropy. The gradients are multiplied by T^2, as suggested in the paper.

  1. Common setting in prototxt (2 input blobs are given)
layer {
  name: "KD"
  type: "KnowledgeDistillation"
  bottom: "student_logits"
  bottom: "taecher_logits"
  top: "KL_div"
  include { phase: TRAIN }
  knowledge_distillation_param { temperature: 4 } #usually larger than 1
  loss_weight: 1
}
  1. If you have ignore_label, 3 input blobs should be given
layer {
  name: "KD"
  type: "KnowledgeDistillation"
  bottom: "student_logits"
  bottom: "taecher_logits"
  bottom: "label"
  top: "KL_div"
  include { phase: TRAIN }
  knowledge_distillation_param { temperature: 4 }
  loss_param {ignore_label: 2}
  loss_weight: 1
}

knowledge_distillation_caffe's People

Contributors

wentianli avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.