Giter Site home page Giter Site logo

elegy's People

Contributors

abhinavsp0730 avatar alexander-g avatar anvelezec avatar cgarciae avatar charlielito avatar ciroye avatar davidnet avatar github-actions[bot] avatar haruiz avatar lkhphuc avatar samuelmarks avatar sebasarango1180 avatar sooheon avatar soumik12345 avatar srcolinas avatar vladdoster avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

elegy's Issues

Verbosity options 1 and 3 slow down training

The print callbacks for live monitoring significantly slow down training. For example my GPU utilization drops from 99% to <85% with one of them compared to verbose=2 or verbose=4. I think I understand why: printing to stdout acts as a .block_until_ready() for the otherwise asynchronous JAX values in the logs.
I see some options what could be done:

  • add this as a hint to the documentation
  • increase the default interval value in the Progbar class to mitigate this effect and maybe pass it through to .fit()
  • pipeline the training loop to something like this:
for e in range(epochs):
  x = get_data()
  async_y = train_step(x)
  for step in range(1, steps_per_epoch):
    x = get_data()
    future_y = train_step(x)

    callbacks.on_batch_end( async_y ) #blocks_until_ready
    async_y = future_y

(this would also improve performance for slow data generators)

[Feature Request] Monitoring learning rates

Elegy uses optax which is really nice but I am really missing a way to monitor the learning rate during training and maybe log it to tensorboard.
Moreover, it's sometimes inconvenient to compute the right parameters for optax schedules because they use steps and not epochs. For example when I want the learning rate to drop at epoch 10, I first need to convert it to the correct step from the size of the dataset and the batch size.

I don't know how to solve the first issue but for the second one it would be nice to extend the elegy.Model API with parameters like per_epoch_schedule= and per_step_schedule= to which I would pass the optax schedules which would be updated at the end of the epoch or step correspondingly. Even better if I could specify something like per_step_schedule="cosine_decay" or per_epoch_schedule=['0.1@50%', '0.01@90%'] which would automatically compute the correct step or epoch for me.

Exceptions in "Getting Started" colab notebook

Describe the bug
It is not possible to run the "Getting Started" colab notebook as is without getting exceptions.

The notebook is found at https://colab.research.google.com/github/poets-ai/elegy/blob/master/docs/getting-started.ipynb and linked to in the "Getting Started" chapter of the documentation.

When calling model.summary the following exception occurs:

>>> model.summary(X_train[:64])

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

<ipython-input-5-fd239b14293e> in <module>()
----> 1 model.summary(X_train[:64])

16 frames

/usr/local/lib/python3.6/dist-packages/jax/core.py in concrete_aval(x)
    829     handler = pytype_aval_mappings.get(typ)
    830     if handler: return handler(x)
--> 831   raise TypeError(f"{type(x)} is not a valid JAX type")
    832 
    833 

TypeError: <class 'elegy.nn.flatten.Flatten'> is not a valid JAX type

When calling model.fit the following exception occurs:

>>> history = model.fit(
......     x=X_train,
......     y=y_train,
......     epochs=100,
......     steps_per_epoch=200,
......     batch_size=64,
......     validation_data=(X_test, y_test),
......     shuffle=True,
......     callbacks=[elegy.callbacks.ModelCheckpoint("model", save_best_only=True)],
...... )

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

<ipython-input-6-3eecffeab807> in <module>()
      7     validation_data=(X_test, y_test),
      8     shuffle=True,
----> 9     callbacks=[elegy.callbacks.ModelCheckpoint("model", save_best_only=True)],
     10 )

13 frames

/usr/local/lib/python3.6/dist-packages/jax/_src/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
    296                     if not _arraylike(arg))
    297     msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 298     raise TypeError(msg.format(fun_name, type(arg), pos))
    299 
    300 def _check_no_float0s(fun_name, *args):

TypeError: prod requires ndarray or scalar arguments, got <class 'tuple'> at position 0.

Minimal code to reproduce
The exceptions can also be reproduced by running the following snippet of code in the Colab:

import numpy as np
import jax.numpy as jnp
import jax
import elegy
import optax


