This package is a JAX
-based implementation of Conditional Flow Matching
(CFM) - an approach for generative modelling based on continuous normalizing
flows. The API design of this package is closely tied to that of the
TorchCFM
library to
allow users used to TorchCFM
who want to migrate to JAX
an easy transition.
This repository is currently under construction and thus may not be bug-free or complete at this point.
To install JAX-CFM
clone this repository and run pip install .
in an
environment with a python
version >= 3.10.
If you intend to contribute or run examples, please consider installing
with optional packages as well (e.g. pip install .[dev]
or
pip install .[examples]
).
Eventually, the goal is to make the package available on PyPi.
JAX-CFM
relies on ott-jax
for Optimal
Transport-related tasks and uses
equinox
and
jaxtyping
for API
design and type annotations and checking.