Giter Site home page Giter Site logo

Comments (9)

bermanmaxim avatar bermanmaxim commented on May 26, 2024 2
def keras_lovasz_softmax(y_true, y_pred):
    l_pred = K.expand_dims(K.argmax(y_pred, axis=3), -1)
    l_pred = tf.cast(l_pred, y_true.dtype)
    return lovasz_softmax(l_pred, y_true)

Well that is exactly the problem isn't it, y_pred should definitely be 4D (with scores for each class), while y_true should be 3D (with ground truth label for each pixel) so you should use the argmax to reduce y_true, not y_pred.

Also don't forget to first go through a softmax if y_pred are unnormalized, as hinted at in the TF example notebook.

from lovaszsoftmax.

bermanmaxim avatar bermanmaxim commented on May 26, 2024 1

The function did initially throw non differentiable errors for me, however swapping the input order, and making the above modification to the labels enabled me to train (albeit slowly).

PS @ben2789: sorry about the slowness, it is expected in tensorflow unfortunately.

from lovaszsoftmax.

bermanmaxim avatar bermanmaxim commented on May 26, 2024

Hi,
The lovasz_softmax loss implemented here expects dim(labels) = (batchsize, width, height) and dim(probas)=(batchsize, width, height, n_classes). But the code should also work if dim(labels) = (batchsize, width, height, 1). So there shouldn't be anything to change there.
Looking at keras documentation, the problem might simply be the order of the arguments, since I implemented lovasz_softmax(probas, labels) and keras expects loss(y_true, y_pred). In that case simply modify the implementation or define a loss function that inverts the arguments before calling lovasz_softmax.

from lovaszsoftmax.

rhoef avatar rhoef commented on May 26, 2024

Labels and Probabilities need to have the same dimensionality that a loss function can be used in keras. In other words the labels need to be in the one-hot-vector notation. It's not about the order of the arguments.

from lovaszsoftmax.

bermanmaxim avatar bermanmaxim commented on May 26, 2024

Alright. Well, I'm not an expert with tensorflow, but doing an argmax on the labels should change them from a one-hot encoding into an integer encoding. There is no issue of differentiability since you do not want to differentiate with respect to the labels.

from lovaszsoftmax.

rhoef avatar rhoef commented on May 26, 2024

Unfortunately answer the answer does not solve the problem I initally described. Anyway I find some other solution.

from lovaszsoftmax.

ben2789 avatar ben2789 commented on May 26, 2024

I used argmax on the labels in the loss function (could have just not one hot encoded them in the data generator) Keras throws no differentiability exception on this. As I understand it (the paper was a bit beyond me in places) Argmaxing the output of the softmax layer would cause an error, but modifying the labels should be fine.

The function did initially throw non differentiable errors for me, however swapping the input order, and making the above modification to the labels enabled me to train (albeit slowly).

from lovaszsoftmax.

rhoef avatar rhoef commented on May 26, 2024

You might share the code.
Argmax on the labels not the predictions I don't know how this should makes sense. The output of the network is one-hot encoded (4 classes) and there is no chance to change this. It's a standard use case.

The function works fine outside training. Here is the Exception and the wrapper function (which I already had before I wrote this issue!)

Traceback (most recent call last):
  File "train.py", line 170, in <module>
    batch_size=args.batch_size, run_id=args.run_id, test=args.test, loss=args.loss)
  File "train.py", line 100, in train
    callbacks=[checkpoint, tbc, csvl, lrs])
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1008, in fit
    self._make_train_function()
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 498, in _make_train_function
    loss=self.total_loss)
  File "/usr/local/lib/python3.5/dist-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/keras/optimizers.py", line 470, in get_updates
    grads = self.get_gradients(loss, params)
  File "/usr/local/lib/python3.5/dist-packages/keras/optimizers.py", line 91, in get_gradients
    raise ValueError('An operation has `None` for gradient. '
ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.
def keras_lovasz_softmax(y_true, y_pred):
    l_pred = K.expand_dims(K.argmax(y_pred, axis=3), -1)
    l_pred = tf.cast(l_pred, y_true.dtype)
    return lovasz_softmax(l_pred, y_true)

from lovaszsoftmax.

tom-bu avatar tom-bu commented on May 26, 2024

In Keras, a loss function accepts the predictions and labels in the opposite order. That's why an error of non differentiable gradient shows up. You're accidentally taking the argmax of the prediction.

from lovaszsoftmax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.