Giter Site home page Giter Site logo

linen-image-models's Introduction

Linen Image Models

limo is an easily customizable image model implementation with pre-trained variables. The pre-trained variables are loaded correctly even if you add or remove modules. Currently, this project would like to reimplement torchvision.models.

Due to the limitation of my onedrive storage, pretrained variables are not available now. Sorry!

Installation

  1. Install JAX for your environment. See details in the installation guide of JAX.
  2. Install limo via pip:
$ pip install git+https://github.com/h-terao/linen-image-models

Usage

Basic usage

To use builtin models and their pretrained variables, take the following steps.

  1. Create model via limo.create_model.
  2. Initialize varaiables in the standard flax manner.
  3. Overwrite initalized variables with pretrained variables using limo.load_pretrained.
import jax
import limo

x =  jax.numpy.zeros((224, 224, 3))

model = limo.create_model("convnext_tiny", num_classes=100)
variables = model.init(jax.random.PRNGKey(0), x)
variables = limo.load_pretrained(variables, "convnext_tiny", pretrained=True)
state, params = variables.pop("params")

# inference mode.
out = model.apply({"params": params, **state}, x)

# train mode.
out, new_state = model.apply(
    {"params": params, **state},
    x,
    rngs={"dropout": jax.random.PRNGKey(0)}
    is_training=True,
    mutable=True,
)

Use builtin models as modules of your model

Call limo.create_model in your model to use builtin models iside your model. To load pretrained variables, name the created model and specify the name as module_name when calling limo.load_pretrained. If you would like to load variables to deeper modules, specify module names joined by dot (e.g., f1.f1_child.f1_grandchild).

import jax
from flax import linen
import limo


class Model(linen.Module):

    @linen.compact
    def __call__(self, x, is_training):
        f1 = limo.create_model("convnext_tiny", name="f1")  # Pass name to load variables.
        f2 = limo.create_model("efficientnet_b0", name="f2")
        y = f1(x, is_training) + f2(x, is_training)
        return y

x =  jax.numpy.zeros((224, 224, 3))

model = limo.create_model("convnext_tiny", num_classes=100)
variables = model.init(jax.random.PRNGKey(0), x)
variables = limo.load_pretrained(variables, "convnext_tiny", pretrained=True, module_name="f1")
variables = limo.load_pretrained(variables, "efficientnet_b0", pretrained=True, module_name="f2")

# inference mode.
out = model.apply(variables, x, is_training=False)

Load your own variables

To load your own variables, limo.maybe_overwrite_variables is useful. This method also supports module_name option to load variables to modules like limo.load_pretrained.

to_load = ...  # your own variables.
variables = limo.maybe_overwrite_variables(variables, to_load)
variables = limo.maybe_overwrite_variables(variables, to_load, module_name="f1")  # load variables to `f1` module.

Examples

In examples/, some examples are implemented.

  • ensemble.py: Example of how to use builtin models as modules of your model, and how to load variables into modules of a model.
  • resnet_tsm.py: Example of model customization.

linen-image-models's People

Contributors

h-terao avatar

Stargazers

Inoichan avatar

Watchers

 avatar

linen-image-models's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.