Comments (10)
Support for extracting module info in a creator has landed 😄 Here's an example colab using it to extract all info to a dict outside the function: https://colab.research.google.com/drive/1tt9ifYFsxvSSXaFAz_Oq59Im8QY4S16o
Using it inside a transformed function is documented here: https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.experimental.custom_creator
from dm-haiku.
@sharadmv has done a lot of thinking about probabilistic programming in JAX (outside of Haiku) and might have some useful input for us here.
After initialization, a researcher needs information about passed bijector for different reasons, that could be monitoring or debugging an algorithm. Does it make sense?
Absolutely.
Also, I don't really like
self.s = hk.get_parameter("struct", (), None, init=lambda *_: S(x, y))
line, that looks hacky and I would prefer to have a function for getting a structure, e.g.hk.get_structure("name", getter="")
.
Agreed that it is ugly looking, I like your suggestion, I think we should probably call this get_parameter_tree
to make it clear that it is strongly related to get_parameter
(e.g.s = hk.get_parameter_tree("s", init=lambda: S(a, b))
). I'm happy to add that, will close out this issue with a commit later today.
Is there anything else in Haiku getting in your way for this type of research?
from dm-haiku.
Hey @mattwescott and @awav , sorry for the delay implementing this. Before adding to core I want to think carefully about how it will interact with JAX transforms, especially when those transforms are used inside a haiku transformed function (e.g. via hk.jit
).
For now you should be able to use this without needing changes in Haiku by adding the following utility function in your code and using it in your modules (it is slightly ugly since it adds a "Box" type around your type, but otherwise this should unblock you):
from typing import Any, NamedTuple
class Box(NamedTuple):
value: Any
shape = property(fget=lambda _: ())
def get_parameter_tree(name, init):
return hk.get_parameter(name, [], init=lambda *_: Box(init())).value
You can use it as so:
>>> def f():
... p = get_parameter_tree("w", lambda: (jnp.ones([]), jnp.zeros([])))
... return p
>>> hk.transform(f, apply_rng=True).init(None)
frozendict({
'~': frozendict({
'w': Box(value=(DeviceArray(1., dtype=float32), DeviceArray(0., dtype=float32))),
}),
})
On a related note, for some parameter transformations it is useful to know the type of the corresponding module. Is this accessible in haiku without adding type information to module names?
It isn't right now, the closest we have is hk.experimental.custom_creator
which allows you to intercept parameter creation, one thing people have been using this for at DeepMind is to stash all the initializers for their parameters:
>>> inits = {}
>>> def creator(next_getter, name, shape, dtype, init):
... inits[name] = init
... return next_getter(name, shape, dtype, init)
>>> f = lambda: hk.nets.MLP([300, 100, 10])(jnp.ones([1, 1]))
>>> f = hk.transform(f, apply_rng=True)
>>> with hk.experimental.custom_creator(creator):
... f.init(jax.random.PRNGKey(42))
>>> inits
{'mlp/~/linear_0/w': <haiku._src.initializers.TruncatedNormal at 0x7f28476df5f8>,
'mlp/~/linear_0/b': <function jax.numpy.lax_numpy.zeros>,
'mlp/~/linear_1/w': <haiku._src.initializers.TruncatedNormal at 0x7f283125b048>,
'mlp/~/linear_1/b': <function jax.numpy.lax_numpy.zeros>,
'mlp/~/linear_2/w': <haiku._src.initializers.TruncatedNormal at 0x7f28476df358>,
'mlp/~/linear_2/b': <function jax.numpy.lax_numpy.zeros>}
I could imagine extending this custom getter to also pass the module
as well as the init
function, then you could keep a copy of type(module)
. Would that be useful for you?
from dm-haiku.
@tomhennigan, I found out that flax
has a support for dataclasses
and it has all what I needed (a big part of it). I haven't tried it with haiku
, but I believe it should work with haiku
out of the box. JAX must work with dataclass
implicitly, but looks like it cannot, without flax
at least. Do you have plans for doing a similar thing?
from dm-haiku.
Hi @awav, thanks for trying Haiku!
I think there are two Haiku assumptions that you are challenging here:
- Parameters are always
jnp.ndarray
instances (we assume this when checking the return value ofget_parameter
but otherwise this is not a hard requirement). - Modules are always temporary objects that are deleted when transformed functions return (by returning a module from a function you violate this, although calling repr or getting non-computed properties should work). This is a fairly hard requirement, we work hard in Haiku to make sure when you use
transform
that the result is pure (wrt. Haiku API calls) and this would not be the case if modules existed outside transform (there would then need to be a global scope for them to find parameters/state).
I think we can make this work, concretely I would suggest:
- Use
NamedTuple
to defineS
(no need to register it as a custom pytree then 😄). - Separate the data structure from the module (e.g. we have
S
andSModule
). - We use
SModule
to createS
instances and useget_parameter
to mark them as parameters. - As a temporary workaround we make
S
have ashape
property so it looks like a parameter (we could relax this in Haiku).
Putting that all together:
import jax
import jax.numpy as jnp
import haiku as hk
from typing import NamedTuple
class S(NamedTuple):
x: jnp.ndarray
y: jnp.ndarray
@property
def shape(self):
# Hack to workaround the fact that `get_parameter` checks tensor shapes.
return ()
class SModule(hk.Module):
def __init__(self, x, y, name=None):
super().__init__(name=name)
self.s = hk.get_parameter("struct", (), None, init=lambda *_: S(x, y))
def __call__(self, x, a):
return jnp.sqrt(self.s.x ** 2 * self.s.y ** 2) * x * a
def loss(x):
s = SModule(1.0, 2.0)
a = hk.get_parameter("free", shape=(), dtype=jnp.float32, init=jnp.ones)
y = s(x, a)
return jnp.sum(y)
loss = hk.transform(loss)
x = jnp.array([2.0])
key = jax.random.PRNGKey(42)
params = loss.init(key, x)
jax.grad(loss.apply)(params, x)
Output:
frozendict({
's_module': frozendict({
'struct': S(x=DeviceArray(2., dtype=float32), y=DeviceArray(1., dtype=float32)),
}),
'~': frozendict({'free': DeviceArray(0., dtype=float32)}),
})
If this looks good then I'm happy to make a change to get_parameter
to support parameters that are trees (e.g. we only check the shape if the result of get_parameter
is an ndarray
instance.
WDYT?
from dm-haiku.
@tomhennigan, for a very simple case, the namedtuple approach will work. However, the main challenge is the implementation of transformed parameters.
The parameter with a constraint would look like this:
class Parameter:
def __init__(self, init_constrained_value: jnp.ndarray, constraint: tfp.bijectors.Bijector):
# NOTE: Compute gradients w.r.t. this unconstrained value!!!
self._unconstrained_value = constraint.inverse(init_constrained_value)
self._constraint = constraint
# NOTE: convert the value in unconstrained space to the value in constrained space
def constrained_value(self):
return self._constraint.forward(self._unconstrained_value)
def __call__(self):
return self.constrained_value()
def loss(x):
p = Parameter(1.0, tfp.bijector.Exp())
return jnp.square(p())
def loss_complex(x):
class ProbModel:
def __init__(self):
self.variance = Parameter(1.0, tfp.bijector.Exp())
def __call__(self, x):
pass
m = ProbModel()
return m(x)
After initialization, a researcher needs information about passed bijector for different reasons, that could be monitoring or debugging an algorithm. Does it make sense?
Also, I don't really like self.s = hk.get_parameter("struct", (), None, init=lambda *_: S(x, y))
line, that looks hacky and I would prefer to have a function for getting a structure, e.g. hk.get_structure("name", getter="")
.
from dm-haiku.
@tomhennigan your get_parameter_tree
proposal would be useful to me. On a related note, for some parameter transformations it is useful to know the type of the corresponding module. Is this accessible in haiku without adding type information to module names?
from dm-haiku.
@tomhennigan thanks for the examples.
I could imagine extending this custom getter to also pass the module as well as the init function
This would be great, so much cleaner!
from dm-haiku.
Would it be impractical to instead intercept module creation? With a mapping from module names to types, could use tree.flatten_with_path_up_to
for straightforward type-dependent transformations of the parameter tree.
Either approach would likely be sufficient for me to adopt Haiku.
from dm-haiku.
whoa, that struct.dataclass is cool, and would solve headaches of passing modules to functions and getting not a JAX type
errors
from dm-haiku.
Related Issues (20)
- More fine-grained mixed-precision configuration HOT 2
- Suggestion: alias `Transformed`(WithState) apply to __call__ HOT 2
- Is there a way to load parameters from Flax model? HOT 2
- Support model examples HOT 7
- Change to jax.interpreters.xla for JAX==0.4.14 HOT 3
- Warning: hk.LayerNorm when used in transformer decoder causes violation of autoregressive property HOT 1
- Reservoir Computing with Haiku
- Efficiency difference in using jax.lax.fori_loop vs looping over identical layers? HOT 2
- Please publish requirements.txt fix to pip
- How to use `apply` with additional parameters? HOT 1
- hk.Conv2DTranspose takes FOREVER to initialize and compile HOT 1
- 0.4.16 timeline HOT 2
- How to export haiku network parameters into Pytorch network?
- Modules got silently "reused" with `hk.vmap` HOT 2
- Wrong gradients in a Haiku network
- Direct Feedback Alignment
- Issue with wheels including docs and examples folder
- `haiku.experimental.flax` is not part of newest pip release HOT 1
- Train multiple hk.nets.MLP with one optimizer HOT 2
- TypeError: 'type' object is not subscriptable HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dm-haiku.