Giter Site home page Giter Site logo

harryjo97 / riemannian-diffusion-mixture-torch Goto Github PK

View Code? Open in Web Editor NEW
2.0 2.0 0.0 208 KB

PyTorch implementation for "Generative Modeling on Manifolds Through Mixture of Riemannian Diffusion Processes" (ICML 2024).

Home Page: https://arxiv.org/abs/2310.07216

Python 99.45% Shell 0.55%

riemannian-diffusion-mixture-torch's Introduction

Riemannian Diffusion Mixture

This repo contains a PyTorch implementation for the paper Generative Modeling on Manifolds Through Mixture of Riemannian Diffusion Processes.

We provide official code repo for JAX implementation in riemannian-diffusion-mixture.

Why Riemannian Diffusion Mixture?

  • Simple design of the generative process as a mixture of Riemannian bridge processes, which does not require heat kernel estimation as previous denoising approach.
  • Geometrical interpretation for the mixture process as the weighted mean of tangent directions on manifolds
  • Scales to higher dimensions with significantly faster training compared to previous diffusion models.

Dependencies

Create an environment with Python 3.9.0, and Pytorch 2.0.0. Install requirements with the following command:

pip install -r requirements.txt
conda install -c conda-forge cartopy python-kaleido

Manifolds

Following manifolds are supported in this repo:

  • Euclidean
  • Hypersphere
  • Torus
  • Hyperboloid
  • Triangular mesh
  • Special orthogonal group

To implement new manifolds, add python files that define the geometry of the manifold in /geomstats/geometry.

Please refer to geomstats/geometry for examples.

Running Experiments

This repo supports experiments on the following datasets:

  • Protein datasets: General, Glycine, Proline, and Pre-Pro, and RNA.
  • High-dimensional tori

Please refer to riemannian-diffusion-mixture for running expreiments on earth and climate science datasets, triangular mesh datasets, and hyperboloid datasets.

1. Dataset preparations

For experiment on Protein datasets, create .tsv file in /data/top500 directory with the following command:

cd data/top500
bash batch_download.sh -f list_file.txt -p
python get_torsion_angle.py

For experiment on RNA dataset, create .tsv file in /data/rna directory with the following command:

cd data/rna
bash batch_download.sh -f list_file.txt -p
python get_torsion_angles.py

2. Configurations

The configurations are provided in the config/ directory in YAML format.

3. Experiments

CUDA_VISIBLE_DEVICES=0 python main.py -m \
    experiment=<exp> \
    seed=0,1,2,3,4 \
    n_jobs=5 \

where <exp> is one of the experiments in config/experiment/*.yaml

For example,

CUDA_VISIBLE_DEVICES=0 python main.py -m \
    experiment=rna \
    seed=0,1,2,3,4 \
    n_jobs=5 \

To run experiments on high-dimensional tori, use experiment=htori with n=$DIM where $DIM denotes the dimesion of the tori.

Citation

If you found the provided code with our paper useful in your work, we kindly request that you cite our work.

@inproceedings{jo2024riemannian,
  author    = {Jaehyeong Jo and
               Sung Ju Hwang},
  title     = {Generative Modeling on Manifolds Through Mixture of Riemannian Diffusion Processes},
  booktitle = {International Conference on Machine Learning},
  year      = {2024},
}

Acknowledgments

Our code builds upon geomstats. We thank Riemannian Score-Based Generative Modelling and Riemmanian Flow Matching for their works.

riemannian-diffusion-mixture-torch's People

Contributors

harryjo97 avatar

Stargazers

 avatar Alison avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.