Giter Site home page Giter Site logo

Comments (6)

sschoenholz avatar sschoenholz commented on May 19, 2024

Thanks for bringing this up. I believe this issue should have been fixed by Roman's work on jit-compilation in batching. I've tried your repro in colab and it seems to work, but let me know if any problems persist.

from neural-tangents.

ravidziv avatar ravidziv commented on May 19, 2024

Thanks, but it didn't solve it.
Colab code

from neural-tangents.

sschoenholz avatar sschoenholz commented on May 19, 2024

Ah I see, thanks! So this appears to work if you remove the jit around get_network, interesting! We'll look into it soon, but perhaps not until after ICLR.

One point that we should make more clear in the docs, especially when we have docs, is that we have found poor memory characteristics of applying jit to batch. We think this is an issue on the JAX / XLA end but haven't had time to pursue it to get them a simplified repro. If you run into OOM errors, it might be better to jit-before-batch, though we expect this situation to be temporary.

from neural-tangents.

ravidziv avatar ravidziv commented on May 19, 2024

I understand, thanks!
Do Do you have an idea how can I vmap the batch for different networks?

from neural-tangents.

romanngg avatar romanngg commented on May 19, 2024

It looks like the latest JAX version has this fixed, and both jit and vmap work! Here's an adapted example from above: https://colab.research.google.com/gist/romanngg/ffdd9a41fdf5479eaeac95772c259d27/jit_or_vmap_of_network.ipynb

Please note that I am also unsure of how reasonable it is to apply jit to batch, especially while having a vmap on top. You may want to also try the options of
a) vmap / pmap on get_network, and have no batch or jit anywhere inside the function.
b) a simple python for loop over the W_stds of the jitted get_network, with no batch or jit inside, or with batch and no jit inside.
etc.
I suggest this because the purpose of batch is to sacrifice parallelization for the purpose of reducing memory cost. Therefore it seems redundant / sub-optimal to first call batch and then parallelize it with vmap or jit afterwards. But I'm no expert on how these JAX optimizations work so I very well might be wrong here...

from neural-tangents.

romanngg avatar romanngg commented on May 19, 2024

Closing, please let me know if there are still related issues!

from neural-tangents.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.