Giter Site home page Giter Site logo

pyomt's Introduction

pyOMT: A Pytorch implementation of Adaptive Monte Carlo Optimal Transport Algorithm

Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.

The optimal transport problem arises whenever one wants to transform one distribution to another distribution in a optimal way. For example, computing measure preserving maps between surfaces/volumes, matching two histograms, and generating realistic pictures from a given dataset in deep learning.

The Adaptive Monte Carlo Optimal Transport algorithm tackles potentially high-dimensional semi-discrete OT problems in a scalable way by finding the minimum of a convex energy (i.e. the Brenier potential), which induces the optimal transport map from a continuous distribution to a empirical distribution. The energy is optimized by gradient descent method, and at each iteration, the gradient of the energy is estimated using the Monte Carlo integration.

One application of the OT solver is to generate new samples of a dataset. Intuitively, the generation involves producing new samples near the empirical distribution formed by the dataset samples. With the optimal transport map from a continuous prior distribution (e.g. uniform or Gaussian) to this empirical distribution computed, new samples can be easily generated by sampling in the prior distribution and then mapping it to the dataset distribution with the OT map.

AEOT

Reference

@inproceedings{
An2020AE-OT:,
title={AE-OT: A NEW GENERATIVE MODEL BASED ON EXTENDED SEMI-DISCRETE OPTIMAL TRANSPORT},
author={Dongsheng An and Yang Guo and Na Lei and Zhongxuan Luo and Shing-Tung Yau and Xianfeng Gu},
booktitle={International Conference on Learning Representations},
year={2020},
url={[https://openreview.net/pdf?id=HkldyTNYwH]}
}

Implementation

Code is developed in PyTorch for better integration with deep learning frameworks. The code is for research purpose only. Please open an issue or email me at yangg20111 (at) gmail (dot) com if you have any problem with the code. Suggestions are also highly welcomed.

Dependencies

  1. Python=3.6 (or above)
  2. PyTorch=1.3.0 (or above)
  3. NumPy=1.17.4 (or above)
  4. Matplotlib=3.1.0 (or above)

Demos

Generation examples on simple measures (i.e. toy sets).

  • Code:

    python demo1.py

  • Results: 8Gaussians 25Gaussians SwissRoll

Generation human face images with AE-OT framework.

This demo shows the application of the adaptive Monte Carlo OT solver in image generation tasks. Dataset used here is the CelebA_crop_resize_64 dataset, which contains ~200,000 human face images of resolution 64x64. With the OT solver, infinitely many realistic face images can be generated.

  • Model training and generating:

    python demo2.py --data_root_train path-to-your-training-dataset --data_root_test path-to-your-test-dataset

  • Generating with pre-trained models:

    Download the pre-trained models here, extract the files to the "results" folder, and

    python demo2.py --generate_feature --decode_feature --data_root_train path-to-your-training-dataset --data_root_test path-to-your-test-dataset

    OR if you want to only use the pre-trained AE model and compute the OT solver yourself:

    python demo2.py --train_ot --generate_feature --decode_feature --data_root_train path-to-your-training-dataset --data_root_test path-to-your-test-dataset

  • Generated images:

pyomt's People

Contributors

k2cu8 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.