Giter Site home page Giter Site logo

Comments (12)

fchollet avatar fchollet commented on May 14, 2024 1

Absolutely! It's a pretty hard issue I think. The problem is very non-obvious, everything looks good piecewise. TF and torch train well, but JAX trains very poorly (while still training anyway). The code is here:

https://gist.github.com/fchollet/f0c84ecbed8441e54820df8366a5a629

from keras-core.

qlzh727 avatar qlzh727 commented on May 14, 2024 1

ok, I was able to verify my fix in #888 to fix this issue.

See https://colab.corp.google.com/drive/1z_QDD0uX9ApLJdFTxhHYxJNejUdgowkG#scrollTo=kxjMd749C1nA. Will send a PR very soon.

from keras-core.

fchollet avatar fchollet commented on May 14, 2024

I verified that the two backends give identical forward pass numerics for the model.

The saving issue was unrelated and I've now fixed it.

So there must be something different about either initializers or the optimization process somehow.

from keras-core.

fchollet avatar fchollet commented on May 14, 2024

I ruled out initialization. So it can only really be an optimizer issue or trainer issue.

from keras-core.

fchollet avatar fchollet commented on May 14, 2024

It trains fine in torch as well.

from keras-core.

shivance avatar shivance commented on May 14, 2024

Hi @fchollet, can I contribute here ?

from keras-core.

ariG23498 avatar ariG23498 commented on May 14, 2024

While working on the code the "tensorflow" backend threw an error.

Here is the Gist to reproduce the error: https://gist.github.com/ariG23498/b8b4c0912a0a19dfe2ef8b29b3160943

from keras-core.

fchollet avatar fchollet commented on May 14, 2024

The error message tells you that cuDNN can't be compiled to XLA, basically. It's somewhat tricky to solve on our side. Either we disable jit_compile if there's a cuDNN-enabled layer, or we don't use cuDNN if we detect we're tracing for XLA?

from keras-core.

fchollet avatar fchollet commented on May 14, 2024

To work around this you can just pass jit_compile=False to compile() when using TF.

from keras-core.

ariG23498 avatar ariG23498 commented on May 14, 2024

To work around this you can just pass jit_compile=False to compile() when using TF.

Right! Now it trains using tf backend.

The same issue persists when loading the saved model it seems. Do I also need to pass compile=False when loading the saved model (when using TensorFlow as backend)?

from keras-core.

fchollet avatar fchollet commented on May 14, 2024

Do I also need to pass compile=False when loading the saved model (when using TensorFlow as backend)?

You can just set jit_compile = False on the model I think.

from keras-core.

fchollet avatar fchollet commented on May 14, 2024

@ariG23498 I have fixed it in this commit, please check that it works for you. c9bce12

from keras-core.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.