class MLP(elegy.Module):
    """Standard LeNet-300-100 MLP network."""
    def __init__(self, n1: int = 300, n2: int = 100, **kwargs):
        super().__init__(**kwargs)
        self.n1 = n1
        self.n2 = n2
    
    def call(self, image: jnp.ndarray) -> jnp.ndarray:
        
        image = image.astype(jnp.float32) / 255.0
        
        mlp = elegy.nn.sequential(
            elegy.nn.Flatten(),
            elegy.nn.Linear(self.n1),
            jax.nn.relu,
            elegy.nn.Linear(self.n2),
            jax.nn.relu,
            elegy.nn.Linear(10),
        )
        
        return mlp(image)


X_train = np.random.uniform(0, 256, size=(60000, 28, 28)).astype(np.uint8)
y_train = np.random.randint(0, 10, size=60000).astype(np.uint8)

X_test = np.random.uniform(0, 256, size=(60000, 28, 28)).astype(np.uint8)
y_test = np.random.randint(0, 10, size=60000).astype(np.uint8)

model = elegy.Model(
    module=MLP(n1=300, n2=100),
    loss=[
        elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
        elegy.regularizers.GlobalL2(l=1e-4),
    ],
    metrics=elegy.metrics.SparseCategoricalAccuracy(),
    optimizer=optax.adam(1e-3),
)


model.summary(X_train[:64])

history = model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.ModelCheckpoint("model", save_best_only=True)],
)

Expected behavior
The expected output of model.fit is the summary table:

╒═════════════════════╤═══════════════════════╤═════════════════════╤═════════════════╕
│ Layer               │ Outputs Shape         │ Trainable           │   Non-trainable │
│                     │                       │ Parameters          │      Parameters │
╞═════════════════════╪═══════════════════════╪═════════════════════╪═════════════════╡
│ Inputs              │ (64, 28, 28)  uint8   │ 0                   │               0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ flatten   (Flatten) │ (64, 784)     float32 │ 0                   │               0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ linear    (Linear)  │ (64, 300)     float32 │ 235,500    942.0 KB │               0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ relu                │ (64, 300)     float32 │ 0                   │               0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ linear_1  (Linear)  │ (64, 100)     float32 │ 30,100     120.4 KB │               0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ relu_1              │ (64, 100)     float32 │ 0                   │               0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ linear_2  (Linear)  │ (64, 10)      float32 │ 1,010      4.0 KB   │               0 │
├─────────────────────┼───────────────────────┼─────────────────────┼─────────────────┤
│ Outputs   (MLP)     │ (64, 10)      float32 │ 0                   │               0 │
╘═════════════════════╧═══════════════════════╧═════════════════════╧═════════════════╛
Total Parameters:          266,610  1.1 MB
Trainable Parameters:      266,610  1.1 MB
Non-trainable Parameters:  0

and the expected output of model.fit is training progress:

...
 
Epoch 99/100
200/200 [==============================] - 1s 4ms/step - l2_regularization_loss: 0.0452 - loss: 0.0662 - sparse_categorical_accuracy: 0.9928 - sparse_categorical_crossentropy_loss: 0.0210 - val_l2_regularization_loss: 0.0451 - val_loss: 0.1259 - val_sparse_categorical_accuracy: 0.9766 - val_sparse_categorical_crossentropy_loss: 0.0808
Epoch 100/100
200/200 [==============================] - 1s 4ms/step - l2_regularization_loss: 0.0450 - loss: 0.0610 - sparse_categorical_accuracy: 0.9953 - sparse_categorical_crossentropy_loss: 0.0161 - val_l2_regularization_loss: 0.0447 - val_loss: 0.1093 - val_sparse_categorical_accuracy: 0.9795 - val_sparse_categorical_crossentropy_loss: 0.0646

Library Info
Elegy version: 0.2.2
Jax version: 0.2.4
OS info: Colab machine with GPU accelerator

l2_normalize

Hey, @cgarciae @charlielito I'm trying to add CosineSimalarity in elegy. But for this I want to use l2_normalize in either jax or in jax.numpy. Tensorflow is having tf.math.l2_normalize, So can I use this?. Or I've to implement my own l2_normalize in jax?

Fix Docs

Remove tensorflow related stuff from docs.

  • callbacks
  • data
  • losses
  • metrics

Create FAQ page

I believe a FAQ page may help maintain the documentation simple. Some details like "How to define and monitor a learning rate schedule" are super useful but get in your way on the API Reference. A fact seems like a perfect place to have these snippets with useful information.

