Giter Site home page Giter Site logo

latco's Introduction

Model-Based Reinforcement Learning via Latent-Space Collocation

Oleh Rybkin*1, Chuning Zhu*1, Anusha Nagabandi2, Kostas Daniilidis1, Igor Mordatch3, Sergey Levine4
(* equal contribution)

1University of Pennsylvania
2Covariant
3Google Brain
4UC Berkeley

This is a TF2 implementation for our latent-space collocation (LatCo) agent for model-based reinforcement learning. LatCo is a visual model-based reinforcement learning method that can solve long-horizon tasks by optimizing sequences of latent states, instead of optimizing actions directly as done in shooting methods such as visual foresight and planet. Optimizing latent states allows LatCo to quickly discover the high-reward region and construct effective plans even for complex multi-stage tasks that shooting methods do not solve.

Instructions

Setting up repo

git clone https://github.com/zchuning/latco.git
cd latco
git submodule update --init --recursive

Dependencies

  • Install Mujoco
  • Install dependencies
pip install --user numpy==1.19.2 cloudpickle==1.2.2 tensorflow-gpu==2.2.0 tensorflow_probability==0.10.0
pip install --user gym imageio pandas pyyaml matplotlib
pip install --user mujoco-py (optional - to run Sparse MetaWorld or Pointmass)
pip install --user dm-control (optional - to run DM control)
pip install --user -e metaworldv2

The code is tested on Ubuntu 20.04, with CUDNN 7.6.5, CUDA 10.1, and Python 3.8

Commands

Train LatCo agent on the Reaching task:

python train.py --collect_sparse_reward True --use_sparse_reward True --task mw_SawyerReachEnvV2 --planning_horizon 30 --mpc_steps 30 --agent latco --logdir logdir/mw_reach/latco/0

Evaluate LatCo agent:

python eval.py --collect_sparse_reward True --use_sparse_reward True --task mw_SawyerReachEnvV2 --planning_horizon 30 --mpc_steps 30 --agent latco --logdir logdir/mw_reach/latco/0 --logdir_eval logdir_eval/mw_reach/latco/0 --n_eval_episodes 10

To run the offline+fine-tune experiments, download the released offline data for the Hammer and the Thermos tasks. Train LatCo agent on the Thermos task with offline data (data path specified by --offline_dir):

python train.py --prefill 0 --action_repeat 2 --collect_sparse_reward True --use_sparse_reward True --task mw_SawyerStickPushEnvV2 --planning_horizon 50 --mpc_steps 25 --agent latco --logdir logdir/mw_thermos/latco/0 --offline_dir logdir/mw_thermos/offline/episodes

For convenience, we include a script for automatically generating training and evaluation commands with hyperparameters from the paper. To generate training command, run:

python gencmd.py --task mw_reach --method latco

To generate evaluation command, run:

python gencmd.py --task mw_reach --method latco --eval True

--task can be one of mw_reach, mw_button, mw_window, mw_drawer, mw_push, mw_hammer, mw_thermos, dmc_reacher, dmc_cheetah, dmc_quadruped. --method can be one of latco, planet, mppi, shooting_gd, shooting_gn, platco. To replicate ablation experiments, set --method to be one of latco_no_constraint, latco_no_relax, latco_first_order, image_colloc. To run dense reward metaworld tasks, add --dense_mw True.

Generate plots:

python plot.py --indir ./logdir --outdir ./plots --xaxis step --yaxis train/success --bins 3e4

Tensorboard:

tensorboard --logdir ./logdir

Troubleshooting

By default, the mujoco rendering for Sparse MetaWorld will use glfw. With egl rendering, gpu 0 will be used by default. You can change these with the following environment variables

export MUJOCO_RENDERER='egl'
export GL_DEVICE_ID=1

If you get ValueError: numpy.ndarray size changed, may indicate binary incompatibility. Expected 88 from C header, got 80 from PyObject, chances are, your mujoco-py was not installed properly. You can fix it with the following

pip uninstall mujoco-py
pip install mujoco-py --no-cache-dir --no-binary :all: --no-build-isolation

Code structure

  • latco.py contains the LatCo agent. It inherits planning_agent.
  • planners/probabilistic_latco.py contains the Gaussian LatCo agent. It inherits planning_agent.
  • planners/gn_solver.py is a Gauss-Newton optimizer which leverages the block-tridiagonal structure of the Jacobian to speed up computation.
  • base_agent.py is a barebone agent with an RSSM model, modified from the Dreamer agent. planning_agent.py inherits base_agent and is inherited by all methods in planners.
  • planners contains planning agents such as shooting cem, shooting gd, and a few other variants.
  • envs/sparse_metaworld.py is the wrapper for the Sparse MetaWorld benchmark. The benchmark itself is a submodule metaworldv2.
  • envs/pointmass/pointmass_prob_env.py is the pointmass lottery task.

Using Sparse MetaWorld environments

WARNING! The Sparse MetaWorld environments by default output dense reward. The sparse reward is in info['success']. A simple-to-use standalone environment code that outputs sparse reward is here, please see usage instructions by that link.

Adding new environments

See wrappers.py as well as envs/sparse_metaworld.py, envs/pointmass/pointmass_prob_env.py for examples on how to add new environments. We follow a simple gym-like interface from the Dreamer repo.

Note you can also train our agent entirely offline on an existing dataset of episodes. To do this, use the --offline_dir argument to point to the dataset and set --pretrain to the desired number of training steps.

Bibtex

If you find this code useful, please cite:

@inproceedings{rybkin2021latco,
  title={Model-Based Reinforcement Learning via Latent-Space Collocation},
  author={Rybkin, Oleh and Zhu, Chuning and Nagabandi, Anusha and Daniilidis, Kostas and Mordatch, Igor and Levine, Sergey},
  journal={Proceedings of the 38th International Conference on Machine Learning},
  year={2021}
}

Acknowledgements

This codebase was built on top of Dreamer.

latco's People

Contributors

zchuning avatar orybkin avatar

Stargazers

 avatar Nick Imanzi avatar Hushmand Esmaeili avatar Chad Mcintire avatar  avatar James avatar Jacek Cyranka avatar Ben Evans avatar Skand avatar Wenhao Ding avatar  avatar Cong Wang avatar  avatar  avatar Octavio Arriaga avatar  avatar Mupupup avatar  avatar  avatar Oleksii Kachaiev avatar Shundo Kishi avatar Matt Shaffer avatar Fei_Ni avatar Frank Röder avatar Jose Cohenca avatar Guangyuan Zhao avatar Tatsuya Matsushima avatar Shyam Sudhakaran avatar Rishabh Anand avatar Yunbo Wang avatar Cheol-Hui Min avatar

Watchers

Tatsuya Matsushima avatar Zixuan Huang avatar  avatar Matt Shaffer avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.