Giter Site home page Giter Site logo

kaiyuyue / mgd Goto Github PK

View Code? Open in Web Editor NEW
65.0 5.0 12.0 696 KB

Matching Guided Distillation (ECCV 2020)

Home Page: https://kaiyuyue.com/mgd

License: MIT License

Python 100.00%
matching-guided-distillation mgd model-distillation knowledge-distillation pytorch detectron2 kd detection classification moco-v2

mgd's Introduction

Matching Guided Distillation

Project Webpage | Paper | Zhihu Blog [ηŸ₯乎]

Updates

Introduction

This implementation is based on the official PyTorch ImageNet training code, which supports two training modes DataParallel (DP) and DistributedDataParallel (DDP). MGD for object detection is also re-implemented in Detectron2 as an external project.

introfig

Note: T : teacher feature tensors. S : student feature tensors. dp : distance function for distillation. Ci: i-th channel.

BibTex

@inproceedings{eccv20mgd,
    title     = {Matching Guided Distillation},
    author    = {Yue, Kaiyu and Deng, Jiangfan and Zhou, Feng},
    booktitle = {European Conference on Computer Vision (ECCV)},
    year      = {2020}
}

Software Version Used for Paper

  • Python - 3.7
  • PyTorch - 1.5.0 with torchvision - 0.6.0
  • Detectron2 Tree - 369a57d333

Quick & Easy Start

We take using ResNet-50 to distill ResNet-18 as an example, as shown in the below figure.

Note: models are from torchvision.

0. Install Dependencies

Install OR-Tools by pip install ortools.

1. Expose Intermediate Features

The function exposes intermediate features and final output logits. The only thing to do is copy the original forward context and expose any tensors you want to work with for distillation. Reference.

def extract_feature(self, x, preReLU=False):
    ...

    feat3 = self.layer3(x) # we expose layer3 output

    x = self.layer4(feat3)

    ...

    if not preReLU:
        feat3 = F.relu(feat3)

    return [feat3], x

2. Expose BN

The function exposes BN layers before the distillation position. Reference.

def get_bn_before_relu(self):
    if isinstance(self.layer1[0], Bottleneck):
        bn3 = self.layer3[-1].bn3
    elif isinstance(self.layer1[0], BasicBlock):
        bn3 = self.layer3[-1].bn2
    else:
        print('ResNet unknown block error !!!')
        raise
    
    return [bn3]

3. Indicate Channel Number

The function tells MGD the channel number of the intermediate feature maps. Reference.

def get_channel_num(self):
    return [1024]

4. Build Model

t_net = resnet50() # teacher model
s_net = resnet18() # student model

import mgd.builder
d_net = mgd.builder.MGDistiller(
    t_net,
    s_net,
    ignore_inds=[],
    reducer='amp',
    sync_bn=False,
    with_kd=True,
    preReLU=True,
    distributed=False, # DP mode: False | DDP mode: True
    det=False # work within Detectron2
)

5. Add MGD Steps In Training Procedure

Reference.

# init mgd params in the first start
mgd_update(train_loader, d_net)

# training loop
for epoch in range(total_epochs):

    # UPDATE_FREQ can be set by yourself
    if (epoch+1)%UPDATE_FREQ == 0:
        mgd_update(train_loader, d_net)

MGD In Tasks

Classification | Object Detecton | Unsupervised Learning.

Acknowledgements

We learn and use some part of codes from following projects. We thank these excellent works:

License

MIT. See LICENSE for details.

mgd's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

mgd's Issues

BNs for student features in MGDistiller

I notice that there are BNs for student features in MGDistiller. Why need additional batch normalizations? What if calculating distillation loss by student features directly ?

self.model has no attribute module in d2/train_net.py

I run your code and occur this error message, the self.model is an MGDistiller or SMDistiller object, and both of them don't have 'module' attribute, and I replace the self.model.module to self.model, the code works. It's right ?

Absolute Max Pooling

def amp(self, i, t_feats, s_feats, margins):
    """
    Absolute Max Pooling for channels reduction.
    """
    b, sc, h, w = s_feats[i].shape
    _, tc, _, _ = t_feats[i].shape

    groups = tc // sc

    t = []
    m = []
    for c in range(0, tc, sc):
        if c == (tc // sc) * sc and self.shave:
            continue

        t.append(t_feats[i][:, self.guided_inds[i][c:c+sc].detach(), :, :])
        m.append(margins[:, self.guided_inds[i][c:c+sc].detach(), :, :])

    t = torch.stack(t, dim=2)
    m = torch.stack(m, dim=2)

    t = t.reshape(b, sc, groups, -1)
    m = m.reshape(1, sc, groups, -1)

    t_inds = torch.argmax(t, dim=2)

    t = t.gather(2, t_inds.unsqueeze(2))
    m = m.mean(dim=2)

    t = t.reshape(b, sc, h, w)
    m = m.reshape(1, sc, 1, 1)

for Absolute Max Pooling ,t may be abs ?
the line t_inds = torch.argmax(t, dim=2) should change to t_inds = torch.argmax(torch.abs(t), dim=2)

'RetinaNet' object has no attribute 'get_channel_num'

Hi author,

Thanks a lot for sharing the nice work!
For detection part: RetinaNet-R50->RetinaNet-R18, run the command:
An error show up.
In mge/builder.py [line61]( self.t_channels = self.t_net.get_channel_num()):

line61: self.t_channels = self.t_net.get_channel_num()
AttributeError: 'RetinaNet' object has no attribute 'get_channel_num'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.