Giter Site home page Giter Site logo

retinanet-pytorch's Introduction

Retinanet:目标检测模型在Pytorch当中的实现


目录

  1. 仓库更新 Top News
  2. 性能情况 Performance
  3. 所需环境 Environment
  4. 文件下载 Download
  5. 预测步骤 How2predict
  6. 训练步骤 How2train
  7. 评估步骤 How2eval
  8. 参考资料 Reference

Top News

2022-04:进行了大幅度的更新,支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整、新增图片裁剪。支持多GPU训练,新增各个种类目标数量计算。

2021-10:进行了大幅度的更新,增加了大量注释、增加了大量可调整参数、对代码的组成模块进行修改、增加fps、视频预测、批量预测等功能。

性能情况

训练数据集 权值文件名称 测试数据集 输入图片大小 mAP 0.5:0.95 mAP 0.5
VOC07+12 retinanet_resnet50.pth VOC-Test07 600x600 - 81.56

所需环境

torch==1.2.0

文件下载

训练所需的retinanet_resnet50.pth和主干的权值可以在百度云下载。
链接: https://pan.baidu.com/s/1Qal7lmN3aV0ZHscB_1OmrA
提取码: ckv8

VOC数据集下载地址如下,里面已经包括了训练集、测试集、验证集(与测试集一样),无需再次划分:
链接: https://pan.baidu.com/s/1-1Ej6dayrx3g0iAA88uY5A
提取码: ph32

训练步骤

a、训练VOC07+12数据集

  1. 数据集的准备
    本文使用VOC格式进行训练,训练前需要下载好VOC07+12的数据集,解压后放在根目录

  2. 数据集的处理
    修改voc_annotation.py里面的annotation_mode=2,运行voc_annotation.py生成根目录下的2007_train.txt和2007_val.txt。

  3. 开始网络训练
    train.py的默认参数用于训练VOC数据集,直接运行train.py即可开始训练。

  4. 训练结果预测
    训练结果预测需要用到两个文件,分别是retinanet.py和predict.py。我们首先需要去retinanet.py里面修改model_path以及classes_path,这两个参数必须要修改。
    model_path指向训练好的权值文件,在logs文件夹里。
    classes_path指向检测类别所对应的txt。

    完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。

b、训练自己的数据集

  1. 数据集的准备
    本文使用VOC格式进行训练,训练前需要自己制作好数据集,
    训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。
    训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。

  2. 数据集的处理
    在完成数据集的摆放之后,我们需要利用voc_annotation.py获得训练用的2007_train.txt和2007_val.txt。
    修改voc_annotation.py里面的参数。第一次训练可以仅修改classes_path,classes_path用于指向检测类别所对应的txt。
    训练自己的数据集时,可以自己建立一个cls_classes.txt,里面写自己所需要区分的类别。
    model_data/cls_classes.txt文件内容为:

cat
dog
...

修改voc_annotation.py中的classes_path,使其对应cls_classes.txt,并运行voc_annotation.py。

  1. 开始网络训练
    训练的参数较多,均在train.py中,大家可以在下载库后仔细看注释,其中最重要的部分依然是train.py里的classes_path。
    classes_path用于指向检测类别所对应的txt,这个txt和voc_annotation.py里面的txt一样!训练自己的数据集必须要修改!
    修改完classes_path后就可以运行train.py开始训练了,在训练多个epoch后,权值会生成在logs文件夹中。

  2. 训练结果预测
    训练结果预测需要用到两个文件,分别是retinanet.py和predict.py。在retinanet.py里面修改model_path以及classes_path。
    model_path指向训练好的权值文件,在logs文件夹里。
    classes_path指向检测类别所对应的txt。

    完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。

预测步骤

a、使用预训练权重

  1. 下载完库后解压,在百度网盘下载权值,放入model_data,运行predict.py,输入
img/street.jpg
  1. 在predict.py里面进行设置可以进行fps测试和video视频检测。

b、使用自己训练的权重

  1. 按照训练步骤训练。
  2. 在retinanet.py文件里面,在如下部分修改model_path和classes_path使其对应训练好的文件;model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类
_defaults = {
    #--------------------------------------------------------------------------#
    #   使用自己训练好的模型进行预测一定要修改model_path和classes_path!
    #   model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
    #   如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
    #--------------------------------------------------------------------------#
    "model_path"        : 'model_data/retinanet_resnet50.pth',
    "classes_path"      : 'model_data/voc_classes.txt',
    #---------------------------------------------------------------------#
    #   输入图片的大小
    #---------------------------------------------------------------------#
    "input_shape"       : [600, 600],
    #---------------------------------------------------------------------#
    #   用于选择所使用的模型的版本
    #   0、1、2、3、4
    #   resnet18, resnet34, resnet50, resnet101, resnet152
    #---------------------------------------------------------------------#
    "phi"               : 2,
    #---------------------------------------------------------------------#
    #   只有得分大于置信度的预测框会被保留下来
    #---------------------------------------------------------------------#
    "confidence"        : 0.5,
    #---------------------------------------------------------------------#
    #   非极大抑制所用到的nms_iou大小
    #---------------------------------------------------------------------#
    "nms_iou"           : 0.3,
    #---------------------------------------------------------------------#
    #   该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
    #   在多次测试后,发现关闭letterbox_image直接resize的效果更好
    #---------------------------------------------------------------------#
    "letterbox_image"   : True,
    #---------------------------------------------------------------------#
    #   
    #---------------------------------------------------------------------#
    "cuda"              : True
}
  1. 运行predict.py,输入
img/street.jpg
  1. 在predict.py里面进行设置可以进行fps测试和video视频检测。

评估步骤

