Giter Site home page Giter Site logo

ida-det's Introduction

IDa-Det: An Information Discrepancy-aware Distillation for 1-bit Detectors

Pytorch implementation of our paper "IDa-Det: An Information Discrepancy-aware Distillation for 1-bit Detectors" accepted by ECCV2022.

Tips

Any problem, please contact the first author (Email: [email protected]).

Our code is heavily borrowed from DeFeat (https://github.com/ggjy/DeFeat.pytorch/) and based on MMDetection (https://github.com/open-mmlab/mmdetection).

Environments

  • Python 3.7
  • MMDetection 2.x
  • This repo uses: mmdet-v2.0 mmcv-0.5.6 cuda 10.1

Get Started

  • sh script.sh

Update

We simplify and optimize the code. Now IDa-Det is successfully plugged in the original DeFeat project. The training cost is reduced by about 30% compared with the old version.

VOC Results

Pretrained model is here: GoogleDrive

Notes:

  • Faster RCNN based model
  • Batch: sample_per_gpu x gpu_num
Model Batch Lr schd box AP Model Log
R101 4x2 0.01 81.9 GoogleDrive
R101-BiR18 4x1 0.004 76.9 GoogleDrive

If you find this work useful in your research, please consider to cite:

@inproceedings{xu2022ida,
  title={IDa-Det: An Information Discrepancy-Aware Distillation for 1-Bit Detectors},
  author={Xu, Sheng and Li, Yanjing and Zeng, Bohan and Ma, Teli and Zhang, Baochang and Cao, Xianbin and Gao, Peng and L{\"u}, Jinhu},
  booktitle={European Conference on Computer Vision},
  pages={346--361},
  year={2022},
  organization={Springer}
}

ida-det's People

Contributors

stevetsui avatar yanjingli0202 avatar

Stargazers

Leoner avatar Ding Rui avatar  avatar 夏钰彤 avatar  avatar ruining tang avatar JoungHyun Kim avatar Lu Ming avatar  avatar  avatar lixc avatar

Watchers

lixc avatar  avatar

ida-det's Issues

教师模型的预训练权重

您好,我目前训练报错:OSError: /tmp/ImageNet-pretrained/faster-rcnn-v2-374.pth is not a checkpoint file
我不知道这个教师模型的预训练权重是哪一个,您可以告诉我吗?

获取正确的计算量

我调用了您代码里的get_flop.py进行计算,似乎不对,请问您是怎么得到论文中的计算量的呢?

关于教师模型的预训练权重

由于没有提供教师模型的预训练权重,我根据您给的代码自己训了一个。但蒸馏时出现以下报错,是我哪里出错了吗?

2023-02-12 12:01:28,423 - mmdet - INFO - workflow: [('train', 1)], max: 12 epochs
kd_decay rate: 1.0
/home/micro/users/zjl/IDa-Det-main/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py:159: UserWarning: This overload of addcmul is deprecated:
addcmul(Tensor input, Number value, Tensor tensor1, Tensor tensor2, *, Tensor out)
Consider using one of the following signatures instead:
addcmul(Tensor input, Tensor tensor1, Tensor tensor2, *, Number value, Tensor out) (Triggered internally at /opt/conda/conda-bld/pytorch_1607370172916/work/torch/csrc/utils/python_arg_parser.cpp:882.)
gx = torch.addcmul(px, 1, pw, dx) # gx = px + pw * dx
dict_keys(['cls_score', 'bbox_pred', 'bbox_targets', 'sampling', 'neck', 'backbone', 'img_metas', 'gt_bboxes', 'gt_labels'])
Traceback (most recent call last):
File "tools/train_entropy.py", line 206, in
main()
File "tools/train_entropy.py", line 194, in main
train_detector_entropy(
File "/home/micro/users/zjl/IDa-Det-main/mmdet/multi_distillation/detector_entropy.py", line 511, in train_detector_entropy
runner.run(data_loaders, cfg.workflow, cfg.total_epochs, kd_cfg=cfg.model.hint_adapt)
File "/home/micro/users/zjl/IDa-Det-main/mmcv-0.5.1/mmcv/runner/runner_kd.py", line 405, in run
epoch_runner(data_loaders[i], **kwargs)
File "/home/micro/users/zjl/IDa-Det-main/mmcv-0.5.1/mmcv/runner/runner_kd.py", line 303, in train
outputs = self.batch_processor(
File "/home/micro/users/zjl/IDa-Det-main/mmdet/multi_distillation/detector_entropy.py", line 211, in batch_processor_entropy
proposal_list_s = head_det_s['proposal_list']
KeyError: 'proposal_list'

关于两个细节问题

博士您好,我有如下疑问,期待您的解答:
1.我注意到您的代码里使用了Shared2FCBBoxHead,这个检测头包含了较大的参数量和计算量,为什么您没有使用参数量、计算量较小的BBoxHead呢?
2.相较于全精度网络,您的二值化网络使用了更小的学习率和batchsize,这是为什么?

计算IDa-Det计算量

如题,如何精确计算GMac?我调用您代码里的get_flop.py发现计算量似乎偏大了,是需要自己写代码定义模块并除以64来计算吗?

about how to train and test

hello

Thanks for your contributions

I have some questions about how to train or test your code on coco datasets,

Have you considered creating a training or testing instruction document.

with best wishes

多卡训练

如题,论文中提及使用了4张2080Ti,如何启动多卡训练?

为什么使用SGD优化器?

博士您好,非常感谢您共享的代码以及先前的指导,我想请教一下,大多数二值化网络使用Adam优化器,且有论文提出Adam优化器优于SGD。
但我注意到您的代码中的config文件使用了SGD优化器,是实验证明这样做更优吗?期待您的指导。

COCO的训练配置文件

我注意到论文中有在COCO上的训练结果,能否分享一下COCO上的教师、学生网络对应的config文件?

关于idanet更换backbone

作者您好,想请问一下idanet适合除resnet以外的其他主干网络吗?例如ViT、Swin、PVT等基于Transformer的主干网络。

计算模型推理时间或者FPS

mmdetection采用tools/analysis_tools/benchmark.py来测试FPS,而在您的代码里面没有找到,请问该如何测试FPS呢,或者如何计算推理速度得知模型每秒钟在GPU上可以处理多少张图片

安装报错

博士您好,我无法运行程序,希望能得到您的帮助。
我已经完成mmcv和mmdet的setup.py,然而仍然报错,如下所示:

Traceback (most recent call last):
File "tools/train.py", line 176, in
main()
File "tools/train.py", line 71, in main
cfg = Config.fromfile(args.config)
File "/home/micro/users/zjl/DeFeat.pytorch-main/mmcv-0.5.6/mmcv/utils/config.py", line 165, in fromfile
cfg_dict, cfg_text = Config._file2dict(filename)
File "/home/micro/users/zjl/DeFeat.pytorch-main/mmcv-0.5.6/mmcv/utils/config.py", line 84, in _file2dict
filename = osp.abspath(osp.expanduser(filename))
File "/home/micro/anaconda3/envs/ida/lib/python3.8/posixpath.py", line 231, in expanduser
path = os.fspath(path)
TypeError: expected str, bytes or os.PathLike object, not NoneType

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.