Giter Site home page Giter Site logo

google-deepmind / dm-haiku Goto Github PK

View Code? Open in Web Editor NEW
2.8K 39.0 226.0 2.05 MB

JAX-based neural network library

Home Page: https://dm-haiku.readthedocs.io

License: Apache License 2.0

Starlark 3.16% Python 96.67% Shell 0.17%
machine-learning neural-networks jax deep-learning deep-neural-networks

dm-haiku's Introduction

Haiku: Sonnet for JAX

Overview | Why Haiku? | Quickstart | Installation | Examples | User manual | Documentation | Citing Haiku

pytest docs pypi

Important

๐Ÿ“ฃ As of July 2023 Google DeepMind recommends that new projects adopt Flax instead of Haiku. Flax is a neural network library originally developed by Google Brain and now by Google DeepMind. ๐Ÿ“ฃ

At the time of writing Flax has superset of the features available in Haiku, a larger and more active development team and more adoption with users outside of Alphabet. Flax has more extensive documentation, examples and an active community creating end to end examples.

Haiku will remain best-effort supported, however the project will enter maintenance mode, meaning that development efforts will be focussed on bug fixes and compatibility with new releases of JAX.

New releases will be made to keep Haiku working with newer versions of Python and JAX, however we will not be adding (or accepting PRs for) new features.

We have significant usage of Haiku internally at Google DeepMind and currently plan to support Haiku in this mode indefinitely.

What is Haiku?

Haiku is a tool
For building neural networks
Think: "Sonnet for JAX"

Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet, a neural network library for TensorFlow.

Documentation on Haiku can be found at https://dm-haiku.readthedocs.io/.

Disambiguation: if you are looking for Haiku the operating system then please see https://haiku-os.org/.

Overview

JAX is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support.

Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations.

Haiku provides two core tools: a module abstraction, hk.Module, and a simple function transformation, hk.transform.

hk.Modules are Python objects that hold references to their own parameters, other modules, and methods that apply functions on user inputs.

hk.transform turns functions that use these object-oriented, functionally "impure" modules into pure functions that can be used with jax.jit, jax.grad, jax.pmap, etc.

Why Haiku?

There are a number of neural network libraries for JAX. Why should you choose Haiku?

Haiku has been tested by researchers at DeepMind at scale.

  • DeepMind has reproduced a number of experiments in Haiku and JAX with relative ease. These include large-scale results in image and language processing, generative models, and reinforcement learning.

Haiku is a library, not a framework.

  • Haiku is designed to make specific things simpler: managing model parameters and other model state.
  • Haiku can be expected to compose with other libraries and work well with the rest of JAX.
  • Haiku otherwise is designed to get out of your way - it does not define custom optimizers, checkpointing formats, or replication APIs.

Haiku does not reinvent the wheel.

  • Haiku builds on the programming model and APIs of Sonnet, a neural network library with near universal adoption at DeepMind. It preserves Sonnet's Module-based programming model for state management while retaining access to JAX's function transformations.
  • Haiku APIs and abstractions are as close as reasonable to Sonnet. Many users have found Sonnet to be a productive programming model in TensorFlow; Haiku enables the same experience in JAX.

Transitioning to Haiku is easy.

  • By design, transitioning from TensorFlow and Sonnet to JAX and Haiku is easy.
  • Outside of new features (e.g. hk.transform), Haiku aims to match the API of Sonnet 2. Modules, methods, argument names, defaults, and initialization schemes should match.

Haiku makes other aspects of JAX simpler.

  • Haiku offers a trivial model for working with random numbers. Within a transformed function, hk.next_rng_key() returns a unique rng key.
  • These unique keys are deterministically derived from an initial random key passed into the top-level transformed function, and are thus safe to use with JAX program transformations.

Quickstart

Let's take a look at an example neural network, loss function, and training loop. (For more examples, see our examples directory. The MNIST example is a good place to start.)

import haiku as hk
import jax.numpy as jnp

def softmax_cross_entropy(logits, labels):
  one_hot = jax.nn.one_hot(labels, logits.shape[-1])
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)

def loss_fn(images, labels):
  mlp = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  logits = mlp(images)
  return jnp.mean(softmax_cross_entropy(logits, labels))

loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)

rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
params = loss_fn_t.init(rng, dummy_images, dummy_labels)

def update_rule(param, update):
  return param - 0.01 * update

for images, labels in input_dataset:
  grads = jax.grad(loss_fn_t.apply)(params, images, labels)
  params = jax.tree_util.tree_map(update_rule, params, grads)

The core of Haiku is hk.transform. The transform function allows you to write neural network functions that rely on parameters (here the weights of the Linear layers) without requiring you to explicitly write the boilerplate for initialising those parameters. transform does this by transforming the function into a pair of functions that are pure (as required by JAX) init and apply.

init

The init function, with signature params = init(rng, ...) (where ... are the arguments to the untransformed function), allows you to collect the initial value of any parameters in the network. Haiku does this by running your function, keeping track of any parameters requested through hk.get_parameter (called by e.g. hk.Linear) and returning them to you.

The params object returned is a nested data structure of all the parameters in your network, designed for you to inspect and manipulate. Concretely, it is a mapping of module name to module parameters, where a module parameter is a mapping of parameter name to parameter value. For example:

{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
            'w': ndarray(..., shape=(28, 300), dtype=float32)},
 'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
              'w': ndarray(..., shape=(1000, 100), dtype=float32)},
 'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
              'w': ndarray(..., shape=(100, 10), dtype=float32)}}

apply

The apply function, with signature result = apply(params, rng, ...), allows you to inject parameter values into your function. Whenever hk.get_parameter is called, the value returned will come from the params you provide as input to apply:

loss = loss_fn_t.apply(params, rng, images, labels)

Note that since the actual computation performed by our loss function doesn't rely on random numbers, passing in a random number generator is unnecessary, so we could also pass in None for the rng argument. (Note that if your computation does use random numbers, passing in None for rng will cause an error to be raised.) In our example above, we ask Haiku to do this for us automatically with:

loss_fn_t = hk.without_apply_rng(loss_fn_t)

