Comments (6)
Thanks for bringing this up. I believe this issue should have been fixed by Roman's work on jit-compilation in batching. I've tried your repro in colab and it seems to work, but let me know if any problems persist.
from neural-tangents.
Thanks, but it didn't solve it.
Colab code
from neural-tangents.
Ah I see, thanks! So this appears to work if you remove the jit
around get_network
, interesting! We'll look into it soon, but perhaps not until after ICLR.
One point that we should make more clear in the docs, especially when we have docs, is that we have found poor memory characteristics of applying jit
to batch
. We think this is an issue on the JAX / XLA end but haven't had time to pursue it to get them a simplified repro. If you run into OOM errors, it might be better to jit-before-batch, though we expect this situation to be temporary.
from neural-tangents.
I understand, thanks!
Do Do you have an idea how can I vmap the batch for different networks?
from neural-tangents.
It looks like the latest JAX version has this fixed, and both jit
and vmap
work! Here's an adapted example from above: https://colab.research.google.com/gist/romanngg/ffdd9a41fdf5479eaeac95772c259d27/jit_or_vmap_of_network.ipynb
Please note that I am also unsure of how reasonable it is to apply jit
to batch
, especially while having a vmap
on top. You may want to also try the options of
a) vmap
/ pmap
on get_network
, and have no batch
or jit
anywhere inside the function.
b) a simple python for
loop over the W_stds
of the jitted get_network
, with no batch
or jit
inside, or with batch
and no jit
inside.
etc.
I suggest this because the purpose of batch
is to sacrifice parallelization for the purpose of reducing memory cost. Therefore it seems redundant / sub-optimal to first call batch
and then parallelize it with vmap
or jit
afterwards. But I'm no expert on how these JAX optimizations work so I very well might be wrong here...
from neural-tangents.
Closing, please let me know if there are still related issues!
from neural-tangents.
Related Issues (20)
- How to do Aggregate on a Graph whose nodes are all vectors HOT 6
- The analytical output of GP can not fit the result of NNGP generated by the nt.predict.gp_inference HOT 1
- Question: Relu Kernel Computation HOT 3
- Question: Connection MLE "parametrized" GP in infinite Width Limit vs minimizing MSE "parametrized" Kernel in infinite Width HOT 4
- Question regarding OOM issues HOT 3
- Question regarding lr in Neural Tangents Cookbook
- eNTK implementation uses deprecated xla attribute HOT 2
- Colab notebooks issue HOT 2
- How to obtain aleatoric uncertainty? HOT 2
- How to compute the empirical after kernel? HOT 1
- pip install issues HOT 2
- Erf function goes beyond [-1,1] HOT 2
- using stax.Cos(a=1.0, b=1.0, c=0.0) to get kernel from conv layer gives error HOT 2
- NTK is not PD
- stax.serial PSDness HOT 1
- How to use batch to gradient_descent_mse_ensemble ? HOT 1
- NTK/NNGP behavior in the infinite regime when weights are drawn from Gaussians with high standard deviation HOT 7
- NKT_mean output Nan, when the number of training sample is increased HOT 3
- Inefficient jacobian computation for embedding layers. HOT 1
- Question regarding the cookbook
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from neural-tangents.