Giter Site home page Giter Site logo

stokhos / large-scale-wasserstein-gradient-flows Goto Github PK

View Code? Open in Web Editor NEW

This project forked from petrmokrov/large-scale-wasserstein-gradient-flows

0.0 0.0 0.0 29.16 MB

Source code for Large-Scale Wasserstein Gradient Flows (NeurIPS 2021)

Home Page: https://arxiv.org/abs/2106.00736

License: MIT License

Python 19.98% Jupyter Notebook 80.02%

large-scale-wasserstein-gradient-flows's Introduction

Large-Scale Wasserstein Gradient Flows

This repository contains code and results of the experiments of NeurIPS 2021 paper Large-Scale Wasserstein Gradient Flows by Petr Mokrov, Alexander Korotin, Lingxiao Li, Aude Genevay, Justin Solomon and Evgeny Burnaev. We approximate gradient flows and, in particular, diffusion processes governed by Fokker-Planck equation using JKO scheme modelled via Input Convex Neural Networks. We conduct experiments to demonstrate that our approach works in different scenarios and machine learning applications.

Citation

If you find this repository or the ideas presented in our paper useful, please consider citing our paper.

@article{mokrov2021large,
  title={Large-scale wasserstein gradient flows},
  author={Mokrov, Petr and Korotin, Alexander and Li, Lingxiao and Genevay, Aude and Solomon, Justin M and Burnaev, Evgeny},
  journal={Advances in Neural Information Processing Systems},
  volume={34},
  year={2021}
}

Prerequisites

It is highly recommended to use GPU to launch our experiments. The list of required python libraries can be found in ./requirements.txt. One can install the libraries via the following command:

> pip install -r requirements.txt

Related repositories

Experiments

All our experiments could be launched via ./script.py script. The experiments use config files presented in ./configs/ directory which define hyperparameters of the experiments. See our submission for the details.

The results of the experiments are saved to the ./results directory and could be visualized using ./W2JKO_results.ipynb notebook. All the images representing our experiments are stored in ./images directory

Convergence to Stationary Solution

We test if our gradient flow approximating advection-diffusion process manage to converge to the stationary distribution.

Quantitative comparison

To reproduce the quantitative comparison of our method with particle based methods run the following:

> python .\script.py conv_comp_dim_[dimensionality] --method [method] --device [device]

Use D = 2, 4, 6, 8, 10, 12 for dimensionality option, available methods are ICNN_jko, EM_sim_1000, EM_sim_1000, EM_sim_50000. Additionally one can consier EM_ProxRec_400 and EM_ProxRec_1000 methods. The device option make sense only for ICNN_jko method.

In particular, the command below launches quantitative comparison experiment for the dimension D=8 using our method on the cuda:0 device:

> python .\script.py conv_comp_dim_8 --method ICNN_jko --device cuda:0

The results for all dimensions are presented in the image below:

drawing

Qualitative comparison

The qualitative comparion results could be reprodused via the following command:

> python .\script.py conv_mix_gauss_dim_[dimensionality] --device [device]

The dimensionality can be either D = 13 or D = 32. The comparison between fitted and true stationary distribution for D = 32 below:

drawing

Modelling Ornstein-Uhlenbeck processes

We model advection-diffusion processes with special quadratic-form potentials which have close-form solution for marginal process distribution at each observation time.

To launch Ornstein-Uhlenbeck experiment run the command below:

> python .\script.py ou_vary_dim_freq --method [method] --device [device]

The available options for method are ICNN_jko, EM_sim_1000, EM_sim_10000, EM_sim_50000, 'EM_ProxRec_10000, dual_jko.

The obtained divergence between true and fitted distributions for t = 0.9 sec.:

drawing

Unnormalized Posterior Sampling

Given the prior distribution of model parameters and conditional data distribution we model posterior parameters distribution by establishing it as stationary one of the gradient flow.

The experiments with different benchmark datasets could be run as follows:

> python .\script.py [dataset]_data_posterior --device [device]

The supported datasets are : covtype, diabetis, german, splice, banana, waveform, ringnorm, twonorm, image.

Nonlinear filtering

We model predictive distribution at final time-moment of the latent highly-nonlinear diffusion process X given noisy observations obtained at specific time moments.

To reproduce our results run the command:

> python .\script.py filtering --method [method] --device [device]

The available methods are: ICNN_jko, dual_jko, bbf_100, bbf_1000, bbf_10000, bbf_50000

The obtained discrepancy between fitted methods and ground truth method (Chang&Cooper numerical integration) presented below:

drawing

Credits

large-scale-wasserstein-gradient-flows's People

Contributors

iamalexkorotin avatar petrmokrov avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.