a、评估VOC07+12的测试集

  1. 本文使用VOC格式进行评估。VOC07+12已经划分好了测试集,无需利用voc_annotation.py生成ImageSets文件夹下的txt。
  2. 在retinanet.py里面修改model_path以及classes_path。model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。
  3. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。

b、评估自己的数据集

  1. 本文使用VOC格式进行评估。
  2. 如果在训练前已经运行过voc_annotation.py文件,代码会自动将数据集划分成训练集、验证集和测试集。如果想要修改测试集的比例,可以修改voc_annotation.py文件下的trainval_percent。trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1。train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1。
  3. 利用voc_annotation.py划分测试集后,前往get_map.py文件修改classes_path,classes_path用于指向检测类别所对应的txt,这个txt和训练时的txt一样。评估自己的数据集必须要修改。
  4. 在retinanet.py里面修改model_path以及classes_path。model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。
  5. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。

Reference

https://github.com/pierluigiferrari/ssd_keras
https://github.com/kuhung/SSD_keras
https://github.com/qqwweee/keras-yolo3/
https://github.com/Cartucho/mAP

retinanet-pytorch's People

Contributors

bubbliiiing avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

retinanet-pytorch's Issues

anchor生成的问题

为什么生成anchor时这里左上角和右上角为y1,x1,y2,x2?

boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2,
yv + anchor_size_y_2, xv + anchor_size_x_2))

这种是运行成功了么,楼主

(base) root@autodl-container-b2e2118052-826a86c4:~/retinanet-pytorch-master# python train.py
Load weights model_data/retinanet_resnet50.pth.

Successful Load Key: ['backbone_net.model.conv1.weight', 'backbone_net.model.bn1.weight', 'backbone_net.model.bn1.bias', 'backbone_net.model.bn1.running_mean', 'backbone_net.model.bn1.running_var', 'backbone_net.model.bn1.num_batches_tracked', 'backbone_net.model.layer1.0.conv1.weight', 'backbone_net.model.layer1.0.bn1.weight', 'backbone_net.model.layer1.0.bn1.bias', 'backbone_net.model.layer1.0.bn1.running_mean', 'backbone_net.model.layer1.0.bn1.running_var', 'backbone_net.model.layer1.0.bn1.num_batches_tracked', ' ……
Successful Load Key Num: 354

Fail To Load Key: [] ……
Fail To Load Key num: 0

但是有一次我运行之后,迭代五次就爆了错误
Get map.
0%| | 0/4952 [00:03<?, ?it/s]
Traceback (most recent call last):
File "train.py", line 515, in
epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank)
File "/root/retinanet-pytorch-master/utils/utils_fit.py", line 105, in fit_one_epoch
eval_callback.on_epoch_end(epoch + 1, model_train)
File "/root/retinanet-pytorch-master/utils/callbacks.py", line 196, in on_epoch_end
self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
File "/root/retinanet-pytorch-master/utils/callbacks.py", line 143, in get_map_txt
results = non_max_suppression(torch.cat([outputs, classification], axis=-1), self.input_shape,
TypeError: cat() got an unexpected keyword argument 'axis'

mAP复现问题

bubbliiing大哥,我做了大概一周您的这个project,但是目前我用您提供的权重复现不了81.56%的检测精度,我就是用的您的代码啊。
我想再次确认一下,600x600pixel的尺度能够获得这个精度吧。

计算fps值报错

您好,我用您的代码训练完,也测得了对应的map指标。但是在计算fps值的时候报了如下错误 TypeError: non_max_suppression() got multiple values for argument 'conf_thres',麻烦请问下您有遇到过这个错误吗? 我的环境是ubuntu下torch1.8,谢谢了。

有关改为Resnet101/152的问题

image

我把代码中的phi改成了101代表的数字,并且也在model_data中放入了resnet101的模型,请问这个是什么原因啊,我还需要改哪里吗

数据集问题

你好博主,我的数据集是关于医学图像的没有那种voc的标注xml文件。只有png格式的和两个描述的csv文件,可以直接训练吗,还是需要转到voc的格式。🙏🙏

关于预训练模型

您好,请问有resnet34的预训练模型吗?为什么我使用在其它地方找到的resnet34的预训练模型对网络进行训练,会出现如下报错信息呢:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
KeyError: 'conv1.weight'

num_class=20

您好,我想问一下为什么针对voc数据集,网络预测的分类不是21个?你后面算focal_loss不看背景吗

torch版本问题

大佬推荐的是torch=1.2.1,请问可以用更高版本的torch吗?

训练mAP问题

b导,我在使用retinaNet训练时,使用主干网络作为预训练权重,然后pretrain设置为true,结果显示:
Only tensors or tuples of tensors can be output from traced functions ...Error occurs, No graph saved
然后训练的每10代验证集mAP一直是0.00几,这是哪里出现问题如何解决呢?

FileNotFoundError: [Errno 2] No such file or directory: '288'

您好, 我在按照所有步骤正常执行后,运行train.py提示这样的错误。
File "C:\Users\TitanV\AppData\Roaming\Python\Python36\site-packages\PIL\Image.py", line 2766, in open
fp = builtins.open(filename+'.jpg', "rb")
FileNotFoundError: [Errno 2] No such file or directory: '288'

修改nms

大神您好,请问如果想要改nms的话要改哪些地方呢?

有关加入cascade的问题

您好,我在训练自己的数据集时,发现了生成box与gt存在一点差距的问题,boss让我加入cascade来解决。但是我看了下cascade的代码,感觉脑子好混乱,里面还涉及了roipooling、RoIAlign等等。请问如果我要在这个retina中加入cascade的话,要在哪里改,大概的思路是什么,谢谢!

预测结果

您好,请问为什么检测图片的时候一个目标会预测出两个类呢?

计算map

你好,请问为什么计算map时,召回率与精确度都一样,但map的结果却相差了10呢?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.