Specific Requirements for losses and metrics

Is there any template that we should refer to while creating a PR for a new loss function or a new metric? For example, does the loss has to be inherited from some base class, etc, etc.?

Add how to build the docs instructions

For new contributors it may be helpful to have instructions on building docs, this way one can try something before making the PR. If this does not sound like the expected workflow, we need such description to beware of the workflow.

[RFC] How to properly define the model function?

Right now we define a model function like this:

def model_fn(image) -> jnp.ndarray:
    """Standard LeNet-300-100 MLP network."""
    image = image.astype(jnp.float32) / 255.0

    mlp = hk.Sequential(
        [
            hk.Flatten(),
            hk.Linear(300),
            jax.nn.relu,
            hk.Linear(100),
            jax.nn.relu,
            hk.Linear(10),
        ]
    )
    return mlp(image)

and use it like this:

model = elegy.Model(
    model_fn=model_fn,
    loss=lambda: elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
    aux_losses=lambda: elegy.regularizers.GlobalL2(l=1e-5),
    metrics=lambda: elegy.metrics.SparseCategoricalAccuracy(),
    optimizer=optix.rmsprop(0.001),
)

This has the problem that there is no name scope around these operations, this could be a problem later if you want to reuse/extend this model somehow. An alternative is to define a Module and force users to instantiate it inside a lambda:

class MLP(hk.Module):
    def __call__(self, image) -> jnp.ndarray:
        """Standard LeNet-300-100 MLP network."""
        image = image.astype(jnp.float32) / 255.0

        mlp = hk.Sequential(
            [
                hk.Flatten(),
                hk.Linear(300),
                jax.nn.relu,
                hk.Linear(100),
                jax.nn.relu,
                hk.Linear(10),
            ]
        )
        return mlp(image)


model = elegy.Model(
    module=lambda: MLP(),
    loss=lambda: elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
    aux_losses=lambda: elegy.regularizers.GlobalL2(l=1e-5),
    metrics=lambda: elegy.metrics.SparseCategoricalAccuracy(),
    optimizer=optix.rmsprop(0.001),
)

This is a bit more similar to how you do it in e.g. skorch.

It also has the added benefit that in the future we could have users get a partial method classmethod for free to to remove the weird lambda: if they inherit from our own elegy.Module. We could in-fact extend this strategy to losses and metrics as well:

model = elegy.Model(
    module=MLP.partial(),
    loss=elegy.losses.SparseCategoricalCrossentropy.partial(from_logits=True),
    aux_losses=elegy.regularizers.GlobalL2.partial(l=1e-5),
    metrics=elegy.metrics.SparseCategoricalAccuracy.partial(),
    optimizer=optix.rmsprop(0.001),
)

lambdas have the downside of not being serializable so we might have to resort to something like this anyway.

Implement Model.summary

Create elegy.Module that inherites from hk.Module and overrides __call___ to capture output shapes.

[Feature Request] Add support for multi-node multi-cpu training

Is your feature request related to a problem? Please describe.
In cpu-only mode it's not clear if parallelization across CPUs is happening and controllable.
It's also would be necessary to scale it to multiple nodes.

Describe the solution you'd like
Using something like horovod, dask, ray or even byte-ps would be proper way to implement it.

Additional context
Multi-gpu is also another way to make that happen - but I would start from mult-cpu.

