Giter Site home page Giter Site logo

alphadl / knowledge-distillation Goto Github PK

View Code? Open in Web Editor NEW

This project forked from tejasgodambe/knowledge-distillation

0.0 2.0 0.0 175 KB

Transfer knowledge from a large DNN or an ensemble of DNNs into a small DNN

License: GNU General Public License v3.0

Python 16.27% Shell 22.83% Perl 60.89%

knowledge-distillation's Introduction

Knowledge distillation

This code is the implementation of the paper "Distilling the knowledge in a Neural Network" (https://arxiv.org/pdf/1503.02531.pdf).

Abstract

A very simple way to improve the performance of almost any machine learning algorithm is to train many different models on the same data and then to average their predictions [3]. Unfortunately, making predictions using a whole ensemble of models is cumbersome and may be too computationally expensive to allow deployment to a large number of users, especially if the individual models are large neural nets. Caruana and his collaborators [1] have shown that it is possible to compress the knowledge in an ensemble into a single model which is much easier to deploy and we develop this approach further using a different compression technique. We achieve some surprising results on MNIST and we show that we can significantly improve the acoustic model of a heavily used commercial system by distilling the knowledge in an ensemble of models into a single model. We also introduce a new type of ensemble composed of one or more full models and many specialist models which learn to distinguish fine-grained classes that the full models confuse. Unlike a mixture of experts, these specialist models can be trained rapidly and in parallel.

Scripts

  1. train_student_DNN.py
    This Python script is the main script. In this script, user has to provide the input directories such as alignments, HMM, data folder. The output directory is where the weights (after each epoch) and the final weights (both in hdf5 and txt format) are saved. In this script, we define the DNN architecture and also set the DNN configuration params.

  2. dataGenerator_teacher.py
    This Python script is the generator which provides batches to Keras' fit_generator while training teacher DNN.

  3. dataGenerator_student.py
    This Python script is the generator which provides batches to Keras' fit_generator while training student DNN.

  4. custom_crossentropy.py
    This Python script has the custom crossentropy loss used for training student DNN.

  5. softmax_with_temp.py
    This Python script has the implementation of softmax fn with temperature parameter.

  6. saveModel.py
    This Python script converts DNN weights from hdf5 to txt format.

knowledge-distillation's People

Contributors

tejasgodambe avatar

Watchers

Liang Ding avatar paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.