Comments (32)
π― to this change. It aligns the mental model with TF2's tf.Module
and PyTorch's nn.Module
a lot more, and both of these have converged to where they are now after many years of mistakes, so this is a good thing.
(NOTE: We may choose to keep the safe-guarding behavior of .shared() that makes it hard to accidentally copy and paste code that accidentally re-uses modules. We can achieve that by having modules default to raising an error when call is invoked a second time, unless .shared() was called on the module instance first)
Please don't. Re-using an instance is the common, intuitive, friction-less way of sharing weights; this would just add annoying overhead for the sake of avoiding a mistake which, frankly, I have never encountered. An explicit :share
method was how it was done in Torch7, and it was annoying and painful and does not exist anymore in PyTorch.
Regarding the __init__
vs __call__
separation, I don't think that it makes good code impossible, so if someone creates a monster hydra code because of that, it's probable the author's fault, not the library's. Using dataclass (or attr.s
) for this is an interesting idea. However, usually what is done in __init__
is just normalizing of convenience of parameters, for example allowing filter-size to be passed as (3,3)
or as 3
, and then turning 3
into (3,3)
in __init__
, such that __call__
is cleaner to read, and really you can skip reading __init__
with that in mind. I think this is a good thing.
Finally, I think you can have an even more convincing example for modules which have more than just the obvious __call__
, like the VAE example here which currently is not trivial to understand: I either have to do a lot of guess-work about FLAX internals, or go back and read the whole docs. Whereas after your proposal (and in PyTorch) it can be much more straightforward.
from flax.
Nice, I like this change. It is a good start.
However, if you are making such a breaking change, this feels too conservative.
Core Issues:
- This function still violates Pythonic conventions.
nn.Dense
is seemingly making mutable changes to some internal state buffer that is invisible to the user and not transparent in the syntax. (I know this happens in TF, but flax should be better.)
def __call__(self, x):
x = nn.Dense(features=16)(x)
x = nn.relu(x)
x = nn.Dense(features=16)(x)
=> Does this mean?
def __call__(self, x):
x = nn.Dense(self, features=16)(x)
x = nn.relu(x)
x = nn.Dense(self, features=16)(x)
(Or alternatively pytorch / sonnet 2 syntax which both do this better)
- Params are still treated differently than Layers, and use a string-based naming which seems dangerous and tempting for abuse.
bias = self.param('bias', (self.features,), self.bias_init)
=> ?
bias = nn.Param(self, (self.features,), self.bias_init)
from flax.
I would be in favor of this change, as the proposed way of creating new modules is better aligned with my default mental model: layers are classes and particular instances of these layers (with associated weights and parameters) are objects that process the data by being called. I also fully agree that variable sharing will become more intuitive.
Besides, I actually see a separation between __init__
and __call__
as a potential win. Conceptually, I imagine that __init__
should admit static parameters, like number of channels, and __call__
should admit actual data that is processed. Currently, these different types of parameters are all mixed together.
from flax.
Clearly in
apply_a
we expect the Dense parameters to be shared when we callapply_a
multiple times on an instance of MyModule but not between iterations of the loop. But what if we take thedense_factory
fromapply_a
and turn it into a method (_dense_factory
)? A seemingly innocent refactor will now cause all the Dense modules inapply_b
to be shared.
My expectation from reading this code is that all Dense
parameters in both examples would be unshared. If you want to use the same parameters, you need to use the same Dense
object.
from flax.
Speaking as a teacher, the XLA docs scare me. They are very jargon heavy. It would be like asking numpy students to read the blas docs.
from flax.
Thanks @avital !
I really like the new api, thanks for putting the work into it and being direct about the tradeoffs. I will definitely be using it for my next project. (probably without @nn.compact , but that is totally okay if they are compatible).
I found this helpful: https://colab.research.google.com/github/google/flax/blob/master/docs/notebooks/linen_intro.ipynb
from flax.
I would be in favor of this change, as the proposed way of creating new modules is better aligned with my default mental model: layers are classes and particular instances of these layers (with associated weights and parameters) are objects that process the data by being called. I also fully agree that variable sharing will become more intuitive.
Yes. One caveat is that while layers have have parameters on them, those parameters will be immutable and you'd still need to mutate your parameters at the top-level rather than within your module. This is due to our desire to allow you to use vanilla transformations such as jit
and pmap
which don't work with mutations.
Besides, I actually see a separation between init and call as a potential win. Conceptually, I imagine that init should admit static parameters, like number of channels, and call should admit actual data that is processed. Currently, these different types of parameters are all mixed together.
Yes. The issue is that by simply letting people use __init__
and __call__
arbitrarily, you many times end up with things like this, where you really have to move up and down many times to be able to fully follow the flow of what the module's forward pass does. Hence the restriction from using dataclasses encourages the __init__
to be as dumb as possible.
from flax.
Thanks for this proposal! I agree with the other comments:
- How
__init__
and__call__
might separate responsibility (user-created monster hydras not withstanding @lucasb-eyer) - Removing
.shared()
. I understand the rationale for keeping (one less thing to debug), but in this case, it makes sense to opt for less friction vs. more safety if that's the common user expectation. If we really wanted to be extra, we could provide some flax linting utilities (FL201: Did you mean to reuse a module?) :)
Yes. One caveat is that while layers have have parameters on them, those parameters will be immutable and you'd still need to mutate your parameters at the top-level rather than within your module. This is due to our desire to allow you to use vanilla transformations such as
jit
andpmap
which don't work with mutations.
Do you mean passing modules directly into jit? One of the things I tried to do away with during my weekend excursion was flax.nn.Model
, given the constraint that flax.nn.Module
must be immutable. The solution was not great: have an instance method that returns a new flax.nn.Module
when you update parameters or state.
from flax.
In general I really like the look of this! I think it would be a significant improvement/simplification of Flax's mental model.
π for eliminating the use of __new__
in Modules
π for eliminating .partial()
.
π for eliminating .shared()
. I don't think we need the safeguard -- it is quite common to intentionally reuse models in neural net code
π for encouraging the use of dataclasses
(in particular, @dataclass(frozen=True)
to enforce immutability)
π for requiring dataclasses, and not allowing __init__
methods to be written explicitly. Even if this were possible to enforce in a clean way (I have my doubts), sometimes __init__
can be a nice way to write this, as @lucasb-eyer writes in #208 (comment).
π for the proposed transition plan, which looks quite practical.
from flax.
One question arises: how does this change effect (if at all) with the way we initialize Flax models? Do we still stick with Module.init
and call
methods, except these are now normal methods instead of class methods?
from flax.
π to everything said by @lucasb-eyer, @srush, and @shoyer. I think having separate __init__
and __call__
is actually a huge net positive. It allows people to just think in Python instead of "thinking in Flax" like we have to do with TF.
FWIW, I don't see the hydra thing as much of a disadvantage. In many cases it requires people to be more explicit, and you can see what's going on in the submodule itself, instead of hiding things in implicit behind the scenes work. It also then makes it easier to access model attributes from outside the module if you want to hack things later, say in a colab notebook.
Also, i think it's great to allow access to __call__
directly, rather than redirecting to some other function like apply
. I'm running into challenges with this in Keras at the moment, as I'm trying to work around some aspects of the forced programming model, but it's inflexible if I only have access to call
and not __call__
, and requires me digging deep into the Keras base layer code, which is a mess. Let's not make the same mistake for Flax.
from flax.
This should also help with my confusion #16 (comment) where not calling partial before create_by_shape results in the model being created with different parameters to what it was trained with.
from flax.
Can you add an example of how this would work with an equivalent to module_method
?
from flax.
(Or alternatively pytorch / sonnet 2 syntax which both do this better)
@srush Could you kindly clarify what you mean by this?
Is this just a reference to how PyTorch / Sonnet 2 use explicit attribute assignment for submodules? e.g., self.dense = nn.Dense(features=16)
?
This does make module hierarchies and when mutation is happening very clear. The downside is that layers get specified in __init__
, which is separated from where they are used.
from flax.
Is this just a reference to how PyTorch / Sonnet 2 use explicit attribute assignment for submodules? e.g.,
self.dense = nn.Dense(features=16)
?This does make module hierarchies and when mutation is happening very clear. The downside is that layers get specified in
__init__
, which is separated from where they are used.
@shoyer To be clear, I think a lot of people consider that actually consider that an upside. It separates creation/ownership from usage, so it's much clearer when reuse is happening, and easier to access submodules from outside the class itself for more creative routing of shared parameters.
The mental overhead of having a little boiler plate is a small price to pay for such explicit clarity and python native interaction paradigms (using python's built-in object attributes, vs. some behind the scenes implicit naming schemes)
from flax.
To be clear, I think a lot of people consider that actually consider that an upside. It separates creation/ownership from usage, so it's much clearer when reuse is happening, and easier to access submodules from outside the class itself for more creative routing of shared parameters.
Absolutely, these are all real advantages. On the other hand, I've also had cases where separating initialization/use of layers made my code harder to read and modify because two different parts of the code need to be kept in sync. You also can't use input shapes to determine the shapes of variables. It is not clear to me (personally) which is better/worse in general. It may depend on the context.
Keras lets you write things both ways, which is convenient for users, but of course imposes an even higher cost in terms of complexity.
For JAX, there is one additional consideration, which is whether the module abstraction is amenable to functional transformations -- one of the core strengths of JAX. My understanding is that this is hard to do with Python's mutable object model.
from flax.
I consider inline initialization a Keras design flaw. It mixes functional and structural concerns and makes it very hard to reason about, document, and analyze modules.
However, whether or not you agree with this, the fact that it is causing the library to have ill-defined semantics, with very minimal benefits ("less scrolling up?"), should be a red flag that it is maybe a problem.
from flax.
I think we should not use worlds like "normal" or "pythonic". They are really vague statements that essentially refer to similarity with existing programing paradigms that are common in the Python world. We shouldn't strive to please the status quo.
I think the points raised by @srush are important. Although sharing becomes clearer with explicit construction it still isn't quite like an object that owns it's parameters.
Consider the following example:
class MyModule(flax.Module):
def apply_a(self, x):
def inner_dense_factory():
return nn.Dense(123)
for i in range(3):
x = inner_dense_factory()(x)
return x
def apply_b(self, x):
for i in range(3):
x = self._dense_factory()(x)
return x
def _dense_factory(self):
return nn.Dense(123, self.my_fancy_init)
Clearly in apply_a
we expect the Dense parameters to be shared when we call apply_a
multiple times on an instance of MyModule but not between iterations of the loop. But what if we take the dense_factory
from apply_a
and turn it into a method (_dense_factory
)? A seemingly innocent refactor will now cause all the Dense modules in apply_b
to be shared.
Of course we can add annotation trickery to distinguish between module methods that have a scope and "inline methods"? But the mental model is still significantly more complex than plain old Python objects.
from flax.
Perhaps I am missing something, but I don't really understand the example above. The implied semantics feel really complicated to me as state seems to bind to functions in a way I cannot trace.
Btw, I don't know if it is helpful, but here is a proof-of-concept of the sort of pure world I like (not saying flax needs to go this way).
https://github.com/srush/parallax
# Everything is immutable @module = dataclass(frozen=True, repr=False)
@module
class Dense(Module):
# All parameter-holders are explicitly declared.
weight : Parameter
bias : Parameter
# Setup replace __init__ and creates shapes and binds lazy initializers.
@staticmethod
def setup(in_size, out_size):
return Dense.init(
weight = Parameter.setup((out_size, in_size), init.xavier_normal_),
bias = Parameter.setup((out_size,), init.normal_))
# Forward is just like standard pytorch.
def forward(self, input):
return self.weight @ input + self.bias
"Sharing" would requires a manual split of the parameter into two parts like this.
@module
class BinaryNetwork(Module):
# No difference between modules and parameters
dense1 : Dense
dense2 : Dense
dense3 : Dense
dropout : Dropout
@staticmethod
def setup(input_size, hidden_size):
return BinaryNetwork.init(
dense1 = Dense.setup(input_size, hidden_size),
dense2 = Dense.setup(hidden_size, hidden_size),
dense3 = Dense.setup(hidden_size, 1),
dropout = Dropout.setup(rate=0.2)
)
def forward(self, input):
# Standard usage works out of the box.
x = torch.tanh(self.dense1(input))
# Stochastic modules (have random seed already)
x = self.dropout(x)
# Shared params / recurrence requires split (like RNG)
dense2_a, dense2_b = self.dense2.split(2)
x = torch.tanh(dense2_a(x))
x = torch.tanh(dense2_b(x))
return torch.sigmoid(self.dense3(torch.tanh(x)))
from flax.
yep, was about to say the same as @shoyer the example is convoluted, but we are creating a new Dense object each time, so would definitely not expect weight sharing. Any sharing happening in that code would be weird magic happening under the hood that is very confusing.
from flax.
@srush I fail to see how your example semantically differs from plain PyTorch/nn code? It's "create object at init, use object to apply at forward" semantics, the remaining differences from plain PyTorch/nn look like mostly syntax to me? (edit: not saying this is bad, I like PyTorch/nn)
from flax.
@lucasb-eyer Sorry, I should have explained better. The fact that it looks like pytorch syntax is a red-herring, unlike pytorch the implementation is pure / immutable.
It's "create declarative skeleton at init, (engine fills in tensors), (engine distributes RNG to module), use objects statelessly to apply at forward"
layer = BinaryNetwork.setup(5, 10)
# Initialize parameters -> stateful, hidden
rng = rng_state()
layer = layer.initialize(rng)
for i in range(10):
rng = rng_state()
layer = layer.init_state(rng, mode="train")
grad = grad(layer.forward)(x)
layer = layer.update(lambda a, b: a + b, grad)
from flax.
I see, yeah I was missing the "use it" code, should've checked your repo. My personal opinion is that classes are the wrong concept to build something pure/immutable/functional.
A few colleagues and I have an internal codebase built on jax, which uses flax in a completely pure/functional way, and flax was open to some design changes to make using flax in that way possible and nice. I think it is very close to your example code actually. We made a simplified version of it public just now, see here: https://github.com/google-research/big_transfer/tree/master/bit_jax
However, all of this pure, pretty, neat, readable stuff goes to π£ π© β‘ the moment you want to add BatchNorm :)
from flax.
Nice I will check it out. Maybe what needs to happen is for the jax community to just have nn.functional module like pytorch so different module systems can use the same layers.
@lucasb-eyer I am still just stuck on one point that is keeping me bother by all these solutions: When you read this code below what is the internal/informal semantics that is going on in your head. Particularly: Where do you imagine that name is stored? do you believe this code knows it is in an object? Do you have a type in your head of x? How do you reason about whether this line of code knows if it is the first or last time it is called? Could this code be tested independently of its system?
x = nn.Dense(x, num_classes, name="conv_head", kernel_init=nn.initializers.zeros)
Until I can answer these questions, I just can't imagine this will be the final state of a reliable module system.
from flax.
but jax.lax
and jax.numpy
pretty much correspond to nn.functional
:) The next step is deciding how bookkeeping of variables/parameters happens, and that is where all the frameworks opinions differ (and mine differs again, and so does yours).
Regarding your second paragraph, I agree that the line has too much magic (also, where are the dense's w/b tracked? a global collection maybe? π¨) And my understanding is that @avital 's proposal in the OP is exactly about reducing this magic and, effectively, being closer to "plain python" or PyTorch semantics.
from flax.
but jax.lax and jax.numpy pretty much correspond to nn.functional :)
That's not true, nn.functional is clean functional nn implementations of conv/dense/rnn/etc that could be used with any module system, none of that is in jax.lax or jax.numpy : https://pytorch.org/docs/stable/nn.functional.html
The next step is deciding how bookkeeping of variables/parameters happens,
I agree. That's what I'm interested in.
And my understanding is that @avital 's proposal in the OP is exactly about reducing this magic and, effectively, being closer to "plain python" or PyTorch semantics.
It gets halfway there, I'm arguing it needs to be really solved.
from flax.
@lucasb-eyer Very neat paper though!
from flax.
That's not true, nn.functional is clean functional nn implementations of conv/dense/rnn/etc that could be used with any module system, none of that is in jax.lax or jax.numpy
Not true either. jax.lax
has a pretty powerful implementation of conv (jax.lax.conv_general_dilated
), similar for pooling (jax.lax.reduce_window
) and, linear (jax.lax.dot_general
).
I was about to concede it's missing an RNN, but there is actually none in nn.functional
either. The only remaining non-trivial entry of nn.functional
that is missing from jax.{nn,lax,numpy}
is ctc_loss
, and I'm sure jaxers would happily accept a PR for jax.nn.ctc_loss
. So I maintain my point that torch.nn.functional
β jax.{nn,lax,numpy}
.
It gets halfway there, I'm arguing it needs to be really solved.
I went back to read it, and I actually agree with the points in your first comment in this thread.
Thanks :)
from flax.
Oh well, now I feel silly. It does seem like the lax functions just are much more general than the pytorch implementations. I honestly never found reduce_window on my own (the doc of "Wraps XLAβs ReduceWindow operator" doesn't really help). The Stax implementation does make it clear though.
from flax.
No worries. jax.lax
is extremely powerful, I like its API a lot (reminds me of BLAS, but in times of XLA) and is criminally under-documented!
from flax.
FYI (you may already know this) most of the ops in lax (things like reduce_window) are documented in more detail here. I guess we ought to copy more of those docs over to JAX.
from flax.
It's been a while, and sorry for not posting more in this thread. We've gone through a major API redesign aligned with the goals originally described in this thread.
Our new Linen API came out of many user group discussions, trying to find a solution that empowers our users, while staying relatively simple and exposes the full power of JAX.
All of our examples have been ported, and multiple large projects have transitioned using our upgrade guide, so now we're making it the official API.
Please check it out! Please ask any questions or suggestions for improvements on our discussion board.
The old flax.nn
API is being deprecated.
from flax.
Related Issues (20)
- Error when calling `Module.tabulate` on normalization wrappers like `WeightNorm` and `SpectralNorm`
- Orbax checkpoint for LogicallyPartitioned params HOT 2
- For some reason these imports are elided on read the docs HOT 1
- Using variable declared at a broader scope in a function is bad form HOT 1
- Add `BatchRenorm` layer to `linen.normalization`
- GroupedConv distributed training failure
- In `MultiHeadAttention`, let `num_heads=1` by default
- Documentation/notebook errors HOT 2
- Remove `tree_map` deprecation filter after Flax upgrades minimum Python version to 3.10
- Unpickled modules with constructor arguments cannot be initialized
- Improve SEO for docs pages HOT 2
- Add ability to easily change documentation version
- Problem while using checkpoints.restore_checkpoint with gradio HOT 1
- nnx static fields not part of static tree structure HOT 1
- nn.remat_scan doesn't work with nn.with_partitioning HOT 1
- No way to call nnx.State.from_flat_path HOT 5
- Tutorial request HOT 2
- with_partitioning has surprising behavior with MultiHeadAttention and DenseGeneral HOT 1
- nnx.graph.split infinite recursion when used in a thread HOT 3
- Documentation links 404 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flax.