Giter Site home page Giter Site logo

artemmavrin / focal-loss Goto Github PK

View Code? Open in Web Editor NEW
186.0 3.0 43.0 181 KB

TensorFlow implementation of focal loss

Home Page: https://focal-loss.readthedocs.io

License: Apache License 2.0

Makefile 1.50% Python 98.50%
deep-learning tensorflow keras loss-functions

focal-loss's Introduction

Focal Loss

Python Version PyPI Package Version Last Commit Build Status Code Coverage Documentation Status License

TensorFlow implementation of focal loss [1]: a loss function generalizing binary and multiclass cross-entropy loss that penalizes hard-to-classify examples.

The focal_loss package provides functions and classes that can be used as off-the-shelf replacements for tf.keras.losses functions and classes, respectively.

# Typical tf.keras API usage
import tensorflow as tf
from focal_loss import BinaryFocalLoss

model = tf.keras.Model(...)
model.compile(
    optimizer=...,
    loss=BinaryFocalLoss(gamma=2),  # Used here like a tf.keras loss
    metrics=...,
)
history = model.fit(...)

The focal_loss package includes the functions

  • binary_focal_loss
  • sparse_categorical_focal_loss

and wrapper classes

  • BinaryFocalLoss (use like tf.keras.losses.BinaryCrossentropy)
  • SparseCategoricalFocalLoss (use like tf.keras.losses.SparseCategoricalCrossentropy)

Documentation is available at Read the Docs.

Focal loss plot

Installation

The focal_loss package can be installed using the pip utility. For the latest version, install directly from the package's GitHub page:

pip install git+https://github.com/artemmavrin/focal-loss.git

Alternatively, install a recent release from the Python Package Index (PyPI):

pip install focal-loss

Note. To install the project for development (e.g., to make changes to the source code), clone the project repository from GitHub and run make dev:

git clone https://github.com/artemmavrin/focal-loss.git
cd focal-loss
# Optional but recommended: create and activate a new environment first
make dev

This will additionally install the requirements needed to run tests, check code coverage, and produce documentation.

References

[1]T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for dense object detection. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2018. (DOI) (arXiv preprint)

focal-loss's People

Contributors

artemmavrin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

focal-loss's Issues

SparseCategoricalFocalLoss with class_weight ?

Hm..maybe this is weird question.

But can be used SparseCategoricalFocalLoss with defining class_weight (for example in model.fit fn) together? Or its bad idea?

Thanks for answer. :)

installing focal loss leads to tensorflow stopping

just to notify others, when i installed focal loss using pip install focal-loss, i couldn't import tensorflow nor keras libraries.
so to avoid this happen to you, please note that you have to put --user after the line code of focal installation, just as below:

pip install focal-loss --user

SparseCategoricalFocalLoss: ignore class?

Hi,

TensorFlow's implementation of categorical cross-entropy accepts a parameter called ignore_class (see here) since v 2.10. In principle, it can be useful for NER tasks. It would be nice if such a parameter appeared in the calculation of focal loss, so an apple-to-apple comparison can be carried out.

alpha weight for negative loss

Hi,

On line 512 and forward of _binary_focal_loss.py:

def _binary_focal_loss_from_probs(labels, p, gamma, pos_weight,
                              label_smoothing):
"""Compute focal loss from probabilities.

Parameters
----------
labels : tensor-like
    Tensor of 0's and 1's: binary class labels.

p : tf.Tensor
    Estimated probabilities for the positive class.

gamma : float
    Focusing parameter.

pos_weight : float or None
    If not None, losses for the positive class will be scaled by this
    weight.

label_smoothing : float or None
    Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary
    ground truth labels `y_true` are squeezed toward 0.5, with larger values
    of `label_smoothing` leading to label values closer to 0.5.

Returns
-------
tf.Tensor
    The loss for each example.
"""
# Predicted probabilities for the negative class
q = 1 - p

# For numerical stability (so we don't inadvertently take the log of 0)
p = tf.math.maximum(p, _EPSILON)
q = tf.math.maximum(q, _EPSILON)

# Loss for the positive examples
pos_loss = -(q ** gamma) * tf.math.log(p)
if pos_weight is not None:
    pos_loss *= pos_weight

# Loss for the negative examples
neg_loss = -(p ** gamma) * tf.math.log(q)

# Combine loss terms
if label_smoothing is None:
    labels = tf.dtypes.cast(labels, dtype=tf.bool)
    loss = tf.where(labels, pos_loss, neg_loss)
else:
    labels = _process_labels(labels=labels, label_smoothing=label_smoothing,
                             dtype=p.dtype)
    loss = labels * pos_loss + (1 - labels) * neg_loss

return loss

With the negative loss calculation:

# Loss for the negative examples
neg_loss = -(p ** gamma) * tf.math.log(q)

Shouldn't there be a '1-pos_weight' weighting factor just like for the pos_loss in the lines above it?

Unknown y_true tensor

Hi, I'm trying to put categorical_focal_loss in my image segmentation task. The dataset is defined with tf.data.Dataset object and the model is defined with keras Model. The model is compiled like

loss_gamma = [0.5, 1., ...]
model.compile(
    optimizer=tf.keras.optimizers.Adam(lr=lr),
    loss=SparseCategoricalFocalLoss(gamma=loss_gamma),
...)
model.fit(...)

While training the segmentation task, assert exemption raise because the y_true tensor is Unknown.
https://github.com/artemmavrin/focal-loss/blob/master/src/focal_loss/_categorical_focal_loss.py#L136-L141
How do I define the true tensor? In my task, the true tensor is shaped with (BATCH, HEIGHT, WIDTH). My virtual environment is on ubuntu18.04, tensorflow 2.2.0

Support mixed precision

Right now the loss is hard-coded to use float32 which means smaller data types cannot be used. Perhaps it would be possible to set a default argument for in the wrappers to allow for configuration?

In any case this library is incredibly useful, thank you for it!

Sparse tensor input

Hello,
I'm trying to replace the tf.losses.SparseCategoricalCrossentropy with your loss, but i think it doesn't accept sparse inputs.
The model is a simple U-Net segmentation with a softmax end.
Can you point me in the right direction on how to solve this?
Thank you

SparseCategoricalFocalLoss not included with the installation

Running M1 Apple with miniforge installation. SparseCategoricalFocalLoss is not included in the pip installation
['BinaryFocalLoss',
'author',
'author_email',
'builtins',
'cached',
'copyright',
'description',
'doc',
'file',
'license',
'loader',
'name',
'package',
'path',
'spec',
'url',
'version',
'_focal_loss',
'binary_focal_loss',
'utils']

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.