Giter Site home page Giter Site logo

ipl-uv / rbig_jax Goto Github PK

View Code? Open in Web Editor NEW
6.0 3.0 1.0 113.24 MB

Iterative and Parametric Gaussianization with JAX.

Home Page: https://ipl-uv.github.io/rbig_jax/

License: MIT License

Makefile 0.01% Python 0.45% Jupyter Notebook 99.55%
gaussianization density-estimation rbig information-theory jax sampling generative-model

rbig_jax's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

alexhepburn

rbig_jax's Issues

[Algorithm][Parametric] Emerging Convolutions

Probably the only implementation that allows for invertible convolutions which spatial awareness. It is related to the autoregressive method so it will be fast for density estimates but slow for sampling.


Resources

[Investigate] "Squeezing Layer" for Marginal Histogram Transformation.

We still have issues with the boundaries for the marginal histogram transformation. When we do the sampling, that's when we see it the most.

Proposal

It would be interesting to see what would happen if we do a "squeezing layer" to constrain the domain space. So instead of the histogram transformation on the input domain (-inf, inf) -> [0,1], we could implement an invertible "squeezing" function to the domain, (-inf, inf) -> [0, 1], and then do the histogram transformation, [0,1]->[0,1]. At the very least, we don't have to worry about the bounds for the histogram function.

[Algorithm][Parametric] Implement the GDN layer

A basic implementation of the GDN algorithm. It is mainly used in compression. It features a normalization which can be coupled with a convolution/linear layer.

y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))

Resources


Potential Issues

Somehow we need to constrain a few parameters, e.g., gamma and beta. Probably a simple np.max would work.

[Algorithm][Parametric] SVD Linear Layer

We currently have the householder transformation but we can use that for the SVD layer, X=USV^T. The rotation matrices (U,V^T) are constrained to be orthogonal and the diagonal, S, is unconstrained. It's a bit more expensive but it might help for training.


Resources

ITM Notebook

A notebook showcasing information theoretic metrics:

  • Probability & Information
  • Total Correlation
  • Entropy
  • Mutual Information
  • KL-Divergence

Note: we really want to focus on the speed and simplicity. E.g. We can show how one can do it from scratch as well as the convenient wrapper functions to explain the design decisions.

[Demo] Info Loss for GaussFlow versus RBIG

It would be nice to see the differences between the information loss/reduction for the GaussFlow algorithm and the IterGauss algorithm.

Outcome: A demo notebook showcasing the info loss between layers for both algorithms.

[Parametric] Remove Dependency on objax

objax is basically a PyTorch-like version of Jax. But it is a bit limiting when trying to mix and match reviews. So I think we should remove the dependency on objax and stick with pure Jax.


Example

This example was taken from jax-flows.

Demo Snippet ```python def FixedInvertibleLinear(): """An implementation of an invertible linear layer from `Glow: Generative Flow with Invertible 1x1 Convolutions` (https://arxiv.org/abs/1605.08803). Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. """
def init_fun(rng, input_dim, **kwargs):
    W = orthogonal()(rng, (input_dim, input_dim))
    W_inv = linalg.inv(W)
    W_log_det = np.linalg.slogdet(W)[-1]

    def direct_fun(params, inputs, **kwargs):
        outputs = inputs @ W
        log_det_jacobian = np.full(inputs.shape[:1], W_log_det)
        return outputs, log_det_jacobian

    def inverse_fun(params, inputs, **kwargs):
        outputs = inputs @ W_inv
        log_det_jacobian = np.full(inputs.shape[:1], -W_log_det)
        return outputs, log_det_jacobian

    return (), direct_fun, inverse_fun

return init_fun

</details>

Deep Dive Demo Notebook

I have notebooks documenting the progression to building and RBIG model but there is no all-in-one notebook showcasing

Examples components:

  • Marginal Gaussianization
  • Univariate Entropy
  • Information Loss
  • ITMs -> TC, H

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.