Comments (12)
Absolutely! It's a pretty hard issue I think. The problem is very non-obvious, everything looks good piecewise. TF and torch train well, but JAX trains very poorly (while still training anyway). The code is here:
https://gist.github.com/fchollet/f0c84ecbed8441e54820df8366a5a629
from keras-core.
ok, I was able to verify my fix in #888 to fix this issue.
See https://colab.corp.google.com/drive/1z_QDD0uX9ApLJdFTxhHYxJNejUdgowkG#scrollTo=kxjMd749C1nA. Will send a PR very soon.
from keras-core.
I verified that the two backends give identical forward pass numerics for the model.
The saving issue was unrelated and I've now fixed it.
So there must be something different about either initializers or the optimization process somehow.
from keras-core.
I ruled out initialization. So it can only really be an optimizer issue or trainer issue.
from keras-core.
It trains fine in torch as well.
from keras-core.
Hi @fchollet, can I contribute here ?
from keras-core.
While working on the code the "tensorflow" backend threw an error.
Here is the Gist to reproduce the error: https://gist.github.com/ariG23498/b8b4c0912a0a19dfe2ef8b29b3160943
from keras-core.
The error message tells you that cuDNN can't be compiled to XLA, basically. It's somewhat tricky to solve on our side. Either we disable jit_compile if there's a cuDNN-enabled layer, or we don't use cuDNN if we detect we're tracing for XLA?
from keras-core.
To work around this you can just pass jit_compile=False
to compile()
when using TF.
from keras-core.
To work around this you can just pass
jit_compile=False
tocompile()
when using TF.
Right! Now it trains using tf backend.
The same issue persists when loading the saved model it seems. Do I also need to pass compile=False
when loading the saved model (when using TensorFlow as backend)?
from keras-core.
Do I also need to pass compile=False when loading the saved model (when using TensorFlow as backend)?
You can just set jit_compile = False
on the model I think.
from keras-core.
@ariG23498 I have fixed it in this commit, please check that it works for you. c9bce12
from keras-core.
Related Issues (20)
- Problem with `run_output_asserts` . Bug? HOT 1
- tflite cannot fuse `BatchNormalization` in Keras Core as effectively as in the original Keras
- Improve Documentation on KERAS_BACKEND HOT 5
- How can we use `ops.reshape` when the `new_shape` is obtained from `ops.shape`? HOT 3
- GSOC Program Organization ( question ) HOT 1
- Constant Initializer supports only Scalar values HOT 1
- Add support for RaggedTensors to JAX and Pytorch Backend HOT 1
- Saving broken between tf.keras and Keras Core HOT 9
- Ensure workflow reliability by hash-pinning GitHub Actions HOT 1
- Run actions.yml with read-only permissions
- Segmentation metrics HOT 3
- Use of PyTorch loss functions inside Keras
- Is possible serialize models which use torch functions? HOT 5
- Expose `Operation` without `src` import hacks HOT 1
- v0.1.6 bug: AttributeError: 'GPT2CausalLM' object has no attribute 'compiled' HOT 2
- Adam with amsgrad=True + JAX backend is broken HOT 1
- Inconsistent type handling between backends HOT 1
- Matmul - tensorflow does not broadcast/expand dimensions correctly HOT 3
- Casting dtype in losses.Loss base class HOT 2
- TypeError: copy() got an unexpected keyword argument 'overwrite' HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras-core.