Giter Site home page Giter Site logo

umangjpatel / kerax Goto Github PK

View Code? Open in Web Editor NEW
50.0 4.0 6.0 10.24 MB

Keras-like APIs for JAX framework

License: The Unlicense

Python 100.00%
deeplearning-framework automatic-differentiation python3 deep-learning deep-neural-networks google numpy pandas matplotlib tqdm beginner-friendly jax kerax

kerax's Introduction

logo

Kerax

Keras-like APIs for the JAX library.

Features

  • Enables high-performance machine learning research.
  • Built-in support of popular optimization algorithms and activation functions.
  • Runs seamlessly on CPU, GPU and even TPU! without any manual configuration required.

Quickstart

Code

from kerax.datasets import binary_tiny_mnist
from kerax.layers import Dense, Relu, Sigmoid
from kerax.losses import BCELoss
from kerax.metrics import binary_accuracy
from kerax.models import Sequential
from kerax.optimizers import SGD

data = binary_tiny_mnist.load_dataset(batch_size=200)
model = Sequential([Dense(100), Relu, Dense(1), Sigmoid])
model.compile(loss=BCELoss, optimizer=SGD(step_size=0.003), metrics=[binary_accuracy])
model.fit(data=data, epochs=10)
model.save(file_name="model")

interp = model.get_interpretation()
interp.plot_losses()
interp.plot_accuracy()

Output

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Epoch 10: 100%|██████████| 10/10 [00:02<00:00,  3.82it/s, train_loss : 0.192 :: valid_loss : 0.202 :: train_binary_accuracy : 1.000 :: valid_binary_accuracy : 1.000]

Process finished with exit code 0

Quickstart code Loss Curves Quickstart code Accuracy Curves

Documentation (Coming soon...)

Developer's Notes

This project is developed and maintained by Umang Patel

kerax's People

Contributors

dependabot[bot] avatar umangjpatel avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

kerax's Issues

DataBunch-like API

Description
Implement a databunch-like API for loading image data.

Solution
Use PyTorch and TensorFlow data loading libraries

Additional context
Refer to jax.readthedocs.io for more info.

Regression loss functions missing

Feature request description
Regression algorithms missing.

Solution proposed
Implement loss functions like MSE, RMSE, etc.

Additional context
Refer to docs.fast.ai, PyTorch and Keras docs for implementation

Constant accuracy in CNN

Description
Constant accuracy and large losses while using CNN APIs

Expected behaviour
Accuracy should increase while loss should decrease

JAX imports not working

Description
JAX imports displays NotImplementedError

Expected behavior
To use JAX-numpy in accordance with the original NumPy package.

Proposed Solution
Replace JAX computation engine with PyTorch package.

CNN APIs

Description
Implement CNN layers and relevant APIs

Solution
Use LAX APIs

Additional context
Refer to jax.readthedocs.io for more info.

Stax-like layers API

Description
Stax-like layers API refactoring

Solution
Write every layer with apply_fun and init_fun

Support for saving and loading models

Is your feature request related to a problem? Please describe.
After training large models, we shall be able to save and reload them for further experiments.

Describe the solution you'd like
Models shall be able to save (serialize) and load (deserialize) items as follows :

  • Model parameters
  • Layers in models

Additional context
Refer to flax serialization logic for more info. It involves serializing and deserializing JAX PyTrees to MessagePack format.

Optimizers

Description
Add optimization algorithms

Solution
Use JAX experimental module

Multinomial / Multiclass classification implementation

Description of feature request
Multiclass / Multinomial classification problems not supported.

Solution proposed
Implement softmax function

Additional context
Refer to docs.fast.ai, PyTorch and Keras documentation for implementation details

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.