Giter Site home page Giter Site logo

gom's Introduction

Transferable Reinforcement Learning via Generalized Occupancy Models

Chuning Zhu1, Xinqi Wang1, Tyler Han1, Simon Shaolei Du1, Abhishek Gupta1

1University of Washington

This is a Jax implementation of Generalized Occupancy Models (GOMs). GOM is an unsupervised reinforcement learning method that models the distribution of all possible outcomes represented as discounted sums of state-dependent cumulants. The outcome model is paired with a readout policy that produces an action to realize a particular outcome. Assuming a linear dependence of rewards on cumulants, transferring to downstream tasks reduces to performing linear regression and solving a simple optimization problem for the optimal possible outcome.

Instructions

Setting up repo

git clone https://github.com/WEIRDLabUW/gom

Install Dependencies

pip install -r requirements.txt

D4RL Experiments

To train GOMs on D4RL datasets and adapt to the default tasks, run the following commands

# Antmaze
python train.py env_id=antmaze-umaze-v2 exp_id=benchmark seed=0
python train.py env_id=antmaze-umaze-diverse-v2 exp_id=benchmark seed=0
python train.py env_id=antmaze-medium-diverse-v2 exp_id=benchmark seed=0
python train.py env_id=antmaze-medium-play-v2 exp_id=benchmark seed=0
python train.py env_id=antmaze-large-diverse-v2 exp_id=benchmark seed=0
python train.py env_id=antmaze-large-play-v2 exp_id=benchmark seed=0

# Kitchen
python train.py --config-name atrl_kitchen.yaml env_id=kitchen-partial-v0 exp_id=benchmark seed=0
python train.py --config-name atrl_kitchen.yaml env_id=kitchen-mixed-v0 exp_id=benchmark seed=0

To adapt a trained GOM to a new downstream reward, relabel the dataset with the new reward function (e.g. by adding a env wrapper and modifying the dataset class) and run the following command (changing env_id correspondingly)

python train_w.py env_id=antmaze-medium-diverse-v2 exp_id=benchmark seed=0

This will load the pretrained outcome model and readout policy and perform linear regression to fit the new rewards.

Preference antmaze experiments

To run the preference antmaze experiments, install D4RL with the custom antmaze environment from this repository. Then, download the accompanying dataset from this link and place it in data/d4rl under the project root directory. Run the following commands to train on each preference mode. Alternatively, train on only one mode and adapt to the other mode using the adaptation script.

# Go Up
python train.py env_id=multimodal-antmaze-0 exp_id=benchmark seed=0 planning.planner=random_shooting
# Go Right
python train.py env_id=multimodal-antmaze-1 exp_id=benchmark seed=0 planning.planner=random_shooting

Roboverse experiments

To run the roboverse experiments, download the roboverse dataset from this link and place the files data/roboverse under the project root directory. Use one of the following commands to train a GOM.

python train.py --config-name atrl_roboverse.yaml env_id=roboverse-pickplace-v0 exp_id=benchmark seed=0
python train.py --config-name atrl_roboverse.yaml env_id=roboverse-doubledraweropen-v0 exp_id=benchmark seed=0
python train.py --config-name atrl_roboverse.yaml env_id=roboverse-doubledrawercloseopen-v0 exp_id=benchmark seed=0

Bibtex

If you find this code useful, please cite:

@article{zhu2024gom,
    author    = {Zhu, Chuning and Wang, Xinqi and Han, Tyler and Du, Simon Shaolei and Gupta, Abhishek},
    title     = {Transferable Reinforcement Learning via Generalized Occupancy Models},
    booktitle = {ArXiv Preprint},
    year      = {2024},
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.