Giter Site home page Giter Site logo

yatenglg / focal-loss-pytorch Goto Github PK

View Code? Open in Web Editor NEW
428.0 5.0 113.0 567 KB

全中文注释.(The loss function of retinanet based on pytorch).(You can use it on one-stage detection task or classifical task, to solve data imbalance influence).用于one-stage目标检测算法,提升检测效果.你也可以在分类任务中使用该损失函数,解决数据不平衡问题.

Home Page: https://ptorch.com/news/253.html

Python 43.99% Jupyter Notebook 56.01%

focal-loss-pytorch's Introduction

GIthub使用指北:

1.想将项目拷贝到自己帐号下就fork一下.

2.持续关注项目更新就star一下

3.watch是设置接收邮件提醒的.

jupyter-notebook用法例子 请见:由于Github是国外网站,加载会稍慢

retinanet的实现请见:Retinanet-pytorch


pytorch 实现 focal loss

retinanet论文损失函数

实现过程简易明了,全中文备注.

参数说明

  • alpha参数,是类别损失权重。

    用于调节各类别对损失的影响,具体作用与torch实现的CrossEntropyLoss中的weight参数一致。

    你可以输入一个float,比如0.25,则最终的alpha将是[0.25, 0.75, 0.75, 0.75, ...],这种情况一般用于目标检测,用来抑制背景类对损失的影响;你也可以直接输入一个列表,直接为每一类指定损失权重。

  • gamma参数,是难易度系数,也是focal loss不同于交叉熵的最大区别。

    用于调整训练过程中难识别样本与易识别样本对损失的影响

参数设置

通常情况下,设置好num_classes直接调用就可以了。

  • alhpa参数

    可以参考各类别样本数据量比例,来设置alpha参数。

    但更建议的是,进行多次训练:

    1. 初次训练时,可以将alpha设置为一个值全为1的列表,使各个类别平等的去影响损失。

    2. 测试结果后,针对想提高的类别,给予一个较其他类大的权重值,加大该类对损失的影响,继续训练模型,使模型在训练时更倾向于该类。例如5分类任务中,设置alpha=[1, 1, 2, 3, 1],加大第三类、第四类对损失的影响,提高这两类的分类精度。

  • gamma参数

    gamma参数只推荐设置为2

交叉熵损失

cross_empty

带平衡因子的交叉熵

α-cross_empty

Focal损失

加入 (1-pt)γ 平衡难易样本的权重,通过γ缩放因子调整,retinanet默认γ=2

focal loss

带平衡因子的Focal损失

论文中最终为带平衡因子的focal loss, 本项目实现的也是这个版本

α-focal loss

最终retinanet的效果

不同γ 值收敛效果

focal loss_效果

retinanet与其他检测模型对比

retinanet对比图

focal-loss-pytorch's People

Contributors

yatenglg avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

focal-loss-pytorch's Issues

老哥,第40行是不是应该用reshape啊?

老哥的代码看起来很舒服啊,但是用的时候第40行preds = preds.view(-1,preds.size(-1))报错了,改成reshape就通过了。正在使用,你看看是不是这样改,我也怕等等训爆炸了2333

typos in your readme

你好,

一个小提醒,在你的ReadMe文件中,RetinaNet, 你写成了RetainNet.
谢谢.
祝愉快!

关于α参数

论文中说 α for class 1 and 1−α for class −1,class 1和class -1应该分别指的是正样本和负样本(背景类),可我看您的实现好像是反的?

这个报错该怎么解决?

self.alpha = self.alpha.gather(0,labels.view(-1))

RuntimeError: Invalid index in gather at C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:657

关于背景类

博主您好!
我研读了您focalloss.py的代码,如果label是otherwise而不是1,我看代码好像并没有计算相应的损失值。
请问如果label是otherwise是不需要计算损失吗?

代码bug

self.alpha = self.alpha.gather(0,labels.view(-1))

因为上一阶段self.alpha被赋值后self.alpha数值变了。gather后的结果就会有问题。
另外如果这样写,样本维度变化,如果后一批样本比前一批样本维度大,会报错 :Invalid index in gather
建议改成
alpha = self.alpha.gather(0,labels.view(-1))
...
loss = torch.mul(alpha, loss.t())

NER中num_class怎么设置?

您好,感谢您的分享。想请教一下,在NER任务中,比如有['O','B-person','I-person','[PAD]','[CLS]','[SEP]']这几个标签,num_class应该设置多少?还有就是NER中preds和labels的输入是?可以举个例子么?

原始论文中提到,应该给样本数量较少的类别较小的权重

您好,非常感谢您的工作,我最近在看Focal Loss的时候,发现论文中提到,应该给原始类别较少的类别较小的权重,即设背景类是0,目标类是1的话,那么应该令$\alpha=[0.75, 0.25]$才对,这是否和您写的代码有出入呢?

还望不吝赐教,不知道是您写的有问题,还是我理解错了,感谢!

alpha bug

Hi, yatengLG
对于focal_loss中
self.alpha = self.alpha.gather(0,labels.view(-1))

假设多分类num_classes=10,batch_size=8
第一轮self.alpha.size=10, 第二轮self.alpha=8,这好像是bug吧

当batch size是1的时候报错

在代码中把batchsize设置为1,在第二个batch的时候,运行过 self.alpha = self.alpha.gather(0, labels.view(-1))
这句代码之后,会报越界错误,具体表现为所有tensor的值都会变成Unable to get repr for <class 'torch.Tensor'>。应该怎么更改呢?

alpha如何设置?

你好,我在做图片5分类,数目分别是A类280张 、C类 313 、D类1801 、G类 326和N类 2157,想解决数据不平衡问题,我的alpha和gamma怎么设置呢?期待您的回复!

数值不稳定问题

由于数值溢出问题,当preds中错误类别的数值绝对值特别大的时候,loss容易计算得到inf

建议不要采用softmax+log,采用log_softmax+exp会更好。

preds_softmax = F.softmax(preds, dim=1) # 这里并没有直接使用log_softmax, 因为后面会用到softmax的结果(当然你也可以使用log_softmax,然后进行exp操作)
preds_logsoft = torch.log(preds_softmax)

图像分割使用问题

您好,打扰您了,想请教一下该损失函数能用于图像分割吗,我直接使用会报关于维度的错误。

请问如何设计正则化来防止focal loss带来的过拟合会比较好?可以帮忙给出一个样例参考吗?

你好,非常感谢你的代码贡献。我在现在这样简单的优化设定下过拟合很严重:
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-5)
scheduler = StepLR(optimizer, step_size=40, gamma=0.1)
看到你之前的回复说需要设置正则化来抗过拟合,请问可以帮忙给出一个正则化的方式吗?非常感谢!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.