Giter Site home page Giter Site logo

Comments (13)

avital avatar avital commented on May 22, 2024 1

Yes, indeed at the moment XLA builds on GPU aren't fully reproducible, e.g. google/jax#565. I'll check with the JAX team to learn more.

from flax.

goingtosleep avatar goingtosleep commented on May 22, 2024 1

I confirm that for the MNIST example, this issue is solved. With the following command:

!export XLA_FLAGS=--xla_gpu_deterministic_reductions && export TF_CUDNN_DETERMINISTIC=1 && echo $XLA_FLAGS, $TF_CUDNN_DETERMINISTIC && python main.py,

results are consistent between 2 runs (on Google Colab):

--xla_gpu_deterministic_reductions, 1
2020-10-28 16:14:10.248318: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
eval epoch: 1, loss: 0.0563, accuracy: 98.14
eval epoch: 2, loss: 0.0561, accuracy: 98.21
eval epoch: 3, loss: 0.0401, accuracy: 98.79
eval epoch: 4, loss: 0.0365, accuracy: 98.89
eval epoch: 5, loss: 0.0359, accuracy: 98.95
eval epoch: 6, loss: 0.0360, accuracy: 98.94
eval epoch: 7, loss: 0.0303, accuracy: 99.16
eval epoch: 8, loss: 0.0418, accuracy: 98.93
eval epoch: 9, loss: 0.0406, accuracy: 99.03
eval epoch: 10, loss: 0.0326, accuracy: 99.18

eval epoch: 1, loss: 0.0563, accuracy: 98.14
eval epoch: 2, loss: 0.0561, accuracy: 98.21
eval epoch: 3, loss: 0.0401, accuracy: 98.79
eval epoch: 4, loss: 0.0365, accuracy: 98.89
eval epoch: 5, loss: 0.0359, accuracy: 98.95
eval epoch: 6, loss: 0.0360, accuracy: 98.94
eval epoch: 7, loss: 0.0303, accuracy: 99.16
eval epoch: 8, loss: 0.0418, accuracy: 98.93
eval epoch: 9, loss: 0.0406, accuracy: 99.03
eval epoch: 10, loss: 0.0326, accuracy: 99.18

from flax.

avital avatar avital commented on May 22, 2024

Hi @goingtosleep! There's a good chance this is due to TFDS. Here's what would be great: Could you dump MNIST entirely to disk once, and then modify the dataset loading code to read from disk? I believe then the training runs will be fully reproducible.

I'd also like to share this with the TFDS team, as I believe there should be a way to get reproducible dataset loaders, though we haven't yet done that ourselves.

from flax.

Conchylicultor avatar Conchylicultor commented on May 22, 2024

An easy way to test this hypothesis would be to move the training pipeline outside of the train() fn and convert them to numpy and use list on the generator:

train_ds = tfds.load('mnist', split='train')
train_ds = list(tfds.as_numpy(train_ds.cache().batch(128)))
test_ds = tfds.as_numpy(tfds.load( 'mnist', split='test', batch_size=-1))

def train():
  ...
  for batch in ds_train:
    batch = dict(batch)  # Copy dict before mutating in-place
    ...

train()
train()

That way there is no more tf.data involved as train_ds is just a List[np.array]

Edit: Could you also share which version of Flax, TFDS and Python you're using ?

from flax.

jheek avatar jheek commented on May 22, 2024

On what accelerator did you run this test?

On GPU reproducibility is never guaranteed because Jax is currently not deterministic on that platform even when executing the exact same computations.
BTW: We should probably mention this somewhere in the docs and so should Jax!

On CPU I was unable to reproduce your issue and TPU results should also be reproducible.

from flax.

goingtosleep avatar goingtosleep commented on May 22, 2024

Thank you all for the replies.

I used flax-0.0.1a0, jax-0.1.59, jaxlib-0.1.39, tfds 2.0.0, on Google Colab P100 instance (Python 3.6.9).

