Giter Site home page Giter Site logo

trantorrepository / distilling-object-detectors Goto Github PK

View Code? Open in Web Editor NEW

This project forked from twangnh/distilling-object-detectors

0.0 1.0 0.0 2.67 MB

Implementations of CVPR 2019 paper Distilling Object Detectors with Fine-grained Feature Imitation

License: MIT License

Python 73.70% MATLAB 0.43% Shell 0.46% Cuda 10.57% C 14.26% C++ 0.58%

distilling-object-detectors's Introduction

Implementation of our CVPR 2019 paper Distilling Object Detectors with Fine-grained Feature Imitation

15% performance boost of student model

We have proposed a general distillation approach for anchor based object detection model to get enhanced small student model with the knowledge of large teacher model, which is othorgonal and can be further combined with other model compression method like quantization and pruning. The key observation of vanilla knowledge distillation technique is that the inter-class discrepancy of perdiction confidence reveals how curmbersome model tends to genearlize (e.g., how much confidence the model would put on cat label when the input is actually a dog). While our idea is the inter-location discrepancy of feature response near object also reveals how large detector tends to generalize (e.g., how is the model's response different for different near object anchor locations).

We release the code for distilling shufflenet based detector and VGG11 based Faster R-CNN, this code repository implements Faster R-CNN imitation based on pytorch-faster-rcnn. Check Distilling-ShuffleDet for tensorflow code of Shufflenet based detector imitation.

🔥Updating🔥

TODO

We have accumulated the following to-do list, which we hope to complete in the near future

  • Still to come:
    • Add more models(ResNet-FRCNN, FPN-FRCNN).
    • Implement SSD model distillation.

Distilling VGG11-FRCNN

pytorch 0.4.0 python2

Preparation

1 Clone the repository

First of all, clone the code

git clone https://github.com/twangnh/Distilling-Object-Detectors

Then, create a folder:

cd Distilling-Object-Detectors && mkdir data

2 Data preparation

  • PASCAL_VOC 07+12: Please follow the instructions in py-faster-rcnn to prepare VOC datasets. Actually, you can refer to any others. After downloading the data, create softlinks in the folder data/. The prepaired direcoty is like data/VOCdevkit2007/VOC2007/...

3 download imagenet pretrained model and trained VGG16-FRCNN teacher model

download imagenet pretrained VGG11 model at GoogleDrive and put it into data/pretrained_model/

download trained VGG16-FRCNN model at GoogleDrive and put it into data/VGG16-FRCNN/

Train

currently only batch size of 1 is supported

python trainval_net_sup.py --dataset pascal_voc --net vgg11 --bs 1 --nw 2 --lr 3e-3 --lr_decay_step 5 --cuda --s 1 --gpu 0
[session 1][epoch  1][iter    0/10022] loss: 13.4238, loss_sup: 0.0000, lr: 3.00e-03
			fg/bg=(15/241), time cost: 0.307381
			rpn_cls: 0.7839, rpn_box: 0.4312, rcnn_cls: 12.0130, rcnn_box 0.1957 
[session 1][epoch  1][iter  100/10022] loss: 2.1172, loss_sup: 0.0000, lr: 3.00e-03
			fg/bg=(15/241), time cost: 17.871297
			rpn_cls: 0.2372, rpn_box: 0.0492, rcnn_cls: 0.5382, rcnn_box 0.1475 
[session 1][epoch  1][iter  200/10022] loss: 2.3993, loss_sup: 0.0000, lr: 3.00e-03
			fg/bg=(27/229), time cost: 17.885193
			rpn_cls: 0.0451, rpn_box: 0.3003, rcnn_cls: 2.5547, rcnn_box 0.5216 
[session 1][epoch  1][iter  300/10022] loss: 1.6754, loss_sup: 0.0000, lr: 3.00e-03
			fg/bg=(21/235), time cost: 17.856990
			rpn_cls: 0.2837, rpn_box: 0.2542, rcnn_cls: 1.1131, rcnn_box 0.2073 
[session 1][epoch  1][iter  400/10022] loss: 1.6178, loss_sup: 0.1145, lr: 3.00e-03
			fg/bg=(23/233), time cost: 17.976755
			rpn_cls: 0.3597, rpn_box: 0.0106, rcnn_cls: 0.7343, rcnn_box 0.2363 
[session 1][epoch  1][iter  500/10022] loss: 1.4362, loss_sup: 9.6434, lr: 3.00e-03
			fg/bg=(32/224), time cost: 17.911143
			rpn_cls: 0.1783, rpn_box: 0.0235, rcnn_cls: 0.4522, rcnn_box 0.3731 
[session 1][epoch  1][iter  600/10022] loss: 1.3638, loss_sup: 8.4568, lr: 3.00e-03
			fg/bg=(18/238), time cost: 18.024369
			rpn_cls: 0.4774, rpn_box: 0.1143, rcnn_cls: 0.4781, rcnn_box 0.1663 

training should progress as above, where loss_sup is the imitation loss and firts 400 steps are warmup steps with no imitation (i.e., loss_sup = 0.) models will be saved in ./temp/vgg11/pascal_voc/xxx.pth

Train without imitation (baseline training)

currently only batch size of 1 is supported

python trainval_net_sup.py --dataset pascal_voc --net vgg11 --bs 1 --nw 2 --lr 3e-3 --lr_decay_step 5 --cuda --s 1 --gpu 0 --tfi True

models will be saved in ./temp/vgg11/pascal_voc/xxx.pth Note the imitation loss weight and warm up step can be further tuned with --ilw and --ws

Test

python test_net.py --dataset pascal_voc --net vgg11 --checksession 1 --checkepoch 2 --checkpoint 10021 --cuda --gpu 0

change checksession, checkepoch, checkpoint to test specific model

model   #GPUs batch size learning_rate(lr) lr_decay max_epoch mAP ckpt
VGG-16     1 1 1e-3 5   7   70.1 GoogleDrive
VGG-11     1 4 3e-3 15   59.6 GoogleDrive
VGG-11-I    8 16 3e-3  8   15 67.6 +8.0 GoogleDrive

models at max_epoch are reported

the numbers are different from the paper as they are independent running of the algorithm.

Test with trained model

download the trained model at the GoogleDrive link, run

python test_net.py --dataset pascal_voc --net vgg11 --load_name ./path_to/xxx.pth --cuda --gpu 0

Distilling ShuffleDet

...

We have implemented a single layer one-stage toy object detector with tensorflow, and mutli-gpu training with cross-gpu batch normalization, check Distilling-ShuffleDet for codes

Models Flops
/G
Params
/M
car pedestrian cyclist mAP ckpt
Easy Mod Hard Easy Mod Hard Easy Mod Hard
1x 5.1 1.6 85.7 74.3 65.8 63.2 55.6 50.6 69.7 51.0 49.1 62.8 GoogleDrive
0.5x 1.5 0.53 81.6 71.7 61.2 59.4 52.3 45.5 59.7 43.5 42.0 57.4 GoogleDrive
0.5x-I 1.5 0.53 84.9 72.9 64.1 60.7 53.3 47.2 69.0 46.2 44.9 60.4 GoogleDrive
+3.3 +1.2 +2.9 +1.3 +1.0 +1.7 +9.3 +2.7 +2.9 +3.0
0.25x 0.67 0.21 67.2 56.6 47.5 54.7 48.4 42.1 49.1 33.3 32.9 48.0 GoogleDrive
0.25x-I 0.67 0.21 76.6 62.3 54.6 56.8 48.2 42.6 56.6 37.3 36.5 52.4 GoogleDrive
+9.4 +5.7 +7.1 +2.1 -0.2 +0.5 +7.5 +4.0 +3.6 +4.4

models with highest mAP are reported for both baseline and distilled model

Note the numbers are different from the paper as they are independent running of the algorithm and we have migrated from single GPU training to multi-gpu training with larger batch size.

Distilling YoloV2

Third party implementation of distilling YOLOV2 on Widerface(codes not available yet, but very easy to implement)

Model size easy medium hard
YOLOv2 190MB 87.2 74.6 36.0
0.25x 12MB 78.2 69.8 35.6
0.25x-I 12MB 83.9 +5.7 74.9 +5.1 38.5 +2.9
0.15x 4.4MB 69.7 61.1 29.7
0.15x-I 4.4MB 79.3 +9.6 67.0 +5.9 32.0 +2.3

Citation

@inproceedings{wang2019distilling,
  title={Distilling Object Detectors With Fine-Grained Feature Imitation},
  author={Wang, Tao and Yuan, Li and Zhang, Xiaopeng and Feng, Jiashi},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  pages={4933--4942},
  year={2019}
}

License

The code and the models are MIT licensed, as found in the LICENSE file.

distilling-object-detectors's People

Contributors

twangnh avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.