Giter Site home page Giter Site logo

--dcem-pytorch's Introduction

The Differentiable Cross-Entropy Method

This repository is by Brandon Amos and Denis Yarats and contains the PyTorch library and source code to reproduce the experiments in our ICML 2020 paper on The Differentiable Cross-Entropy Method. This repository depends on the Limited Multi-Label Projection Layer. Our code provides an implementation of the vanilla cross-entropy method for optimization and our differentiable extension. The core library source code is in dcem/; our experiments are in exp/, including the regression notebook and the action embedding notebook that produced most of the plots in our paper; basic usage examples of our code that are not published in our paper are in examples.ipynb; our slides are available here in pptx and pdf formats; and the full LaTeX source code for our paper is in paper/.

Setup

Once you have PyTorch setup, you can install our core code as a package with pip:

pip install git+git://github.com/facebookresearch/dcem.git

This should automatically install the Limited Multi-Label Projection Layer dependency.

Basic usage

Our core cross-entropy method implementation with the differentiable extension is available in dcem. We provide a lightweight wrapper for using CEM and DCEM in the control setting in dcem_ctrl. These can be imported as:

from dcem import dcem, dcem_ctrl

The interface for DCEM is:

dcem(
    f, # Objective to optimize
    nx, # Number of dimensions to optimize over
    n_batch, # Number of elements in the batch
    init_mu, # Initial mean
    init_sigma, # Initial variance
    n_sample, # Number of samples CEM uses in each iteration
    n_elite, # Number of elite CEM candidates in each iteration
    n_iter, # Number of CEM iterations
    temp, # DCEM temperature parameter, set to None for vanilla CEM
    iter_cb, # Iteration callback
)

And our control interface is:

dcem_ctrl(
    obs=obs, # Initial state
    plan_horizon, # Planning horizon for the control problem
    init_mu, # Initial control sequence mean, warm-starting can be done here
    init_sigma, # Initial variance around the control sequence
    n_sample, # Number of samples CEM uses in each iteration
    n_elite, # Number of elite CEM candidates in each iteration
    n_iter, # Number of CEM iterations
    n_ctrl, # Number of control dimensions
    lb, # Lower-bound of the control signal
    ub, # Upper-bound of the control signal
    temp, # DCEM temperature parameter, set to None for vanilla CEM
    rollout_cost, # Function that returns the cost of rollout out a control sequence
    iter_cb, # CEM iteration callback
)

Simple examples

examples.ipynb provides a light introduction for using our interface for simple optimization and control problems.

2d optimization

We first show how to use DCEM to optimize a 2-dimensional objective:

Next we parameterize that objective and show how DCEM can update the objective to move the minimum to a desired location:

Pendulum control

We show how to use CEM to solve a pendulum control problem, which can be made differentiable by setting a non-zero temperature for the soft top-k operation.

Reproducing our experimental results

We provide the source code for our cartpole and regression experiments in the exps directory. We do not have plans to open source our PlaNet and PPO experiment. One starting point is to use an existing PyTorch PlaNet implementation such as cross32768/PlaNet_PyTorch with a PyTorch PPO implementation such as ikostrikov/pytorch-a2c-ppo-acktr-gai or SAC implementation such as denisyarats/pytorch_sac.

1D energy-based regression

The base experimental code for our 1D energy-based regression experiment is in regression.py. Once running this, the results can be analyzed with regression-analysis.ipynb, which will produce:

Embedding actions in the cartpole

The base experimental code for our cartpole action embedding experiment is in cartpole_emb.py. Once running this, the results can be analyzed with cartpole_emb-analysis.ipynb, which will produce:

Citations

If you find this repository helpful in your publications, please consider citing our paper.

@inproceedings{amos2020differentiable,
  title={{The Differentiable Cross-Entropy Method}},
  author={Brandon Amos and Denis Yarats},
  booktitle={ICML},
  year={2020}
}

Licensing

This repository is licensed under the CC BY-NC 4.0 License.

--dcem-pytorch's People

Contributors

bamos avatar

Stargazers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.