I've done some tests so far:

  • On CPU, the random seed is fixed, the result is reproducible every run.

  • On GPU (Colab), I use keras datasets, which returns numpy ndarrays. Then I train the model with 25_000 first data points (due to memory limit) with no permutation and 1 batch only. Test accuracy is calculated on test set (10_000 data points) as usual. With this setup I believe there is no randomness involved except for weights initialization. The results are as follows:

First run:

Epoch 1 in [17.32s], loss [4.726857], accuracy [34.37%]
Epoch 2 in [0.02s], loss [3.494419], accuracy [70.59%]
Epoch 3 in [0.02s], loss [2.420179], accuracy [84.71%]
Epoch 4 in [0.02s], loss [1.485580], accuracy [88.12%]
Epoch 5 in [0.02s], loss [0.923213], accuracy [91.93%]

Second run:

Epoch 1 in [9.93s], loss [4.726847], accuracy [34.37%]
Epoch 2 in [0.02s], loss [3.494333], accuracy [70.60%]
Epoch 3 in [0.02s], loss [2.420010], accuracy [84.70%]
Epoch 4 in [0.02s], loss [1.485356], accuracy [88.11%]
Epoch 5 in [0.02s], loss [0.923163], accuracy [91.93%]

Third run:

Epoch 1 in [9.81s], loss [4.726859], accuracy [34.40%]
Epoch 2 in [0.02s], loss [3.494327], accuracy [70.59%]
Epoch 3 in [0.02s], loss [2.419940], accuracy [84.70%]
Epoch 4 in [0.02s], loss [1.485139], accuracy [88.11%]
Epoch 5 in [0.02s], loss [0.923007], accuracy [91.94%]

Rerun after restarting the runtime:

Epoch 1 in [19.73s], loss [4.726849], accuracy [34.37%]
Epoch 2 in [0.02s], loss [3.494359], accuracy [70.59%]
Epoch 3 in [0.02s], loss [2.420025], accuracy [84.70%]
Epoch 4 in [0.02s], loss [1.485441], accuracy [88.10%]
Epoch 5 in [0.02s], loss [0.923208], accuracy [91.93%]

I used a different network architecture so the accuracy and loss could be different from the MNIST example, but the network is deterministic (no dropout), so no randomness involved here in the training process.

I'm not sure if the differences in the above results are due to float32 precision. I tried

from jax.config import config
config.update("jax_enable_x64", True)

but the weights are still of float32, maybe I will test this later. What do you think?

from flax.

jheek avatar jheek commented on May 22, 2024

Can we close this for now? I think both XLA and Jax teams are aware of this issue and the fix is in progress.

from flax.

avital avatar avital commented on May 22, 2024

If the fix isn't in, I don't think we should close the issue.

from flax.

marcvanzee avatar marcvanzee commented on May 22, 2024

google/jax#565 is fixed. I've verified that @goingtosleep 's code now outputs reproducible results on a TPU, so I am closing this issue.

from flax.

avital avatar avital commented on May 22, 2024

Sorry, the fix isn't in. JAX is still not reproducible on GPU. We need to make sure there's an open ticket on the JAX GitHub tracker.

from flax.

avital avatar avital commented on May 22, 2024

Hi @goingtosleep -- this is a bit late but the the XLA_FLAGS=--xla_gpu_deterministic_reductions environment should work (though perhaps not yet for all operations). I'd be curious to see if this solves the issue.

from flax.

mattjj avatar mattjj commented on May 22, 2024

We're working on improving the JAX documentation on this in google/jax#4824. Feedback on that PR is welcome!

from flax.

mattjj avatar mattjj commented on May 22, 2024

It sounds like the --xla_gpu_deterministic_reductions flag is now gone (or it will be when we push an updated jaxlib) because it's now effectively always on by default. So hopefully this will get less surprising...

from flax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.