poets-ai / elegy Goto Github PK
View Code? Open in Web Editor NEWA High Level API for Deep Learning in JAX
Home Page: https://poets-ai.github.io/elegy/
License: MIT License
A High Level API for Deep Learning in JAX
Home Page: https://poets-ai.github.io/elegy/
License: MIT License
The print callbacks for live monitoring significantly slow down training. For example my GPU utilization drops from 99% to <85% with one of them compared to verbose=2
or verbose=4
. I think I understand why: printing to stdout acts as a .block_until_ready()
for the otherwise asynchronous JAX values in the logs.
I see some options what could be done:
interval
value in the Progbar
class to mitigate this effect and maybe pass it through to .fit()
for e in range(epochs):
x = get_data()
async_y = train_step(x)
for step in range(1, steps_per_epoch):
x = get_data()
future_y = train_step(x)
callbacks.on_batch_end( async_y ) #blocks_until_ready
async_y = future_y
(this would also improve performance for slow data generators)
Elegy uses optax
which is really nice but I am really missing a way to monitor the learning rate during training and maybe log it to tensorboard.
Moreover, it's sometimes inconvenient to compute the right parameters for optax schedules because they use steps and not epochs. For example when I want the learning rate to drop at epoch 10, I first need to convert it to the correct step from the size of the dataset and the batch size.
I don't know how to solve the first issue but for the second one it would be nice to extend the elegy.Model
API with parameters like per_epoch_schedule=
and per_step_schedule=
to which I would pass the optax schedules which would be updated at the end of the epoch or step correspondingly. Even better if I could specify something like per_step_schedule="cosine_decay"
or per_epoch_schedule=['0.1@50%', '0.01@90%']
which would automatically compute the correct step or epoch for me.
Describe the bug
It is not possible to run the "Getting Started" colab notebook as is without getting exceptions.
The notebook is found at https://colab.research.google.com/github/poets-ai/elegy/blob/master/docs/getting-started.ipynb and linked to in the "Getting Started" chapter of the documentation.
When calling model.summary
the following exception occurs:
>>> model.summary(X_train[:64])
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-5-fd239b14293e> in <module>()
----> 1 model.summary(X_train[:64])
16 frames
/usr/local/lib/python3.6/dist-packages/jax/core.py in concrete_aval(x)
829 handler = pytype_aval_mappings.get(typ)
830 if handler: return handler(x)
--> 831 raise TypeError(f"{type(x)} is not a valid JAX type")
832
833
TypeError: <class 'elegy.nn.flatten.Flatten'> is not a valid JAX type
When calling model.fit
the following exception occurs:
>>> history = model.fit(
...... x=X_train,
...... y=y_train,
...... epochs=100,
...... steps_per_epoch=200,
...... batch_size=64,
...... validation_data=(X_test, y_test),
...... shuffle=True,
...... callbacks=[elegy.callbacks.ModelCheckpoint("model", save_best_only=True)],
...... )
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-6-3eecffeab807> in <module>()
7 validation_data=(X_test, y_test),
8 shuffle=True,
----> 9 callbacks=[elegy.callbacks.ModelCheckpoint("model", save_best_only=True)],
10 )
13 frames
/usr/local/lib/python3.6/dist-packages/jax/_src/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
296 if not _arraylike(arg))
297 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 298 raise TypeError(msg.format(fun_name, type(arg), pos))
299
300 def _check_no_float0s(fun_name, *args):
TypeError: prod requires ndarray or scalar arguments, got <class 'tuple'> at position 0.
Minimal code to reproduce
The exceptions can also be reproduced by running the following snippet of code in the Colab:
import numpy as np
import jax.numpy as jnp
import jax
import elegy
import optax
class MLP(elegy.Module):
"""Standard LeNet-300-100 MLP network."""
def __init__(self, n1: int = 300, n2: int = 100, **kwargs):
super().__init__(**kwargs)
self.n1 = n1
self.n2 = n2
def call(self, image: jnp.ndarray) -> jnp.ndarray:
image = image.astype(jnp.float32) / 255.0
mlp = elegy.nn.sequential(
elegy.nn.Flatten(),
elegy.nn.Linear(self.n1),
jax.nn.relu,
elegy.nn.Linear(self.n2),
jax.nn.relu,
elegy.nn.Linear(10),
)
return mlp(image)
X_train = np.random.uniform(0, 256, size=(60000, 28, 28)).astype(np.uint8)
y_train = np.random.randint(0, 10, size=60000).astype(np.uint8)
X_test = np.random.uniform(0, 256, size=(60000, 28, 28)).astype(np.uint8)
y_test = np.random.randint(0, 10, size=60000).astype(np.uint8)
model = elegy.Model(
module=MLP(n1=300, n2=100),
loss=[
elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
elegy.regularizers.GlobalL2(l=1e-4),
],
metrics=elegy.metrics.SparseCategoricalAccuracy(),
optimizer=optax.adam(1e-3),
)
model.summary(X_train[:64])
history = model.fit(
x=X_train,
y=y_train,
epochs=100,
steps_per_epoch=200,
batch_size=64,
validation_data=(X_test, y_test),
shuffle=True,
callbacks=[elegy.callbacks.ModelCheckpoint("model", save_best_only=True)],
)
Expected behavior
The expected output of model.fit
is the summary table:
╒═════════════════════╤═══════════════════════╤═════════════════════╤═════════════════╕
│ Layer │ Outputs Shape │ Trainable │ Non-trainable │
│ │ │ Parameters │ Parameters │
╞═════════════════════╪═══════════════════════╪═════════════════════╪═════════════════╡
│ Inputs │ (64, 28, 28) uint8 │ 0 │ 0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ flatten (Flatten) │ (64, 784) float32 │ 0 │ 0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ linear (Linear) │ (64, 300) float32 │ 235,500 942.0 KB │ 0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ relu │ (64, 300) float32 │ 0 │ 0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ linear_1 (Linear) │ (64, 100) float32 │ 30,100 120.4 KB │ 0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ relu_1 │ (64, 100) float32 │ 0 │ 0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ linear_2 (Linear) │ (64, 10) float32 │ 1,010 4.0 KB │ 0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ Outputs (MLP) │ (64, 10) float32 │ 0 │ 0 │
╘═════════════════════╧═══════════════════════╧═════════════════════╧═════════════════╛
Total Parameters: 266,610 1.1 MB
Trainable Parameters: 266,610 1.1 MB
Non-trainable Parameters: 0
and the expected output of model.fit
is training progress:
...
Epoch 99/100
200/200 [==============================] - 1s 4ms/step - l2_regularization_loss: 0.0452 - loss: 0.0662 - sparse_categorical_accuracy: 0.9928 - sparse_categorical_crossentropy_loss: 0.0210 - val_l2_regularization_loss: 0.0451 - val_loss: 0.1259 - val_sparse_categorical_accuracy: 0.9766 - val_sparse_categorical_crossentropy_loss: 0.0808
Epoch 100/100
200/200 [==============================] - 1s 4ms/step - l2_regularization_loss: 0.0450 - loss: 0.0610 - sparse_categorical_accuracy: 0.9953 - sparse_categorical_crossentropy_loss: 0.0161 - val_l2_regularization_loss: 0.0447 - val_loss: 0.1093 - val_sparse_categorical_accuracy: 0.9795 - val_sparse_categorical_crossentropy_loss: 0.0646
Library Info
Elegy version: 0.2.2
Jax version: 0.2.4
OS info: Colab machine with GPU accelerator
Use jax.nn.log_softmax
Hey, @cgarciae @charlielito I'm trying to add CosineSimalarity
in elegy. But for this I want to use l2_normalize
in either jax or in jax.numpy. Tensorflow is having tf.math.l2_normalize
, So can I use this?. Or I've to implement my own l2_normalize
in jax?
Remove tensorflow related stuff from docs.
In Tf2, this callback creates two files, one for training and one for validation.
I believe a FAQ page may help maintain the documentation simple. Some details like "How to define and monitor a learning rate schedule" are super useful but get in your way on the API Reference. A fact seems like a perfect place to have these snippets with useful information.
Use model.save
to serialize parameters as a callback
Be able to define a keras like loss with just a string, like loss="mse"
mkdocs.yml
MeanSquaredError
Accuracy
MeanSquaredError
SoftmaxCrossEntropy
SigmoidCrossEntropy
Is there any template that we should refer to while creating a PR for a new loss function or a new metric? For example, does the loss has to be inherited from some base class, etc, etc.?
For new contributors it may be helpful to have instructions on building docs, this way one can try something before making the PR. If this does not sound like the expected workflow, we need such description to beware of the workflow.
Right now we define a model function like this:
def model_fn(image) -> jnp.ndarray:
"""Standard LeNet-300-100 MLP network."""
image = image.astype(jnp.float32) / 255.0
mlp = hk.Sequential(
[
hk.Flatten(),
hk.Linear(300),
jax.nn.relu,
hk.Linear(100),
jax.nn.relu,
hk.Linear(10),
]
)
return mlp(image)
and use it like this:
model = elegy.Model(
model_fn=model_fn,
loss=lambda: elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
aux_losses=lambda: elegy.regularizers.GlobalL2(l=1e-5),
metrics=lambda: elegy.metrics.SparseCategoricalAccuracy(),
optimizer=optix.rmsprop(0.001),
)
This has the problem that there is no name scope around these operations, this could be a problem later if you want to reuse/extend this model somehow. An alternative is to define a Module
and force users to instantiate it inside a lambda:
class MLP(hk.Module):
def __call__(self, image) -> jnp.ndarray:
"""Standard LeNet-300-100 MLP network."""
image = image.astype(jnp.float32) / 255.0
mlp = hk.Sequential(
[
hk.Flatten(),
hk.Linear(300),
jax.nn.relu,
hk.Linear(100),
jax.nn.relu,
hk.Linear(10),
]
)
return mlp(image)
model = elegy.Model(
module=lambda: MLP(),
loss=lambda: elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
aux_losses=lambda: elegy.regularizers.GlobalL2(l=1e-5),
metrics=lambda: elegy.metrics.SparseCategoricalAccuracy(),
optimizer=optix.rmsprop(0.001),
)
This is a bit more similar to how you do it in e.g. skorch
.
It also has the added benefit that in the future we could have users get a partial
method classmethod for free to to remove the weird lambda:
if they inherit from our own elegy.Module
. We could in-fact extend this strategy to losses and metrics as well:
model = elegy.Model(
module=MLP.partial(),
loss=elegy.losses.SparseCategoricalCrossentropy.partial(from_logits=True),
aux_losses=elegy.regularizers.GlobalL2.partial(l=1e-5),
metrics=elegy.metrics.SparseCategoricalAccuracy.partial(),
optimizer=optix.rmsprop(0.001),
)
lambda
s have the downside of not being serializable so we might have to resort to something like this anyway.
Create elegy.Module
that inherites from hk.Module
and overrides __call___
to capture output shapes.
MIT -> Apache since TF is Apache.
The docs for elegy.nn.Sequential
are not clear to me. The example in the docs has nothing to do with this module
https://poets-ai.github.io/elegy/api/nn/Sequential/#elegy.nn.sequential_module.Sequential
Is your feature request related to a problem? Please describe.
In cpu-only mode it's not clear if parallelization across CPUs is happening and controllable.
It's also would be necessary to scale it to multiple nodes.
Describe the solution you'd like
Using something like horovod
, dask
, ray
or even byte-ps
would be proper way to implement it.
Additional context
Multi-gpu is also another way to make that happen - but I would start from mult-cpu.
Currently there are some low level hooks & context managers like elegy.hooks_context
that are exposed but not documented.
I'm trying to add a new loss MeanSquaredLogarithimcError
in elegy. I've created the mean_squared_logarithmic_error.py
and also mean_squared_logarithmic_error_test.py
. But while executing pytest
, I'm getting error
AttributeError: module 'elegy.losses' has no attribute 'MeanSquaredLogarithmicError'
and
AttributeError: module 'elegy.losses' has no attribute 'mean_squared_logarithmic_error'
PS: I've also defined necessary imports in __init__.py
.
Please see the attached image.
cc: @Davidnet @charlielito @cgarciae @haruiz
train_on_batch
test_on_batch
predict_on_batch
fit
evaluate
predict
Yeah yeah I know I should be using a case-sensitive file-system, but regardless these should be renamed IMHO:
$ git clone --depth=1 https://github.com/poets-ai/elegy
Cloning into 'elegy'...
remote: Enumerating objects: 311, done.
remote: Counting objects: 100% (311/311), done.
remote: Compressing objects: 100% (262/262), done.
remote: Total 311 (delta 76), reused 132 (delta 42), pack-reused 0
Receiving objects: 100% (311/311), 533.67 KiB | 676.00 KiB/s, done.
Resolving deltas: 100% (76/76), done.
warning: the following paths have collided (e.g. case-sensitive paths
on a case-insensitive filesystem) and only one from the same
colliding group is in the working tree:
'docs/api/losses/Huber.md'
'docs/api/losses/huber.md'
'docs/api/metrics/Accuracy.md'
'docs/api/metrics/accuracy.md'
'docs/api/metrics/F1.md'
'docs/api/metrics/f1.md'
'docs/api/metrics/Precision.md'
'docs/api/metrics/precision.md'
'docs/api/metrics/Recall.md'
'docs/api/metrics/recall.md'
'docs/api/metrics/Reduce.md'
'docs/api/metrics/reduce.md'
'docs/api/nn/Sequential.md'
'docs/api/nn/sequential.md'
Hey @alexander-g!
Now that the RestNet*
architectures are on master
I think we should add them to the documentation. This is done by adding the __all__
constant to elegy/nets/__init__.py
plus including "resnet"
in the list, and then adding the __all__
constant to elegy/nets/resnet.py
and including all the "ResNet*"
names. Apart from that, basic docs are missing from the ResNet
class and __init__
.
To run the documentation locally you can execute:
scripts/run-docs
This will compile the documentation with mkdocs
and spin up a server.
Hi @cgarciae , do you think it's possible to implement pure jittable training functions in Elegy? i.e. a function that does not depend on or modifies global state:
logs, new_parameters, new_state = train_fn(inputs, old_parameters, old_state)
I see several reasons to have something like that:
I've taken a look at the module.jit()
function, the _jit_fn()
without the wrapper
seems to come close to what I want but not exactly because it modifies global state. Does it really require states_tuple
and statics
parameters? The first one is literally not used, and the second one does not seem to get modified anywhere, or does it? Can we get rid of .set_parameters()
in model_base.py
and the global RNG key?
For the beginning it would be maybe enough to have a clean_up()
function that gets called after tracing the jitted training function or a context manager which is basically what wrapper()
already is. Would be of course better if it was completely pure.
The parameters.h5
file saved by Model.save
is way larger than the model parameters. The factor is not fixed and varies between 3 to 5 depending on the model.
model = elegy.Model(elegy.nets.resnet.ResNet18())
model.predict(np.zeros([1,224,224,3]))
print('Size of model parameters:', len(pickle.dumps(model.get_parameters()))/1e6, 'MB' )
model.save('tmp')
print('Size of parameters.h5:', len(open('tmp/parameters.h5', 'rb').read())/1e6, 'MB')
Output:
Size of model parameters: 46.848071 MB
Size of parameters.h5: 238.33064 MB
Additionally the function creates an annoying, unrelated warning:
FutureWarning: The Panel class is removed from pandas. Accessing it from the top-level namespace will also be removed in the next version
elif _pandas and isinstance(level, (pd.DataFrame, pd.Series, pd.Panel)):
Version: Latest commit in master
Create an example using TPUs for training. IT can be in Colab
One of the selling points of jax is the pmap transformation, but best practices around actually getting your training loop parallelizable still is confusing. What is elegy's story around multigpu training? Is it possible to get to pytorch-lightning like api as a single arg to model.fit?
Describe the bug
The accuracy reported by Model.fit()
or Model.evaluate()
is different from manually computed via model.predict(x).argmax(-1) == y).mean()
Minimal code to reproduce
Colab Notebook
Expected behavior
Both should be the same.
Library Info
Latest commit in master 19ec87b
Hey, @cgarciae @charlielito Is it possible to add a community example repo for elegy?
Where the community can create their elegy example in the form of colab notebook and then contribute it?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.