Since apply is a pure function we can pass it to jax.grad (or any of JAX's other transforms):

grads = jax.grad(loss_fn_t.apply)(params, images, labels)

Training

The training loop in this example is very simple. One detail to note is the use of jax.tree_util.tree_map to apply the sgd function across all matching entries in params and grads. The result has the same structure as the previous params and can again be used with apply.

Installation

Haiku is written in pure Python, but depends on C++ code via JAX.

Because JAX installation is different depending on your CUDA version, Haiku does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install Haiku using pip:

$ pip install git+https://github.com/deepmind/dm-haiku

Alternatively, you can install via PyPI:

$ pip install -U dm-haiku

Our examples rely on additional libraries (e.g. bsuite). You can install the full set of additional requirements using pip:

$ pip install -r examples/requirements.txt

User manual

Writing your own modules

In Haiku, all modules are a subclass of hk.Module. You can implement any method you like (nothing is special-cased), but typically modules implement __init__ and __call__.

Let's work through implementing a linear layer:

class MyLinear(hk.Module):

  def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
    return jnp.dot(x, w) + b

All modules have a name. When no name argument is passed to the module, its name is inferred from the name of the Python class (for example MyLinear becomes my_linear). Modules can have named parameters that are accessed using hk.get_parameter(param_name, ...). We use this API (rather than just using object properties) so that we can convert your code into a pure function using hk.transform.

When using modules you need to define functions and transform them into a pair of pure functions using hk.transform. See our quickstart for more details about the functions returned from transform:

def forward_fn(x):
  model = MyLinear(10)
  return model(x)

# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk.transform(forward_fn)

x = jnp.ones([1, 1])

# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk.PRNGSequence(42)
params = forward.init(next(key), x)

# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument.  Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward.apply(params, None, x)

Working with stochastic models

Some models may require random sampling as part of the computation. For example, in variational autoencoders with the reparametrization trick, a random sample from the standard normal distribution is needed. For dropout we need a random mask to drop units from the input. The main hurdle in making this work with JAX is in management of PRNG keys.

In Haiku we provide a simple API for maintaining a PRNG key sequence associated with modules: hk.next_rng_key() (or next_rng_keys() for multiple keys):

class MyDropout(hk.Module):

  def __init__(self, rate=0.5, name=None):
    super().__init__(name=name)
    self.rate = rate

  def __call__(self, x):
    key = hk.next_rng_key()
    p = jax.random.bernoulli(key, 1.0 - self.rate, shape=x.shape)
    return x * p / (1.0 - self.rate)

forward = hk.transform(lambda x: MyDropout()(x))

key1, key2 = jax.random.split(jax.random.PRNGKey(42), 2)
params = forward.init(key1, x)
prediction = forward.apply(params, key2, x)

For a more complete look at working with stochastic models, please see our VAE example.

Note: hk.next_rng_key() is not functionally pure which means you should avoid using it alongside JAX transformations which are inside hk.transform. For more information and possible workarounds, please consult the docs on Haiku transforms and available wrappers for JAX transforms inside Haiku networks.

Working with non-trainable state

Some models may want to maintain some internal, mutable state. For example, in batch normalization a moving average of values encountered during training is maintained.

In Haiku we provide a simple API for maintaining mutable state that is associated with modules: hk.set_state and hk.get_state. When using these functions you need to transform your function using hk.transform_with_state since the signature of the returned pair of functions is different:

def forward(x, is_training):
  net = hk.nets.ResNet50(1000)
  return net(x, is_training)

forward = hk.transform_with_state(forward)

# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params, state = forward.init(rng, x, is_training=True)

# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits, state = forward.apply(params, state, rng, x, is_training=True)

If you forget to use hk.transform_with_state don't worry, we will print a clear error pointing you to hk.transform_with_state rather than silently dropping your state.

Distributed training with jax.pmap

The pure functions returned from hk.transform (or hk.transform_with_state) are fully compatible with jax.pmap. For more details on SPMD programming with jax.pmap, look here.

One common use of jax.pmap with Haiku is for data-parallel training on many accelerators, potentially across multiple hosts. With Haiku, that might look like this:

def loss_fn(inputs, labels):
  logits = hk.nets.MLP([8, 4, 2])(x)
  return jnp.mean(softmax_cross_entropy(logits, labels))

loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)

# Initialize the model on a single device.
rng = jax.random.PRNGKey(428)
sample_image, sample_label = next(input_dataset)
params = loss_fn_t.init(rng, sample_image, sample_label)

# Replicate params onto all devices.
num_devices = jax.local_device_count()
params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)

def make_superbatch():
  """Constructs a superbatch, i.e. one batch of data per device."""
  # Get N batches, then split into list-of-images and list-of-labels.
  superbatch = [next(input_dataset) for _ in range(num_devices)]
  superbatch_images, superbatch_labels = zip(*superbatch)
  # Stack the superbatches to be one array with a leading dimension, rather than
  # a python list. This is what `jax.pmap` expects as input.
  superbatch_images = np.stack(superbatch_images)
  superbatch_labels = np.stack(superbatch_labels)
  return superbatch_images, superbatch_labels

def update(params, inputs, labels, axis_name='i'):
  """Updates params based on performance on inputs and labels."""
  grads = jax.grad(loss_fn_t.apply)(params, inputs, labels)
  # Take the mean of the gradients across all data-parallel replicas.
  grads = jax.lax.pmean(grads, axis_name)
  # Update parameters using SGD or Adam or ...
  new_params = my_update_rule(params, grads)
  return new_params

# Run several training updates.
for _ in range(10):
  superbatch_images, superbatch_labels = make_superbatch()
  params = jax.pmap(update, axis_name='i')(params, superbatch_images,
                                           superbatch_labels)

For a more complete look at distributed Haiku training, take a look at our ResNet-50 on ImageNet example.

Citing Haiku

To cite this repository:

