Giter Site home page Giter Site logo

jaanli / variational-autoencoder Goto Github PK

View Code? Open in Web Editor NEW
1.1K 34.0 255.0 97 KB

Variational autoencoder implemented in tensorflow and pytorch (including inverse autoregressive flow)

Home Page: https://jaan.io/what-is-variational-autoencoder-vae-tutorial/

License: MIT License

Python 100.00%
vae machine-learning tensorflow variational-autoencoder variational-inference probabilistic-graphical-models deep learning deep-learning deep-neural-networks pytorch unsupervised-learning autoregressive-neural-networks

variational-autoencoder's Introduction

Variational Autoencoder in tensorflow and pytorch

DOI

Reference implementation for a variational autoencoder in TensorFlow and PyTorch.

I recommend the PyTorch version. It includes an example of a more expressive variational family, the inverse autoregressive flow.

Variational inference is used to fit the model to binarized MNIST handwritten digits images. An inference network (encoder) is used to amortize the inference and share parameters across datapoints. The likelihood is parameterized by a generative network (decoder).

Blog post: https://jaan.io/what-is-variational-autoencoder-vae-tutorial/

PyTorch implementation

(anaconda environment is in environment-jax.yml)

Importance sampling is used to estimate the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. The final marginal likelihood on the test set was -97.10 nats is comparable to published numbers.

$ python train_variational_autoencoder_pytorch.py --variational mean-field --use_gpu --data_dir $DAT --max_iterations 30000 --log_interval 10000
Step 0          Train ELBO estimate: -558.027   Validation ELBO estimate: -384.432      Validation log p(x) estimate: -355.430  Speed: 2.72e+06 examples/s
Step 10000      Train ELBO estimate: -111.323   Validation ELBO estimate: -109.048      Validation log p(x) estimate: -103.746  Speed: 2.64e+04 examples/s
Step 20000      Train ELBO estimate: -103.013   Validation ELBO estimate: -107.655      Validation log p(x) estimate: -101.275  Speed: 2.63e+04 examples/s
Step 29999      Test ELBO estimate: -106.642    Test log p(x) estimate: -100.309
Total time: 2.49 minutes

Using a non mean-field, more expressive variational posterior approximation (inverse autoregressive flow, https://arxiv.org/abs/1606.04934), the test marginal log-likelihood improves to -95.33 nats:

$ python train_variational_autoencoder_pytorch.py --variational flow
step:   0       train elbo: -578.35
step:   0               valid elbo: -407.06     valid log p(x): -367.88
step:   10000   train elbo: -106.63
step:   10000           valid elbo: -110.12     valid log p(x): -104.00
step:   20000   train elbo: -101.51
step:   20000           valid elbo: -105.02     valid log p(x): -99.11
step:   30000   train elbo: -98.70
step:   30000           valid elbo: -103.76     valid log p(x): -97.71

jax implementation

Using jax (anaconda environment is in environment-jax.yml), to get a 3x speedup over pytorch:

$ python train_variational_autoencoder_jax.py --variational mean-field 
Step 0          Train ELBO estimate: -566.059   Validation ELBO estimate: -565.755      Validation log p(x) estimate: -557.914  Speed: 2.56e+11 examples/s
Step 10000      Train ELBO estimate: -98.560    Validation ELBO estimate: -105.725      Validation log p(x) estimate: -98.973   Speed: 7.03e+04 examples/s
Step 20000      Train ELBO estimate: -109.794   Validation ELBO estimate: -105.756      Validation log p(x) estimate: -97.914   Speed: 4.26e+04 examples/s
Step 29999      Test ELBO estimate: -104.867    Test log p(x) estimate: -96.716
Total time: 0.810 minutes

Inverse autoregressive flow in jax:

$ python train_variational_autoencoder_jax.py --variational flow 
Step 0          Train ELBO estimate: -727.404   Validation ELBO estimate: -726.977      Validation log p(x) estimate: -713.389  Speed: 2.56e+11 examples/s
Step 10000      Train ELBO estimate: -100.093   Validation ELBO estimate: -106.985      Validation log p(x) estimate: -99.565   Speed: 2.57e+04 examples/s
Step 20000      Train ELBO estimate: -113.073   Validation ELBO estimate: -108.057      Validation log p(x) estimate: -98.841   Speed: 3.37e+04 examples/s
Step 29999      Test ELBO estimate: -106.803    Test log p(x) estimate: -97.620
Total time: 2.350 minutes

(The difference between a mean field and inverse autoregressive flow may be due to several factors, chief being the lack of convolutions in the implementation. Residual blocks are used in https://arxiv.org/pdf/1606.04934.pdf to get the ELBO closer to -80 nats.)

Generating the GIFs

  1. Run python train_variational_autoencoder_tensorflow.py
  2. Install imagemagick (homebrew for Mac: https://formulae.brew.sh/formula/imagemagick or Chocolatey in Windows: https://community.chocolatey.org/packages/imagemagick.app)
  3. Go to the directory where the jpg files are saved, and run the imagemagick command to generate the .gif: convert -delay 20 -loop 0 *.jpg latent-space.gif

TODO (help needed - feel free to send a PR!)

  • add multiple GPU / TPU option
  • add jaxtyping support for PyTorch and Jax implementations :) for runtime static type checking (using @beartype decorators)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.