juliuskunze / jaxnet Goto Github PK
View Code? Open in Web Editor NEWConcise deep learning for JAX
License: Apache License 2.0
Concise deep learning for JAX
License: Apache License 2.0
When entries of the dict
provided as reuse
argument to parameters_from
/init_parameters
have conflicting values (i. e. their submodules overlap), resolve by overriding in order.
JAXnet currently depends on jax==0.1.41
.
See test_tuple_input
and test_dict_input
. Solution should look similar to what jit does to prepare arguments.
Fix test_scan_unparametrized_cell
Allow save(state, path)
according to design laid out in google/jax#1278.
See test_parametrized_jit
.
Handling parameters in JAX can get annoying, but what really concerns me even more is handling PRNG keys. JAX has a done a lot of great work to build a very strong PRNG system, but unfortunately splitting and managing random keys can be very messy and especially error-prone. It's alarmingly easy to accidentally reuse a PRNG key. It would be great to have a system analogous to @parameterized
and parameter()
but for random keys and seeds.
I envision an API providing something like @random
and rng()
:
@random
def my_func(x):
W = jax.random.normal(rng(), shape=(2, 2))
b = jax.random.exponential(rng(), shape=(2,))
return W @ x + b
And then ~ magic ~ happens after which point we get a function like:
def my_func(x, rng=None):
rng0, rng = jax.random.split(rng)
W = jax.random.normal(rng0, shape=(2, 2))
rng1, rng = jax.random.split(rng)
b = jax.random.exponential(rng1, shape=(2,))
return W @ x + b
Add a batching rule to the parametrized
primitive.
Including a Conv
into the mnist example
import time
import jax.numpy as np
import numpy.random as npr
from jax.random import PRNGKey
from jaxnet import Sequential, parametrized, Dense, relu, logsoftmax, optimizers, Conv, flatten
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
def mnist():
import tensorflow_datasets as tfds
dataset = tfds.load("mnist:1.0.0")
images = lambda d: np.reshape(np.float32(d['image']) / 256, (-1, 28, 28, 1))
labels = lambda d: _one_hot(d['label'], 10)
train = next(tfds.as_numpy(dataset['train'].shuffle(50000).batch(50000)))
test = next(tfds.as_numpy(dataset['test'].batch(10000)))
return images(train), labels(train), images(test), labels(test)
predict = Sequential(
Conv(32, (5, 5)), relu, flatten,
Dense(500), relu,
Dense(10), logsoftmax)
results in out-of-memory on GPU colab during apply_from
(init_parameters
is fine).
RNNs code needs to be repaired after update to jax-0.1.42
Optimizing an RNN fails with
NotImplementedError: Forward-mode differentiation rule for 'while' not implemented
This can be verified by running the OCR/RNN after removing the break statement.
scan
from jax allows training already, but we have to use a custom version of scan
to allow parameterization. Adding a custom differentiation rule for _scan_apply
` should fix this.
I tried mnist_classifier.py, which is corresponding to mnist_classifier.py of Jax.
I increased num_epochs = 1000
mnist_classifier.py of jaxnet failed with out-of-memory although mnist_classifier.py of Jax finished without any errors.
See TODO
in test_parametrized_jit
.
See test_parameter_sharing_between_multiple_parents
. Optimization is incorrect in this case since parameters are duplicated, not shared.
As explained here.
This applies not only to parameters but to any input-independent submodule calls.
I've simplified the code i'm running down to this minimal example:
import jax
from jax import nn
import jax.numpy as jnp
import jaxnet
import jaxnet.optimizers
cell = jaxnet.Dense(1)
@jaxnet.parametrized
def step_fn(carry, x):
y = cell(carry)
return y, y
@jaxnet.parametrized
def loss(init):
carry, y = jax.lax.scan(step_fn, init=init, xs=None, length=10)
return y.sum()
opt = jaxnet.optimizers.Adam()
opt_state = opt.init(loss.init_parameters(jnp.zeros([1, 1]), key=jax.random.PRNGKey(0)))
for i in range(5):
opt_state = opt.update(loss.apply, opt_state, jnp.zeros([1, 1]), jit=True)
Initially I get this error:
/usr/local/lib/python3.6/dist-packages/jax/util.py in split_dict(dct, names)
83 dct = dict(dct)
84 lst = [dct.pop(name) for name in names]
---> 85 assert not dct
86 return lst
87
AssertionError:
This is because unroll
is missing from the list of kwargs in _custom_cell_scan_impl
. If I add it and pass its value through to scan_p
, I then get this error:
/usr/local/lib/python3.6/dist-packages/jaxnet/core.py in next_parameters_for(self, primitive)
557 return parameters
558
--> 559 parameters = self.parameters[self._index]
560 self._index += 1
561 self.global_parameters_by_primitive[primitive] = parameters
IndexError: tuple index out of range
It appears that the params of dense aren't being added:
> opt.get_parameters(opt_state)
loss(step_fn=step_fn())
Any idea how to fix this?
I. e. test_wavenet
is too slow.
The unbatched PixelCNN example fails with
Exception: Can't lift Traced<ShapedArray(float32[32,32,1]):JaxprTrace(level=6/0)> to JaxprTrace(level=5/0)
while initializing parameters in down_shifted_conv
during trace_to_jaxpr
. An older version of the network was working on an unbatched input in jaxnet==0.1.4
and jax==0.1.41
, so this is probably a bug in the tracing logic of jaxnet.core
.
import jaxnet
os.environ["CUDA_VISIBLE_DEVICES"] = "2" => not working
Mish is a new novel activation function proposed in this paper.
It has shown promising results so far and has been adopted in several packages including:
All benchmarks, analysis and links to official package implementations can be found in this repository
It would be nice to have Mish as an option within the activation function group.
This is the comparison of Mish with other conventional activation functions in a SEResNet-50 for CIFAR-10:
While running examples/resnet50.py
(needed to change loss.init_parameters(rng_key, *next(batches))
from loss.init_parameters(*next(batches))
on line 100 apparently), I get
RuntimeError: Internal: Unable to launch convolution with type forward and algorithm (0, none)
.
Completely mysterious to me, but maybe obvious to @juliuskunze ?
Hi there,
When trying to import jaxnet with import jaxnet
, I get the following error:
~/miniconda3/envs/jaxnet/lib/python3.7/site-packages/jaxnet/core.py in <module>
6 import dill
7 import jax
----> 8 from jax import lax, random, unzip2, safe_zip, safe_map, partial, raise_to_shaped, tree_flatten, \
9 tree_unflatten, flatten_fun_nokwargs, jit, curry
10 from jax.abstract_arrays import ShapedArray
ImportError: cannot import name 'unzip2' from 'jax'
Libraries include:
Python 3.7.6 (conda)
jax 0.1.67
jaxlib 0.1.47
jaxnet 0.2.5
Thanks for this library!
The weight norm implementations dense
and conv_or_conv_transpose
from the pixelcnn example need custom batching rules + initialization.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.