@software{haiku2020github,
  author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
  title = {{H}aiku: {S}onnet for {JAX}},
  url = {http://github.com/deepmind/dm-haiku},
  version = {0.0.10},
  year = {2020},
}

In this bibtex entry, the version number is intended to be from haiku/__init__.py, and the year corresponds to the project's open-source release.

dm-haiku's People

Contributors

8bitmp3 avatar anukaal avatar aslanides avatar chris-chris avatar conchylicultor avatar creativeentropy avatar dependabot[bot] avatar filangelos avatar froystig avatar hamzamerzic avatar hawkinsp avatar hbq1 avatar ibab avatar inoryy avatar joaogui1 avatar kristianholsheimer avatar lenamartens avatar madisonmay avatar marload avatar mattjj avatar mknbv avatar neilgirdhar avatar rchen152 avatar rerrayne avatar superbobry avatar tamaranorman avatar tomhennigan avatar tomwardio avatar trevorcai avatar yashk2810 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dm-haiku's Issues

Remove optix, add optax

The JAX team is going to remove optix from the library now that optax exists, can I open a PR changing all imports from jax.experimental.optix to optax?

I'm curious about the reason why you use JAX.

After you open-sourced haiku, rlax recently, I started to learn about JAX and I love these projects.
However, whenever I share these open sources to deep learning communities and my friends, they don't seem to be convinced to use them. I believe you must have solid reasons for choosing JAX for the RL research framework, which I've tried to guess. So far, these are the advantages that I could think of.

  • Pure function without side effects
  • Performance enhancement using XLA
  • Lightweight and NEAT
  • Fully utilizing vectorization and parallelization
  • Flexible architecture for a complicated distributed learning system

But I'm still not fully aware of the logic behind it. Would it be possible for you to briefly explain the motivation behind using JAX for your research?

Announcing Elegy!

Hey, I created this framework called Elegy with a couple of friends, its Keras-like framework based on Haiku. The experience of creating a higher-level API for training using Haiku has been interesting, we've had to wrap some of Haiku's modules and functions to enable additional features but the overall experience has been very positive! It would be great if some of these features eventually went back to Haiku in the future if you are interested.

https://poets-ai.github.io/elegy/

Feel free to close the issue.

JAX version in Colab

Right now colab uses jax 0.1.69 and so it fails to import haiku.
I believe it would be better to add both jax and jaxlib as normal requirements, so when someone installs they get the correct versions of the dependencies.

Haiku LSTM vs Torch LSTM

I'm trying to figure out the differences between pytorch's LSTM and Haiku's LSTM. In Haiku, the LSTM expects the input to be a rank 1 or rank 2 tensor. However, in pytorch, the input is expected to be a rank 3 tensor. I was hoping to get some clarification on how one would expect to get the same behavior in Haiku's lstms compared to torch's lstms.

Some additional information -- if this is the case, it seems like the LSTM in Haiku is significantly slower than torch. It's very likely that I've made a mistake somewhere here.

Also, would it be correct in saying that haiku's static_unroll + lstm is the same as pytorch's lstm? In which case, isn't Haiku's lstm technically an LSTMcell?

AttributeError: Can't pickle local object 'transform_with_state.<locals>.init_fn'

for concurrent.futures.ProcessPoolExecutor or multiprocessing.Pool, it's necessary to pickle objects to send em to other processes to run. That doesn't work with transformed haiku stuff:

import pickle

import haiku as hk

def forward(x): return hk.Linear(x)
stateful = hk.transform_with_state(forward)
pickle.dumps(stateful)
AttributeError: Can't pickle local object 'transform_with_state.<locals>.init_fn'
stateless = hk.transform(forward)
pickle.dumps(stateless)
AttributeError: Can't pickle local object 'without_state.<locals>.init_fn'

How could we pickle haiku models?

Should we take a different approach?

hk.Module should be an abstract Callable

Hey, I use pyright / pylance for type checking and they are pretty unhappy that hk.Module doesn't define and abstract __call__ method, I get type errors all over the place when defining code that take arbitrary hk.Modules. Given most of Haiku is already typed this would be a nice addition.

extra dimension for jacobian of model, or jacfwd on model.apply

Hi:

I'm computing the Jacobian of my model on every step. The input is a z-dimensional vector and model's output is (bz,ch,h,w), it's a decoder. I expected the output of the Jacobian to be (bz,ch,h,w,z) and it is when i compute the jacobian of a row in my batch. However, when I stick in the whole batch, the dimensions become: (bz,ch,h,w,bz,z). Why the extra bz?

I checked the Jacobian output values for the i^th row, and (i,ch,h,w,j,z) is zero for all j != i. And it has the correct values when j = i. I can still use this; however, I have to take the extra step of removing the zero outputs.

Here's my code for the Jacobian:

decoder_jac = hk.transform(lambda z: jacfwd(Decoder())(z))

Jax version upgrade (AttributeError: CallPrimitive)

Using the current version of master 66f9c69 of Haiku, I am getting the following error on Colab

AttributeError                            Traceback (most recent call last)
<ipython-input-3-3a9e6adbfff5> in <module>()
----> 1 import haiku as hk

/usr/local/lib/python3.6/dist-packages/haiku/__init__.py in <module>()
     17 
     18 from haiku import data_structures
---> 19 from haiku import experimental
     20 from haiku import initializers
     21 from haiku import nets

/usr/local/lib/python3.6/dist-packages/haiku/experimental.py in <module>()
     22 from haiku._src.base import custom_getter
     23 from haiku._src.base import ParamContext
---> 24 from haiku._src.dot import to_dot
     25 from haiku._src.lift import lift
     26 from haiku._src.module import profiler_name_scopes

/usr/local/lib/python3.6/dist-packages/haiku/_src/dot.py in <module>()
     23 
     24 from haiku._src import data_structures
---> 25 from haiku._src import module
     26 from haiku._src import utils
     27 import jax

/usr/local/lib/python3.6/dist-packages/haiku/_src/module.py in <module>()
     26 from haiku._src import base
     27 from haiku._src import data_structures
---> 28 from haiku._src import named_call
     29 from haiku._src import utils
     30 import jax.numpy as jnp

/usr/local/lib/python3.6/dist-packages/haiku/_src/named_call.py in <module>()
     29 
     30 # Registering named call as a primitive
---> 31 named_call_p = core.CallPrimitive('named_call')
     32 # named_call is implemented as a plain core.call and only diverges
     33 # under compilation (see named_call_translation_rule)

AttributeError: module 'jax.core' has no attribute 'CallPrimitive'

I believe that's because Haiku now requires jax>=0.1.71, while the version by default on Colab is jax==0.1.69. CallPrimitive was introduced in jax 0.1.71.
https://github.com/google/jax/blob/1545a29e6d69a7b3c7fdf9a49b38004759a9fbfa/jax/core.py#L1106-L1115

To reproduce (inside a Colab):

import jax
print(jax.__version__)  # 0.1.69

!pip install -q git+https://github.com/deepmind/dm-haiku
import haiku as hk

Run !pip install -q --upgrade jax jaxlib first in your Colab to fix this issue.

Feature Request: some way to pass hyperparameters out of transforms

Do you want to hand tune your models? Most folks donโ€™t because itโ€™s slow. So we get into hyperparameter optimization.

In Haiku, we define a model like

def forward(x):
   return MyModule(hyperparameters)(x)

model = hk.transform(forward)

Where hyperparameters is a dict or class to record decisions modules make when they โ€œpullโ€ choices like โ€œnumber of layersโ€ โ€œlatent code dimensionalityโ€ โ€œblock 2 n_headsโ€ ... we cannot enumerate all possible hyperparameters ahead of time because certain ones depend on the value of earlier ones and there are just too many combinations

Those hyperparameters become โ€œstuckโ€ inside the forward function, unless we make hyperparameters a mutable datastructure and mutate it inside all our MyModule.init methods. Then, the forward function has deeply nested side effects when defined, so you cannot share the decisions across a group of agents because they will mutate shared state, and the haiku tracer needs to know to ignore these init hyperparameter mutations, which Iโ€™m assuming it does but not 100% sure. Mutating hyperparameters at definition time is also confusing for engineers because thereโ€™s many โ€œinitโ€ steps at different times

How can we make an elegant way for โ€œforwardโ€ methods to pull hyperparameters and return their decisions?

Improve `params_dict()` support

module.params_dict() can behave in surprising ways:

def f(x):
  mod = hk.Linear(8)
  print(mod.params_dict())  # empty during init, full during apply
  sequential = hk.Sequential([mod])
  print(sequential.params_dict())  # always empty
  out = sequential(x)
  print(sequential.params_dict())  # no longer empty
  return out

net = hk.transform(f)
p = net.init(jax.random.PRNGKey(428), np.zeros((2, 3)))
net.apply(p, np.zeros((2, 3)))

Prints:

{}
{}
{...}
{...}
{}
{...}

We should clean up & clearly define the desired semantics of params_dict().

Ensure float32 inputs imply float32 outputs when jax_enable_x64=1

import os
os.environ["JAX_ENABLE_X64"] = "1"

import jax
import haiku as hk
import numpy as np

@hk.transform
def f(x):
  return hk.Linear(4)(x)

f32_data = np.zeros((4, 8), dtype=np.float32)

p = f.init(jax.random.PRNGKey(428), f32_data)
print(jax.tree_map(lambda t: t.dtype, p))
f32_params = jax.tree_map(lambda t: t.astype(np.float32), p)
print(f.apply(f32_params, f32_data).dtype)

Prints:

frozendict({
  'linear': frozendict({'b': dtype('float32'), 'w': dtype('float64')}),
})
dtype('float32')

Hopefully the bfloat16 compatibility work means we get most of this for free and that we only need to port the initializers.

New interface for spectral normalization

I noticed earlier today that Haiku has SpectralNormalization -- very cool!

I'm interested in implementing an improved version, which does a much better job estimating the norm for convolutional layers and should converge to the correct answer for any linear operator. The trick is to use auto-diff to calculate the transpose of the linear operator. In contrast, the current implementation is only accurate for dense matrices.

Here's my implementation in pure JAX: https://nbviewer.jupyter.org/gist/shoyer/fa9a29fd0880e2e033d7696585978bfc

My question: how can I implement this in Haiku?

  • It feel like the right way to write this would be as a Module that takes another Module (or function) as an argument, but I don't know of any existing prior art for that. Would does that make sense to you?
  • How do I call jax.vjp on Module? I'm guessing (though to be honest I haven't checked yet) that normal JAX function would break, given the way that Haiku adds mutable state.