Implement Callback API

  • Port Callback API from Keras
  • Implement History structure
  • Implement Metrics callback (maybe this isn't a full callback)

Need some help for contributing new losses.

I'm trying to add a new loss MeanSquaredLogarithimcError in elegy. I've created the mean_squared_logarithmic_error.py
and also mean_squared_logarithmic_error_test.py. But while executing pytest, I'm getting error
AttributeError: module 'elegy.losses' has no attribute 'MeanSquaredLogarithmicError'
and
AttributeError: module 'elegy.losses' has no attribute 'mean_squared_logarithmic_error'
PS: I've also defined necessary imports in __init__.py .

Please see the attached image.
cc: @Davidnet @charlielito @cgarciae @haruiz
Screenshot_20200924_192606

Document Model

  • fit
  • evaluate
  • predict
  • train_on_batch
  • test_on_batch
  • predict_on_batch
  • save
  • load
  • full_state
  • clear_state

[Bug] Case sensitivity

Yeah yeah I know I should be using a case-sensitive file-system, but regardless these should be renamed IMHO:

$ git clone --depth=1 https://github.com/poets-ai/elegy
Cloning into 'elegy'...
remote: Enumerating objects: 311, done.
remote: Counting objects: 100% (311/311), done.
remote: Compressing objects: 100% (262/262), done.
remote: Total 311 (delta 76), reused 132 (delta 42), pack-reused 0
Receiving objects: 100% (311/311), 533.67 KiB | 676.00 KiB/s, done.
Resolving deltas: 100% (76/76), done.
warning: the following paths have collided (e.g. case-sensitive paths
on a case-insensitive filesystem) and only one from the same
colliding group is in the working tree:

  'docs/api/losses/Huber.md'
  'docs/api/losses/huber.md'
  'docs/api/metrics/Accuracy.md'
  'docs/api/metrics/accuracy.md'
  'docs/api/metrics/F1.md'
  'docs/api/metrics/f1.md'
  'docs/api/metrics/Precision.md'
  'docs/api/metrics/precision.md'
  'docs/api/metrics/Recall.md'
  'docs/api/metrics/recall.md'
  'docs/api/metrics/Reduce.md'
  'docs/api/metrics/reduce.md'
  'docs/api/nn/Sequential.md'
  'docs/api/nn/sequential.md'

Add documentation page for ResNet

Hey @alexander-g!

Now that the RestNet* architectures are on master I think we should add them to the documentation. This is done by adding the __all__ constant to elegy/nets/__init__.py plus including "resnet" in the list, and then adding the __all__ constant to elegy/nets/resnet.py and including all the "ResNet*" names. Apart from that, basic docs are missing from the ResNet class and __init__.

To run the documentation locally you can execute:

scripts/run-docs

This will compile the documentation with mkdocs and spin up a server.

Pure Training Functions

Hi @cgarciae , do you think it's possible to implement pure jittable training functions in Elegy? i.e. a function that does not depend on or modifies global state:
logs, new_parameters, new_state = train_fn(inputs, old_parameters, old_state)

I see several reasons to have something like that:

  1. It's simply in the spirit of Jax
  2. Custom training loops
  3. I am planning to create a Jax interpreter based on Vulkan. Jax has the nice feature that it translates all operations to XLA primitives, one would only need to implement them in other backends. This would enable running on non-Nvidia GPUs and even embedded hardware like Raspberry Pi or smartphones. I've already done some tests and it seems to work, however this would require the function to be pure and jittable.

I've taken a look at the module.jit() function, the _jit_fn() without the wrapper seems to come close to what I want but not exactly because it modifies global state. Does it really require states_tuple and statics parameters? The first one is literally not used, and the second one does not seem to get modified anywhere, or does it? Can we get rid of .set_parameters() in model_base.py and the global RNG key?

For the beginning it would be maybe enough to have a clean_up() function that gets called after tracing the jitted training function or a context manager which is basically what wrapper() already is. Would be of course better if it was completely pure.

[Bug] Size of parameters.h5 too large

The parameters.h5 file saved by Model.save is way larger than the model parameters. The factor is not fixed and varies between 3 to 5 depending on the model.

model = elegy.Model(elegy.nets.resnet.ResNet18())
model.predict(np.zeros([1,224,224,3]))
print('Size of model parameters:', len(pickle.dumps(model.get_parameters()))/1e6, 'MB' )

model.save('tmp')
print('Size of parameters.h5:', len(open('tmp/parameters.h5', 'rb').read())/1e6, 'MB')

Output:

Size of model parameters: 46.848071 MB
Size of parameters.h5: 238.33064 MB

Additionally the function creates an annoying, unrelated warning:

FutureWarning: The Panel class is removed from pandas. Accessing it from the top-level namespace will also be removed in the next version
  elif _pandas and isinstance(level, (pd.DataFrame, pd.Series, pd.Panel)):

Version: Latest commit in master

Multi-gpu with pmap docs

One of the selling points of jax is the pmap transformation, but best practices around actually getting your training loop parallelizable still is confusing. What is elegy's story around multigpu training? Is it possible to get to pytorch-lightning like api as a single arg to model.fit?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.