Giter Site home page Giter Site logo

wx-b / roll-1 Goto Github PK

View Code? Open in Web Editor NEW

This project forked from yufeiwang63/roll

0.0 0.0 0.0 32.12 MB

Implementation for "ROLL: Visual Self-Supervised Reinforcement Learning with Object Reasoning", CoRL 2020

Home Page: https://sites.google.com/andrew.cmu.edu/roll

License: MIT License

Python 96.07% Shell 0.01% Ruby 0.45% Mako 0.17% CSS 0.31% JavaScript 1.04% HTML 1.96%

roll-1's Introduction

ROLL: Visual Self-Supervised Reinforcement Learning with Object Reasoning

This is the official implementation for the paper "ROLL: Visual Self-Supervised Reinforcement Learning with Object Reasoning", to appear at Conference on Robot Learning (CoRL) 2020. [Arxiv], [Project Page]
Authors: Yufei Wang*, Gautham Narayan*, Xingyu Lin, Brian Okorn, David Held (* indicates equal contribution)

Instructions

  1. Create conda environment
conda env create -f environment.yml
  1. Activate the conda environment
source prepare.sh
  1. Install our customized multiworld environments
cd multiworld
pip install .

(Note: this requires that you have Mujoco installed on your system)

  1. Download the data we collected for pre-training the scene-VAE, the object-VAE & LSTM.
    https://drive.google.com/drive/folders/1oUYtB72Xn6n7okHgToY_onEk9NMIxfG4?usp=sharing
    Please download the entire data directory, and move it inside this code repo, so we have the following structure:
ROLL-release/data/local/goals  
ROLL-release/data/local/pre-train-lstm    
ROLL-release/data/local/pre-train-vae    
  1. Download the pre-trained Unet models
    https://drive.google.com/drive/folders/17_WKs7FDLfyYEtsSJAIWbIuNYGspdQ2j?usp=sharing
    please download all 4 models and move them to ROLL-release/segmentation/, i.e., we will have the following structure:
ROLL-release/segmentation/pytorchmodel_sawyer_door  
ROLL-release/segmentation/pytorchmodel_sawyer_hurdle  
ROLL-release/segmentation/pytorchmodel_sawyer_pickup  
ROLL-release/segmentation/pytorchmodel_sawyer_push  

We also provide an example script for training a U-Net model from scratch on the hurdle-top environment.
Please refer to segmentation/ReadMe.md.

  1. Test the above steps:
python ROLL/launch_files/launch_lstm_sawyerhurdle.py

If the code runs correctly, two gifs that visualize the learning process should be created very soon.

All the logs will be saved at data/local/debug/{exp-prefix-detailed-date}
Sepcifically, there will be:
- variant.json that stores all the hyper-parameters;
- progress.csv that stores all the logging information during training;
- itr_{epoch}.pkl and params.pkl that stores the trained policy, vae, environments, and all necessary data for reruning the policy after training;
- various debugging plots/gifs for analysis.

  1. Running ROLL:
python ROLL/launch_files/launch_ROLL_sawyerpush.py --no-debug # Puck Pushing
python ROLL/launch_files/launch_ROLL_sawyerhurdlemiddle.py --no-debug # Puck Pushing Hurdle-Bottom
python ROLL/launch_files/launch_ROLL_sawyerhurdle.py --no-debug # Puck Pushing Hurdle-Top
python ROLL/launch_files/launch_ROLL_sawyerdoor.py --no-debug # Door Opening
python ROLL/launch_files/launch_ROLL_sawyerpickup.py --no-debug # Object Pickup

All logs will be dumped at data/local/{exp-prefix-detailed-date}
You can use viskit data/local/exp-prefix-detailed-date} to view the learning progress in a local port.

  1. Running Skewfit:
python examples/skewfit/launch_skewfit_sawyer_push.py --no-debug # Puck Pushing
python examples/skewfit/launch_skewfit_sawyer_push_hurdle.py --no-debug # Puck Pushing Hurdle-{Top, Bottom}
python examples/skewfit/launch_skewfit_sawyer_door.py --no-debug # Door Opening
python examples/skewfit/launch_sawyer_pickup.py --no-debug # Object Pickup

All logs will be dumped at data/local/{exp-prefix-detailed-date}.
You can use viskit data/local/exp-prefix-detailed-date} to view the learning progress in a local port.

Pre-trained models

We provide a pre-trained model (which include a pre-trained scene-VAE, object-VAE/LSTM, and a pre-trained policy) of ROLL for each task. It should have been automatically downloaded when you download the data for pre-training the scene-VAE/object-VAE/LSTM in step 4.
The pre-trained models are at data/local/pre-trained-models/{env-name}/params.pkl, along with a gif visulization of its performance.
To create the gif by yourself, simply run

python scripts/run_goal_conditioned_policy.py --dir {dir-to-env-params.pkl} --gpu

A gif visual visual.gif should be soon dumped at the same directory as the params.pkl file.
E.g., visuals for the trained model on hurdle-bottom and hurdle-top push env are as below:
Gif Gif
Left: hurdlt-bottom puck pushing; Right: hurdle-top puck pushing.
In the gif:

  • The 1st row is the segmented goal image
  • The 2nd row is the image observation of the trained policy's execution
  • The 3rd row is the corresponding segmented object image
  • The 4th row is the scene-VAE reconstruction
  • The 5th row is the object-VAE reconstruction.

Train ROLL with pre-trained scene-VAE, objcet-VAE and LSTM

The default behaviour of ROLL is to retrain the scene-VAE, object-VAE and LSTM from scratch. This could take a while to run, therefore, we prvodie the option of running ROLL with pre-trained VAE models. The pre-trained models are included at data/local/pre-trained-models/{env-name}/params.pkl.
To use pre-trained models, simply change the vae_path variable under skewfit_varaint in the launch files. E.g, for Running ROLL on Puck-Pushing-Hurdle-Bottom with pre-trained VAEs/LSTM, the vae_path variable should be set to data/local/pre-trained-models/puck-push-hurdle-bottom/ (see line 31 at ROLL/launch_files/launch_ROLL_sawyerhurdlemiddle.py)

Change segmentation method

Ideally, ROLL should work with any segmentation code that removes the static background and robot arm. As stated in the paper, in this work we mainly use openCV background subtraction and UNet to achieve these two tasks. Any other segmentation methods that do the same thing should work.
We have tried to write the code in a modular way so it would be easy to swap to other segmentation methods. If you want to change the segmentation method, there are a few lines of code that you will need to change:

  • You can implement your new segmentation method inside segmentation/segment_image.py. Say it is named new_segment_func.
  • In ROLL/LSTM_wrapped_env.py, import new_segment_func, and change self.segment_func to new_segment_func. In the function segment_obs, pass in the correct parameters that are required for your segmentation function.
  • You can comment the code segment from line 174 to line 202 in skewfit_full_experiments_LSTM.py -- that's the code for pretraining the openCV background subtraction. If your new segmentation method does not need it, you can comment it.
    With this being said, we have not fully tested if the code would work perfectly with another segmentaion method. Feel free to open an issue or send an email to [email protected] to discuss if you encounter a bug/not sure how exactly to implement about this.

Citation

If you find this codebase useful in your research, please consider citing:

@inproceedings{corl2020roll,
 title={ROLL: Visual Self-Supervised Reinforcement Learning with Object Reasoning},
 author={Wang, Yufei and Narasimhan Gautham and Lin, Xingyu and Okorn, Brian and Held, David},
 booktitle={Conference on Robot Learning},
 year={2020}
}

References

roll-1's People

Contributors

yufeiwang63 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.