Giter Site home page Giter Site logo

fab-torch's Introduction

Flow Annealed Importance Sampling Bootstrap (FAB)

Overview

Normalizing flows can approximate complicated Boltzmann distributions of physical systems. However, current methods for training flows either suffer from mode-seeking behavior, use samples from the target generated beforehand by expensive MCMC simulations, or use stochastic losses that have very high variance. We tackle this challenge by augmenting flows with annealed importance sampling (AIS) and minimize the mass covering $\alpha$-divergence with $\alpha = 2$, which minimizes importance weight variance. Our method, Flow AIS Bootstrap (FAB), uses AIS to generate samples in regions where the flow is a poor approximation of the target, facilitating the discovery of new modes.

In this repository, we implement FAB and provide the code to reproduce our experiments. For more details about our method and the results of our experiments, please read our paper.

Methods of Installation

The package can be installed via pip by navigating in the repository directory and running

pip install --upgrade .

To run the alanine dipeptide experiments, you will need to install the OpenMM Library as well as openmmtools. This can be done via conda.

conda install -c conda-forge openmm openmmtools

Experiments

Gaussian Mixture Model

Open In Colab

For this problem we use a mixture of 40 two dimensional Gaussian distributions. This allows for easy visualisation of the various methods for training the flow. We provide a colab notebook with an example of training a flow on the GMM problem, comparing FAB to training a flow with KL divergence minimisation. This can be run in a short period of time (10 min) and provides a clear visualisation of how FAB is able to discover new modes and fit them.

To run the experiment for the FAB with a prioritised replay buffer (for the first seed), use the following command:

python experiments/gmm/run.py training.use_buffer=True training.prioritised_buffer=True

To run the full set of experiments see the README for the GMM experiments.

The below plot shows samples from various trained models, with the GMM problem target contours in the background. Gaussian Mixture Model samples vs contours

Many Well distribution

The Many Well distribution is made up of multiple repeats of the Double Well distribution, from the original Boltzmann generators paper.

We provide a colab notebook comparing FAB to training a flow via KL divergence minimisation, on the 6 dimensional Many Well problem, where the difference between the two methods is apparent after a short (<10 min) training period.

To run the experiment for the FAB with a prioritised replay buffer (for the first seed) on the 32 dimensional Many Well problem, use the following command:

python experiments/many_well/run.py training.use_buffer=True training.prioritised_buffer=True

To run the full set of experiments see the README for the Many Well experiments.

The below plot shows samples for our model (FAB) vs training a flow by reverse KL divergence minimisation, with the Many Well problem target contours in the background. This visualisation is for the marginal pairs of the distributions for the first four elements of the x. Many Well distribution FAB vs training by KL divergence minimisation

Alanine dipeptide

In our final experiment, we approximate the Boltzmann distribution of alanine dipeptide in an implicit solvent, which is a molecule with 22 atoms and a popular model system. The molecule is visualized in the figure below. The right figure shows the probability density of for the dihedral angle $\phi$ comparing the ground truth, which was obtrained with a molecular dynamics (MD) simulation, the models trained with our method as well as maximum likelihood on MD samples.

Alanine dipeptide and its dihedral angles; Comparison of probability densities

Furthermore, we compared the Ramachandran plots of the different methods in the following figure.

Ramachandran plot of alanine dipeptide

To reproduce our experiment, use the experiments/aldp/train.py script. The respective configuration files are located in experiments/aldp/config. We used the seeds 0, 1, and 2 in our runs.

The data used to evaluate our models and to train the flow model with maximum likelihood is provided on Zenodo. If you want to use the configuration files in experiments/aldp/config as is, you should put the data in the experiment/aldp/data folder.

DOI

About the code

The main FAB loss can be found in core.py, and we provide a simple training loop to train a flow with this loss (or other flow - loss combinations that meet the spec) in train.py The FAB training algorithm with the prioritised buffer can be found in train_with_prioritised_buffer.py. Additionally, we provide the code for running the SNR/dimensionality analysis with p and q set to independent Gaussians. in the fab-jax repository. For training the CRAFT model on the GMM problem we forked the Annealed Flow Transport repository. This fork may be found here, and may be used for training the CRAFT model.

Normalizing Flow Libraries

We offer a simple wrapper that allows for various normalising flow libraries to be plugged into this repository. The main library we rely on is normflows.

Citation

If you use this code in your research, please cite it as:

Laurence I. Midgley, Vincent Stimper, Gregor N. C. Simm, Bernhard Schölkopf, José Miguel Hernández-Lobato. Flow Annealed Importance Sampling Bootstrap. arXiv preprint arXiv:2208.01893, 2022.

Bibtex

@article{Midgley2022,
  title={Flow {A}nnealed {I}mportance {S}ampling {B}ootstrap},
  author={Laurence I. Midgley and Vincent Stimper and Gregor N. C. Simm and Bernhard Sch\"olkopf and Jos{\'e} Miguel Hern{\'a}ndez-Lobato},
  journal={arXiv preprint arXiv:2208.01893},
  year={2022}
}

fab-torch's People

Contributors

vincentstimper avatar lollcat avatar thargreaves avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.