Comments (13)
Yes, indeed at the moment XLA builds on GPU aren't fully reproducible, e.g. google/jax#565. I'll check with the JAX team to learn more.
from flax.
I confirm that for the MNIST example, this issue is solved. With the following command:
!export XLA_FLAGS=--xla_gpu_deterministic_reductions && export TF_CUDNN_DETERMINISTIC=1 && echo $XLA_FLAGS, $TF_CUDNN_DETERMINISTIC && python main.py
,
results are consistent between 2 runs (on Google Colab):
--xla_gpu_deterministic_reductions, 1
2020-10-28 16:14:10.248318: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
eval epoch: 1, loss: 0.0563, accuracy: 98.14
eval epoch: 2, loss: 0.0561, accuracy: 98.21
eval epoch: 3, loss: 0.0401, accuracy: 98.79
eval epoch: 4, loss: 0.0365, accuracy: 98.89
eval epoch: 5, loss: 0.0359, accuracy: 98.95
eval epoch: 6, loss: 0.0360, accuracy: 98.94
eval epoch: 7, loss: 0.0303, accuracy: 99.16
eval epoch: 8, loss: 0.0418, accuracy: 98.93
eval epoch: 9, loss: 0.0406, accuracy: 99.03
eval epoch: 10, loss: 0.0326, accuracy: 99.18
eval epoch: 1, loss: 0.0563, accuracy: 98.14
eval epoch: 2, loss: 0.0561, accuracy: 98.21
eval epoch: 3, loss: 0.0401, accuracy: 98.79
eval epoch: 4, loss: 0.0365, accuracy: 98.89
eval epoch: 5, loss: 0.0359, accuracy: 98.95
eval epoch: 6, loss: 0.0360, accuracy: 98.94
eval epoch: 7, loss: 0.0303, accuracy: 99.16
eval epoch: 8, loss: 0.0418, accuracy: 98.93
eval epoch: 9, loss: 0.0406, accuracy: 99.03
eval epoch: 10, loss: 0.0326, accuracy: 99.18
from flax.
Hi @goingtosleep! There's a good chance this is due to TFDS. Here's what would be great: Could you dump MNIST entirely to disk once, and then modify the dataset loading code to read from disk? I believe then the training runs will be fully reproducible.
I'd also like to share this with the TFDS team, as I believe there should be a way to get reproducible dataset loaders, though we haven't yet done that ourselves.
from flax.
An easy way to test this hypothesis would be to move the training pipeline outside of the train()
fn and convert them to numpy and use list
on the generator:
train_ds = tfds.load('mnist', split='train')
train_ds = list(tfds.as_numpy(train_ds.cache().batch(128)))
test_ds = tfds.as_numpy(tfds.load( 'mnist', split='test', batch_size=-1))
def train():
...
for batch in ds_train:
batch = dict(batch) # Copy dict before mutating in-place
...
train()
train()
That way there is no more tf.data involved as train_ds is just a List[np.array]
Edit: Could you also share which version of Flax, TFDS and Python you're using ?
from flax.
On what accelerator did you run this test?
On GPU reproducibility is never guaranteed because Jax is currently not deterministic on that platform even when executing the exact same computations.
BTW: We should probably mention this somewhere in the docs and so should Jax!
On CPU I was unable to reproduce your issue and TPU results should also be reproducible.
from flax.
Thank you all for the replies.
I used flax-0.0.1a0
, jax-0.1.59
, jaxlib-0.1.39
, tfds 2.0.0
, on Google Colab P100 instance (Python 3.6.9
).
I've done some tests so far:
-
On CPU, the random seed is fixed, the result is reproducible every run.
-
On GPU (Colab), I use
keras datasets
, which returnsnumpy ndarray
s. Then I train the model with 25_000 first data points (due to memory limit) with no permutation and 1 batch only. Test accuracy is calculated on test set (10_000 data points) as usual. With this setup I believe there is no randomness involved except for weights initialization. The results are as follows:
First run:
Epoch 1 in [17.32s], loss [4.726857], accuracy [34.37%]
Epoch 2 in [0.02s], loss [3.494419], accuracy [70.59%]
Epoch 3 in [0.02s], loss [2.420179], accuracy [84.71%]
Epoch 4 in [0.02s], loss [1.485580], accuracy [88.12%]
Epoch 5 in [0.02s], loss [0.923213], accuracy [91.93%]
Second run:
Epoch 1 in [9.93s], loss [4.726847], accuracy [34.37%]
Epoch 2 in [0.02s], loss [3.494333], accuracy [70.60%]
Epoch 3 in [0.02s], loss [2.420010], accuracy [84.70%]
Epoch 4 in [0.02s], loss [1.485356], accuracy [88.11%]
Epoch 5 in [0.02s], loss [0.923163], accuracy [91.93%]
Third run:
Epoch 1 in [9.81s], loss [4.726859], accuracy [34.40%]
Epoch 2 in [0.02s], loss [3.494327], accuracy [70.59%]
Epoch 3 in [0.02s], loss [2.419940], accuracy [84.70%]
Epoch 4 in [0.02s], loss [1.485139], accuracy [88.11%]
Epoch 5 in [0.02s], loss [0.923007], accuracy [91.94%]
Rerun after restarting the runtime:
Epoch 1 in [19.73s], loss [4.726849], accuracy [34.37%]
Epoch 2 in [0.02s], loss [3.494359], accuracy [70.59%]
Epoch 3 in [0.02s], loss [2.420025], accuracy [84.70%]
Epoch 4 in [0.02s], loss [1.485441], accuracy [88.10%]
Epoch 5 in [0.02s], loss [0.923208], accuracy [91.93%]
I used a different network architecture so the accuracy and loss could be different from the MNIST example, but the network is deterministic (no dropout), so no randomness involved here in the training process.
I'm not sure if the differences in the above results are due to float32
precision. I tried
from jax.config import config
config.update("jax_enable_x64", True)
but the weights are still of float32
, maybe I will test this later. What do you think?
from flax.
Can we close this for now? I think both XLA and Jax teams are aware of this issue and the fix is in progress.
from flax.
If the fix isn't in, I don't think we should close the issue.
from flax.
google/jax#565 is fixed. I've verified that @goingtosleep 's code now outputs reproducible results on a TPU, so I am closing this issue.
from flax.
Sorry, the fix isn't in. JAX is still not reproducible on GPU. We need to make sure there's an open ticket on the JAX GitHub tracker.
from flax.
Hi @goingtosleep -- this is a bit late but the the XLA_FLAGS=--xla_gpu_deterministic_reductions
environment should work (though perhaps not yet for all operations). I'd be curious to see if this solves the issue.
from flax.
We're working on improving the JAX documentation on this in google/jax#4824. Feedback on that PR is welcome!
from flax.
It sounds like the --xla_gpu_deterministic_reductions
flag is now gone (or it will be when we push an updated jaxlib) because it's now effectively always on by default. So hopefully this will get less surprising...
from flax.
Related Issues (20)
- Unexpected behavior for @nn.compact_name_scope
- [Feature Request] Modular checkpointing of Flax module HOT 1
- Add CRF module HOT 1
- Standardizing normalization layers HOT 2
- Jax transforms and Flax models cannot be mixed
- VAE example outdated HOT 3
- Make `self.make_rng()` callable by using a default RNG stream and have `.init()` and `.apply()` use the default RNG stream if no explicit RNG mapping is passed.
- DenseGeneral with more than 2 dimensions cannot be partitioned HOT 1
- asyncio error while loading weights HOT 4
- Add vanilla / Elman / simple RNN cell HOT 4
- NNX `_compute_stats` function missing `use_fast_variance` and `mask` argument HOT 1
- Memory issue when randomly initializing large parameters, sharding cannot help
- Deprecation Warnings with orbax 0.5.3 HOT 2
- Feature request: Add ConvGRUCell
- The Error in FLOP Computation of Model Tabulate Function HOT 1
- Make redundant `features` argument optional for recurrent cells HOT 2
- Add `reset_gate` flag to `MGUCell` HOT 4
- Unify behavior of strides arg of Conv and ConvTranspose HOT 1
- modifying params of flax.linen. Module model HOT 1
- Error when calling module tabulate involving WeightNorm HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flax.