ipl-uv / rbig_jax Goto Github PK
View Code? Open in Web Editor NEWIterative and Parametric Gaussianization with JAX.
Home Page: https://ipl-uv.github.io/rbig_jax/
License: MIT License
Iterative and Parametric Gaussianization with JAX.
Home Page: https://ipl-uv.github.io/rbig_jax/
License: MIT License
Probably the only implementation that allows for invertible convolutions which spatial awareness. It is related to the autoregressive method so it will be fast for density estimates but slow for sampling.
They create the exponential convolution transform. It basically uses the exponential of the convolutions. It approximates the transform via a Taylor series. It generalizes most of the householder approaches, i.e. Sylvester Flows.
Need a dedicated readthedocs
documentation.
Need a wrapper for an implementation for mutual information.
Notebooks:
Implement conditional normalizing flows. It requires having a context argument that needs to be carried throughout all of the transformations until the base distribution.
We still have issues with the boundaries for the marginal histogram transformation. When we do the sampling, that's when we see it the most.
Proposal
It would be interesting to see what would happen if we do a "squeezing layer" to constrain the domain space. So instead of the histogram transformation on the input domain (-inf, inf) -> [0,1], we could implement an invertible "squeezing" function to the domain, (-inf, inf) -> [0, 1], and then do the histogram transformation, [0,1]->[0,1]. At the very least, we don't have to worry about the bounds for the histogram function.
Implement the parametric spline function as an element-wise transformation, i.e. marginal Gaussianization.
A basic implementation of the GDN algorithm. It is mainly used in compression. It features a normalization which can be coupled with a convolution/linear layer.
y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))
Somehow we need to constrain a few parameters, e.g., gamma
and beta
. Probably a simple np.max
would work.
Need an implementation and demonstration of the KL-Divergence Wrapper.
Notebooks
We currently have the householder transformation but we can use that for the SVD layer, X=USV^T
. The rotation matrices (U,V^T
) are constrained to be orthogonal and the diagonal, S
, is unconstrained. It's a bit more expensive but it might help for training.
A notebook showcasing information theoretic metrics:
Note: we really want to focus on the speed and simplicity. E.g. We can show how one can do it from scratch as well as the convenient wrapper functions to explain the design decisions.
It would be nice to see the differences between the information loss/reduction for the GaussFlow algorithm and the IterGauss algorithm.
Outcome: A demo notebook showcasing the info loss between layers for both algorithms.
Need an implementation and demonstration of the Total correlation.
Notebooks
objax is basically a PyTorch-like version of Jax. But it is a bit limiting when trying to mix and match reviews. So I think we should remove the dependency on objax and stick with pure Jax.
This example was taken from jax-flows.
def init_fun(rng, input_dim, **kwargs):
W = orthogonal()(rng, (input_dim, input_dim))
W_inv = linalg.inv(W)
W_log_det = np.linalg.slogdet(W)[-1]
def direct_fun(params, inputs, **kwargs):
outputs = inputs @ W
log_det_jacobian = np.full(inputs.shape[:1], W_log_det)
return outputs, log_det_jacobian
def inverse_fun(params, inputs, **kwargs):
outputs = inputs @ W_inv
log_det_jacobian = np.full(inputs.shape[:1], -W_log_det)
return outputs, log_det_jacobian
return (), direct_fun, inverse_fun
return init_fun
</details>
I have notebooks documenting the progression to building and RBIG model but there is no all-in-one notebook showcasing
Examples components:
Tried running RBIG on CIFAR data. It does not converge after a long time. Tried with 10, 20, 30, etc. Only after 50 layers does it converge.
There is not guarantee that the bisection search will converge. We should also have a max layers stopping criteria as well.
This can be one of the layers done for the iterative method. It can be done by splitting a random key with every layer and intializing the rotation matrix with every layer.
Source:
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.