hk.Embed's embedding_matrix argument can't be supplied a np.ndarray

Seems like unintended behavior.

/usr/local/lib/python3.6/dist-packages/haiku/_src/embed.py in __init__(self, vocab_size, embed_dim, embedding_matrix, w_init, lookup_style, name)
     73     """
     74     super(Embed, self).__init__(name=name)
---> 75     if not embedding_matrix and not (vocab_size and embed_dim):
     76       raise ValueError(
     77           "hk.Embed must be supplied either with an initial `embedding_matrix` "

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Allow `hk.next_rng_key` calls outside of `hk.transform` with `hk.with_rng`

Currently, there is no easy way to just call an initializer outside of hk.transform. This can be worked around but it discourages using hk.next_rng_key when writing custom initializers since whatever code depending on hk.next_rng_key is now unusable without buying in to everything else. This is a minor annoyance when debugging an initializer since there is not way to retrieve sampled weights without wrappers. I could definitely imagine this becoming more of a pain in more complex settings that involve writing many initializers (admittedly, whether that is a plausible scenario is a different question).

This is definitely not a high priority thing but if there was a way to allow using hk.with_rng outside of transform, it wouldn't be an issue anymore. If the diifferent parts of the frame stack backend are too interdependent to safely allow hk.with_rng outside of hk.transform, then it's probably not worth the effort, but it never hurts to ask!

`apply` method really slow on CPU

I am implementing the VGG16 architecture which is pretty straightforward, but when I run hk.transform to obtain the initial parameters the call does not ever end.

My code:

import jax
import jax.numpy as np

import haiku as hk

from PIL import Image

cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']

def _make_layers():
    layers = []
    in_channels = 3

    for v in cfg:
        if v == 'M':
            layers.append(hk.MaxPool(window_shape=[2, 2], 
                                     strides=[2, 2],
                                     padding="VALID"))
        else:
            conv2d = hk.Conv2D(v, 
                               kernel_shape=[3, 3], 
                               stride=1, 
                               padding='SAME')
            layers += [conv2d, jax.nn.relu]
            in_channels = v

    return layers

class VGG16(hk.Module):
    
    def __init__(self, num_classes: int = 1000) -> None:
        super(VGG16, self).__init__()
        features = _make_layers()
        classifier = [
                hk.Flatten(),
                hk.Linear(4096), jax.nn.relu,
                hk.Linear(4096), jax.nn.relu,
                hk.Linear(num_classes), jax.nn.softmax
        ]
        self.vgg = hk.Sequential(features + classifier)

    def __call__(self, x):
        return self.vgg(x)

def _forward(image):
    net = VGG16()
    return net(image)

rng = jax.random.PRNGKey(0)

print('Loading image...')
im = Image.open('images/dog.jpeg').resize((224, 224))
im = np.array(im).astype('float32') / 255.

print('hk.transform...')
vgg_forward_fn = hk.transform(_forward)

print('This will never end..')
params = vgg_forward_fn.init(rng, np.expand_dims(im, 0))

I am on Windows 10 using WSL. I have just implemented the same architecture using the experimental stax library, and the initialization ends up with just 1 second.

Any idea on what I am doing wrong?

Thanks in advance

Omnistaging (still) breaks LSTM Model

I left an issue on the JAX page here about omni-staging breaking LSTMs, but it was closed since @tomhennigan mentioned it was fixed in a recent PR, but it seems like it is not. In summary, JIT compiling the gradient of the loss with respect to an LSTM with static unrolling doesn't ever finish compiling. It keeps using more and more RAM until it crashes.

Here is a gist reproducing the issue (with Haiku and JAX at head). The colab can't be run sequentially: you must run the top portion then one of "With OmniStaging" or "Without OmniStaging" (and restart to run the other one).

The code works when I use dynamic_unroll instead of static_unroll, but this is just a temporary workaround as I would like to use static_unroll.

hk.stateful.remat generates excess un-pruneable HLO

jax.remat wraps all of its inputs with _foil_cse.

When we do the state-threading in hk.stateful.remat, the threaded-out state now is the output of _foil_cse. Any downstream uses of this state now access the foil-cse'd param/state, rather than the original.

Example:

def f(x, ctxt):
  return jnp.sin(x + ctxt[0]), ctxt

@jax.jit
def g(x):
  ctxt = [x + i for i in range(2)]
  x, ctxt = jax.remat(f)(x, ctxt)
  return jnp.sin(x + ctxt[1])

g(1.).block_until_ready()

This results in HLO that looks like:

HloModule jit_g__1.46, is_scheduled=true

ENTRY jit_g__1.46 {
  constant.2 = f32[]{:T(256)} constant(2)
  constant = f32[]{:T(256)} constant(0)
  constant.5 = f32[]{:T(256)} constant(1)
  rng.2 = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
  compare.2 = pred[]{:T(256)E(32)} compare(rng.2, constant.2), direction=LT
  rng.1 = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
  compare.1 = pred[]{:T(256)E(32)} compare(rng.1, constant.2), direction=LT
  rng = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
  compare = pred[]{:T(256)E(32)} compare(rng, constant.2), direction=LT
  parameter.1 = f32[]{:T(256)} parameter(0), parameter_replication={false}
  select = f32[]{:T(256)} select(compare, parameter.1, constant)
  select.1 = f32[]{:T(256)} select(compare.1, parameter.1, constant)
  add = f32[]{:T(256)} add(select, select.1)
  sine = f32[]{:T(256)} sine(add)
  add.6 = f32[]{:T(256)} add(parameter.1, constant.5)
  select.2 = f32[]{:T(256)} select(compare.2, add.6, constant)
  add.43 = f32[]{:T(256)} add(sine, select.2)
  sine.44 = f32[]{:T(256)} sine(add.43)
  ROOT tuple.45 = (f32[]{:T(256)}) tuple(sine.44)
}

Possible solutions:

  1. Reduce the amount of state-motion in & out of stateful_fun, especially during apply.
    • Params are immutable during apply, don't thread them in/out.
    • Pre-split the RNG and populate a new hk.PRNGSequence inside stateful_fun so that RNG doesn't get threaded in/out.
    • state is only updated in-place for state that's actually been changed. JAX referential transparency makes this challenging for the case in which Haiku is not jitted but internal functions are via hk.jit.
  2. Rebuild hk.remat on top of hk._src.lift.

Support for custom pytrees

Hello, haiku team! Thanks a lot for making awesome haiku.

I'm interested in sequential probabilistic models. Normally, parameters of probabilistic models are constrained. A simple example would be variance. It can only be positive. I gave an example and explanation of the constrained parameters in #16 (comment). The pytrees ideally fits into the described use case. The user can create its own differentiable "vectors" and I would expect haiku to support these custom structures out of the box. This would allow a user to get back actual structures from transformed functions for printing, debugging, and plotting purposes (the list can be enlarged with other examples from academic needs). Unfortunately, custom differentiable structures don't work at the moment.

Failing example

In [58]: class S(hk.Module):
    ...:   def __init__(self, x, y):
    ...:     super().__init__()
    ...:     # These are parameters:
    ...:     self.x = x
    ...:     self.y = y
    ...:   def __repr__(self):
    ...:     return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
    ...: def S_flatten(v):
    ...:   children = (v.x, v.y)
    ...:   aux_data = None
    ...:   return (children, aux_data)
    ...: def S_unflatten(aux_data, children):
    ...:   return S(*children)
    ...: register_pytree_node(S, S_flatten, S_unflatten)
    ...:
    ...:
    ...: def function(s):
    ...:   return np.sqrt(s.x**2 * s.y**2)
    ...:
    ...: def loss(x):
    ...:   s = S(1.0, 2.0)
    ...:   a = hk.get_parameter("free_parameter", shape=[], dtype=jnp.float32, init=jnp.zeros)
    ...:   return jnp.sum([function(s) * a * x])
    ...:
    ...: x = jnp.array([2.0])
    ...: forward = hk.transform(loss)
    ...: key = jax.random.PRNGKey(42)
    ...: params = forward.init(key, x)
In [59]: params
Out[59]:
frozendict({
  '~': frozendict({'free_parameter': DeviceArray(0., dtype=float32)}),
})

Thanks

Dealing with conditionally constant state

How could I add a constant state to my haiku module?
Specifically I would want something like this:

class MyModule(hk.Module):
  def __init__(output_size, const, name):
    if const = True:
      self.b = hk.conts(jnp.ones(output_size)) //won't be updated when adding gradient
    else:
      self.b = jnp.zeros(output_size) //will get updated when adding gradient

Faithfully reconstruct tree from context

This variation on @tomhennigan's example tries to build a tree of module types.

It assumes the parameter creation order is preserved when flattening the parameter dictionary, which may be incorrect. Alternatively, if the path could be added to context, or if it is possible to recover the path from context, that would support a more satisfying solution. With module names and parameters possibly containing "/", it is not clear to me how to construct the path. What am I missing?

def init_and_build_module_tree(f):
    """
    Decorated functions build a tree of module types alongside the parameters

    Usage:
      def f(x):
        net = haiku.nets.MLP([300, 100, 10])
        return net(x)

      params, modules = init_and_build_module_tree(f)(rng_key, np.zeros(4))
      params = tree.map_structure(transform_params, params, modules)
    """

    def _init_and_build_module_tree(rng_key, *args, **kwargs):
        module_types = []

        def record_module_type(next_creator, shape, dtype, init, context):
            module_types.append(type(context.module))
            return next_creator(shape, dtype, init)

        def with_creator(*aargs, **kkwargs):
            with haiku.experimental.custom_creator(record_module_type):
                return f(*aargs, **kkwargs)

        params, _ = haiku.transform_with_state(with_creator).init(
            rng_key,
            *args,
            **kwargs
        )

        module_tree = tree.unflatten_as(
            params,
            module_types
        )

        return params, module_tree

    return _init_and_build_module_tree

hk.add_loss

To enable users to easily create per layer weight and activity regularizers plus other forms of losses created by intermediate layers it would be very useful if haiku had a hk.add_loss utility that when called within a transform it would append a loss to a list of losses which the user could later retrieve as an additional output from apply. I guess that this would require an additional flag to hk.transform and friends.

Folder Structure

I've noticed _src has become quite large. I think eventually splitting it up into folders makes more sense. We could have:

  • nn
  • initializers
  • regularizers
  • losses
  • metrics

deepcopy broken for new FlatMapping

Hi there,

I noticed there have been some changes to FlatMapping.

I can imagine that you don't really see the need for deepcopying a FlatMapping as it's supposed to be immutable. But just so you know, deepcopy doesn't work anymore:

from haiku._src.data_structures import FlatMapping

m = FlatMapping.from_mapping({'foo': 'bar'})
deepcopy(m)  # raises TypeError: can't pickle jaxlib.pytree.PyTreeDef objects

P.S. I stumbled upon this because I'm deriving a subclass with limited mutability from FlatMapping (whose leaves are all DeviceArrays). I'm using deepcopy for my target-network / behavior-policy weights.

I added a custom implementation to my derived class:

class Foo(FlatMapping):
    ...

    def __deepcopy__(self, memo):
        leaves, treedef = self.flatten()
        return self.__class__((deepcopy(leaves), treedef))

Also.. thanks for the speed-up in FlatMapping!

He initialization

The default initialization for linear and convolutional modules seems to be Glorot initialization, but for the commonly used ReLU activation function He initialization is superior, while only requiring a quick change to the stddev definition, should we implement better defaults?
I know that there are many initialization schemes, I only suggest it as it would't be computationally expensive and would also be only a minor code change.

Name change or adjustment?

Hi! I'm one of the developers of Haiku and a board member of Haiku, Inc.

We have trademark on the name Haiku around our open-source operating system, as well as a registered logo mark on our Haiku logo.

Haiku has been in development for nearly 20 years now, and usage of the Haiku name for other open source software projects could create confusion.

Please feel free to reach out to Haiku, Inc. if you have any questions.

https://haiku-os.org
https://haiku-inc.org

Fatal Python error: Aborted when running mnist.py example

Hi there! I've been trying to get familiar with the library by running some examples in the examples/ folder. My environment was set up according to the instructions on https://github.com/google/jax#installation and https://github.com/deepmind/dm-haiku#installation.

On running the mnist.py example with TensorFlow 2.1.0, a Fatal Python error: Aborted occurs. The full error message is as below:

2020-05-11 20:08:28.772679: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer.so.6'; dlerror: libnvinfer.so.6: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64
2020-05-11 20:08:28.772772: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer_plugin.so.6'; dlerror: libnvinfer_plugin.so.6: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64
2020-05-11 20:08:28.772790: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:30] Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
I0511 20:08:29.945827 140124110780224 dataset_info.py:361] Load dataset info from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:29.947907 140124110780224 dataset_info.py:401] Field info.citation from disk and from code do not match. Keeping the one from code.
I0511 20:08:29.948163 140124110780224 dataset_builder.py:283] Reusing dataset mnist (/home/tedmund/tensorflow_datasets/mnist/3.0.1)
I0511 20:08:29.948284 140124110780224 dataset_builder.py:479] Constructing tf.data.Dataset for split train, from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.869276 140124110780224 dataset_info.py:361] Load dataset info from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.870402 140124110780224 dataset_info.py:401] Field info.citation from disk and from code do not match. Keeping the one from code.
I0511 20:08:30.870607 140124110780224 dataset_builder.py:283] Reusing dataset mnist (/home/tedmund/tensorflow_datasets/mnist/3.0.1)
I0511 20:08:30.870710 140124110780224 dataset_builder.py:479] Constructing tf.data.Dataset for split train, from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.915239 140124110780224 dataset_info.py:361] Load dataset info from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.916428 140124110780224 dataset_info.py:401] Field info.citation from disk and from code do not match. Keeping the one from code.
I0511 20:08:30.916637 140124110780224 dataset_builder.py:283] Reusing dataset mnist (/home/tedmund/tensorflow_datasets/mnist/3.0.1)
I0511 20:08:30.916740 140124110780224 dataset_builder.py:479] Constructing tf.data.Dataset for split test, from /home/tedmund/tensorflow_datasets/mnist/3.0.1
2020-05-11 20:08:32.182307: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:236] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
2020-05-11 20:08:32.182349: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:113] Check failed: stream->parent()->GetBlasGemmAlgorithms(&algorithms) 
Fatal Python error: Aborted

Current thread 0x00007f712fd8f740 (most recent call first):
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jaxlib/xla_client.py", line 156 in compile
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jaxlib/xla_client.py", line 576 in Compile
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/interpreters/xla.py", line 197 in xla_primitive_callable
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/interpreters/xla.py", line 166 in apply_primitive
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/core.py", line 199 in bind
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/lax/lax.py", line 626 in dot_general
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/lax/lax.py", line 564 in dot
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 2484 in dot
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/basic.py", line 161 in __call__
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/module.py", line 155 in wrapped
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/basic.py", line 120 in __call__
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/module.py", line 155 in wrapped
  File "mnist.py", line 41 in net_fn
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/transform.py", line 271 in init_fn
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/transform.py", line 106 in init_fn
  File "mnist.py", line 112 in main
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/absl/app.py", line 250 in _run_main
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/absl/app.py", line 299 in run
  File "mnist.py", line 131 in <module>
Aborted (core dumped)

One solution I've found to this is a more commonplace solution when using TensorFlow, by inserting the code:

from tensorflow.compat.v1 import ConfigProto, InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.7
sess = InteractiveSession(config=config)

However, this kind of defeats the purpose if one is simply trying to use JAX/NumPy instead of TensorFlow. Not sure what else I can provide to help, please do let me know!

Disable __setattr__ in frozendict

There can be some strange behavior in frozendict due to attr assignment:

import haiku as hk

frozen = hk.data_structures.to_immutable_dict({'a': 'foo'})
print(frozen['a'], frozen.a)  # --> foo foo
frozen.a = 'bar'              # should raise error
print(frozen['a'], frozen.a)  # --> foo bar

My suggestion would be to simply disallow any setattr on frozendict.

Also, as a side note, please expose frozendict to the public api. Right now, the only way to access it is through haiku._src.data_structures.frozendict.

MyModule.__init__ runs every time 'apply' is called

import haiku as hk


class MyModule(hk.Module):
    def __init__(self):
        super().__init__(name="MyModule")
        print("MyModule.__init__ ran")

    def __call__(self, x):
        print("MyModule.__call__ ran")
        return x + 1


def forward(x):
    return MyModule()(x)


rng = hk.PRNGSequence(420)
fwd = hk.transform(forward)
params = fwd.init(next(rng), 1)
x = 0
for _ in range(10):
    x = fwd.apply(params, x)
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran

is this intended? It caused my module to break eventually because it pulled different hyperparameters at different invocations

Is there a good way to save/load & compress/decompress model weights?

Hey- This is Chris.
I'm using this open-source for my project.

https://github.com/chris-chris/haiku-scalable-example

Since I'm new to JAX and haiku, I have some questions.

Is there a good way to save/load & compress/decompress & serialize model weights?

  • save/load model (network only or weight only)
  • compress/decompress weights
  • serialize

I think serialization is an important issue on scalability. Can you give me some keywords or hints about this issue?

Thanks!

Missing files in sdist

It appears that the manifest is missing at least one file necessary to build
from the sdist for version 0.0.1b0. You're in good company, about 5% of other
projects updated in the last year are also missing files.

+ /tmp/venv/bin/pip3 wheel --no-binary dm-haiku -w /tmp/ext dm-haiku==0.0.1b0
Looking in indexes: http://10.10.0.139:9191/root/pypi/+simple/
Collecting dm-haiku==0.0.1b0
  Downloading http://10.10.0.139:9191/root/pypi/%2Bf/f19/fdaf8281b7fb0/dm-haiku-0.0.1b0.tar.gz (121 kB)
    ERROR: Command errored out with exit status 1:
     command: /tmp/venv/bin/python3 -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-wheel-8pyg10w7/dm-haiku/setup.py'"'"'; __file__='"'"'/tmp/pip-wheel-8pyg10w7/dm-haiku/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' egg_info --egg-base /tmp/pip-wheel-8pyg10w7/dm-haiku/pip-egg-info
         cwd: /tmp/pip-wheel-8pyg10w7/dm-haiku/
    Complete output (7 lines):
    Traceback (most recent call last):
      File "<string>", line 1, in <module>
      File "/tmp/pip-wheel-8pyg10w7/dm-haiku/setup.py", line 56, in <module>
        install_requires=_parse_requirements('requirements.txt'),
      File "/tmp/pip-wheel-8pyg10w7/dm-haiku/setup.py", line 33, in _parse_requirements
        with open(requirements_txt_path) as fp:
    FileNotFoundError: [Errno 2] No such file or directory: 'requirements.txt'
    ----------------------------------------
ERROR: Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.

Iterating through hk modules

Let's say I want to iterate through all modules inside an hk model and replace all hn.Linears with my own custom Module or monkey-patch some of their properties. Does haiku currently support something along these lines?

Calling __init__ outside transform

Hey,

I want to create a Haiku compatible library which provides a Keras-like interface called Elegy. I love how easy it is to use Haiku, however, I would like users to code something like this:

module = hk.Sequential([
      hk.Flatten(),
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
])

model = Model(
      module,
      loss=...,
      metrics=dict(
            accuracy=Accuracy(),  # this is probably an hk.Module as well
      ),
      optimizer=...,
)

Instead of checking if __init__ is called inside hk.transform wouldn't it make more sense to check if __call__ is called inside of this context? If this is not possible because you are not able to intercept __call__ maybe having the user implement something like call or apply would be better?

Importing ABC directly from collections will be removed in Python 3.10

Use collections.abc instead. Also collections.abc.KeysView is used in data_structures.py below but the import statement at top is just collections. Accessing collections.abc with just import collections is not supported.

haiku/_src/utils.py
121:      not isinstance(element, collections.Sequence)):

haiku/_src/data_structures.py
80:class KeysOnlyKeysView(collections.abc.KeysView):

haiku/_src/layer_norm.py
79:    elif (isinstance(axis, collections.Iterable) and

haiku/_src/stateful.py
48:    if isinstance(v, collections.Mapping):

KeyError in conditional state (solved)

in a sparsely gated mixture of experts where each expert has state (a memory), there is a KeyError if different experts are activated on apply as happened to be activated on init -- to solve this, you can pass an 'init' flag into your custom module and if it is True, then you just use all the experts on that call. If that's a memory issue, you can use them one by one. Just make sure init hits all conditional branches of the modules with state

(let me know if there's an easier solution)

In haiku, is there something equivalent to flax.nn.Model?

Thanks for creating such a nice library!

My example use case in flax:

nn = ... # flax.nn.Module
 _, nn_params = nn.init(rng_key, data)
model = flax.nn.Model(nn, nn_params) 

With this, i can call model(x) and i can also access current params via model.params.

For haiku

nn = ... #hk.Module 
nn_params = nn.init(rng_key, data)
# i can do this
partial_fun = lambda data: nn.apply(nn_params, rng_key, data)

I can do partial_fun(x) but because partial_fun is a lambda, i can't access nn_params from partial_fun. Wondering if there's any workaround to achieve this in haiku.

For more context, I am trying to integrate haiku with numpyro so that we can convert traditional NN into Bayesian NN. You can see this issue pyro-ppl/numpyro#705 for more context, or this example notebook

Stateful functions and target networks?

Hi there,

I'd like to ask for some advice.

Let's say I've got a function approximator for a q-function that uses hk.{get,set}_state() along with hk.transform_with_state(). This means that my function approximator consists of a triplet func, params, state.

I would like to keep a separate copy of this function approximator, i.e. a target network. This means that I keep separate copy of the params.

Now my question is, would you recommend I also keep a separate copy of the state? And if so, how do we ensure that the usual smooth updates make sense? (e.g. variance typically can't be updated this way unless you map it to the 2nd moment first)

hk.ResetCore Requires A Batch Dimension in Inputs

Hi Haiku team,

Thanks for opensourcing such a great library!

It looks like hk.ResetCore requires the leading dimension of the inputs to be the batch dimension: https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/recurrent.py#L638. However, other RNN cores, like LSTM (https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/recurrent.py#L264), do not have such a requirement. In fact, I found this mismatch when I was using an RNN without a batch dimension in the inputs.

Could you perhaps change hk.ResetCore so that it also works with inputs without a batch dimension?

Thanks!
Zeyu

Write Haiku module with custom gradient

Hi Haiku Team! Thank you for all your work on Haiku.

I'm interested in writing a layer which takes a function as an argument and produces custom gradients. (For context i'm implementing a method to find the stationary point of a function).

A toy example of my implementation in Pure JAX is below and the full implementation here:

@partial(jax.custom_vjp, nondiff_argnums=(1,))
def g(x: jnp.ndarray, fun: Callable):
    return jax.lax.stop_gradient(fun(x))

def g_fwd(x, fun):
    return g(x, fun), x

def g_bwd(fun, res, grad):
    x, = res
    return fun(x)

g.defvjp(g_fwd, g_bwd)

My question: What is the best way to implement this in Haiku?

I see some approaches:

  • Create a JAX function which takes a haiku.Module as an argument (in place of the function). This currently leads to issues with the Module in the backward pass (it doesn't seem to be transformed).

  • Create a haiku.Module with custom gradients. However this is a pure function, so creating a Module feels wrong as it doesn't require Parameters (but may require the Haiku State of another module?).

  • Use haiku.to_module(f). In this approach I'd use get_parameters to access the states of the input function and potentially use some naming conventions to make sure I have the correct scope. I imagine this is the best approach (and maybe very similar to the first approach) - but I really can't find much documentation on how naming variables or accessing them really works!

Would it be possible to share any prior art on how to implement any of these approaches?

Approach 1
import haiku as hk
import jax
import jax.numpy as jnp
from functools import partial
from typing import Callable

def build_net(output_size):
    def forward_fn(x: jnp.ndarray) -> jnp.ndarray:
        linear = hk.Linear(output_size, name='l1')
        x = linear(x)
        return g(x, linear)
    return forward_fn

b_size, s_size, h_size = 1, 2, 3 
input = jnp.ones((b_size, s_size, h_size))
rng = jax.random.PRNGKey(42)
net = build_net(h_size)
net = hk.transform(net)
params = net.init(rng, input)

def loss_fn(params, rng, x):
    return jnp.sum(net.apply(params, rng, x))

print(jax.grad(loss_fn)(params, rng, input))

Gives:

     15 def g_bwd(fun, res, grad):
     16     input, = res
---> 17     return fun(input)
     18 
     19 g.defvjp(g_fwd, g_bwd)

/usr/local/lib/python3.6/dist-packages/haiku/_src/module.py in wrapped(self, *args, **kwargs)
    299     if not base.frame_stack:
    300       raise ValueError(
--> 301           "All `hk.Module`s must be initialized inside an `hk.transform`.")
    302 

ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.

Raise error when user passes bad window/stride for pooling

Currently, MaxPool and AvgPool expect window shape and stride to be "Same rank as value or int.". I think the layers should check the window shape and size and ensure the user conforms to this and throw an error (or perhaps a warning) if they do not.

Coming from other frameworks, users would be used to passing in window shapes/strides with the same rank as the number of spatial dimensions. For example, for a 2D maxpool, users would be used to passing in window_shape=(2, 2), which would resolve to (1, 1, 2, 2) in Haiku which is likely not what the user intended. Throwing an error would force the user to pass in window_shape=(1, 2, 2, 1) or window_shape=2, and would save them debugging time.

Allow creating module instances outside hk.transform

This is as much a question as it is a feature request. What is the reasoning for not allowing a module instance from being created (but not used) outside hk.transform? I took a look at hk.Module and ModuleMetaClass but I feared my soul would get harvested by the dark forbidden magic involved before I could identify all the API features it permits.

For example, I would have expected this to be possible:

linear = hk.Linear(10)  # currently not allowed

def forward(x):
  return linear(x)

model = hk.transform(forward)

Concretely, I'm curious to know what would have to be sacrificed (if anything) to support this kind of usage? Is it meant to prevent a module instance from being used in two different functions wrapped by two different hk.transform calls?

I wouldn't be surprised if I were missing some nasty side effect if you were to allow module creation outside of hk.transform, but, if not, I think it would be more intuitive to allow this kind of usage.

Merge FlatMappings with DeviceArray as values raises `AttributeError: 'DeviceArray' object has no attribute 'items'`

Hi, I'm trying to do meta learning with some slow and fast weights. From the params returned when calling .init I obtain the slow and fast weights like this

params = f.init(...)
fast_weights = params["fast_weights"]

And just before calling the .apply function I want to merge them (the fast weights will be modified). My first attempt was to use the hk.data_structures.merge method like this

params = hk.data_structures.merge(params, fast_weights)
output = f.apply(params, rng, inputs)

But this raises the exception AttributeError: 'DeviceArray' object has no attribute 'items'. I was wondering if this behaviour is wanted or if I should use a different approach for what I want to do.

Thanks!

Total Number of Parameters in Haiku Module

I was wondering if there was a simple way to find the total number of parameters in a haiku model, apart from iterating through the different layers and counting them.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.