Giter Site home page Giter Site logo

pchalasani / rl-baselines3-zoo Goto Github PK

View Code? Open in Web Editor NEW

This project forked from dlr-rm/rl-baselines3-zoo

0.0 1.0 0.0 2.87 MB

A training framework for Stable Baselines3 reinforcement learning agents, with hyperparameter optimization and pre-trained agents included.

Home Page: https://stable-baselines3.readthedocs.io

License: MIT License

Shell 0.90% Python 98.30% Makefile 0.48% Dockerfile 0.32%

rl-baselines3-zoo's Introduction

pipeline status coverage report codestyle

RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents

RL Baselines3 Zoo is a training framework for Reinforcement Learning (RL), using Stable Baselines3.

It provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.

In addition, it includes a collection of tuned hyperparameters for common environments and RL algorithms, and agents trained with those settings.

We are looking for contributors to complete the collection!

Goals of this repository:

  1. Provide a simple interface to train and enjoy RL agents
  2. Benchmark the different Reinforcement Learning algorithms
  3. Provide tuned hyperparameters for each environment and RL algorithm
  4. Have fun with the trained agents!

This is the SB3 version of the original SB2 rl-zoo.

Train an Agent

The hyperparameters for each environment are defined in hyperparameters/algo_name.yml.

If the environment exists in this file, then you can train an agent using:

python train.py --algo algo_name --env env_id

For example (with tensorboard support):

python train.py --algo ppo --env CartPole-v1 --tensorboard-log /tmp/stable-baselines/

Evaluate the agent every 10000 steps using 10 episodes for evaluation (using only one evaluation env):

python train.py --algo sac --env HalfCheetahBulletEnv-v0 --eval-freq 10000 --eval-episodes 10 --n-eval-envs 1

Save a checkpoint of the agent every 100000 steps:

python train.py --algo td3 --env HalfCheetahBulletEnv-v0 --save-freq 100000

Continue training (here, load pretrained agent for Breakout and continue training for 5000 steps):

python train.py --algo a2c --env BreakoutNoFrameskip-v4 -i rl-trained-agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000

When using off-policy algorithms, you can also save the replay buffer after training:

python train.py --algo sac --env Pendulum-v1 --save-replay-buffer

It will be automatically loaded if present when continuing training.

Plot Scripts

Plot scripts (to be documented, see "Results" sections in SB3 documentation):

  • scripts/all_plots.py/scripts/plot_from_file.py for plotting evaluations
  • scripts/plot_train.py for plotting training reward/success

Examples (on the current collection)

Plot training success (y-axis) w.r.t. timesteps (x-axis) with a moving window of 500 episodes for all the Fetch environment with HER algorithm:

python scripts/plot_train.py -a her -e Fetch -y success -f rl-trained-agents/ -w 500 -x steps

Plot evaluation reward curve for TQC, SAC and TD3 on the HalfCheetah and Ant PyBullet environments:

python3 scripts/all_plots.py -a sac td3 tqc --env HalfCheetahBullet AntBullet -f rl-trained-agents/

Plot with the rliable library

The RL zoo integrates some of rliable library features. You can find a visual explanation of the tools used by rliable in this blog post.

First, you need to install rliable.

Note: Python 3.7+ is required in that case.

Then export your results to a file using the all_plots.py script (see above):

python scripts/all_plots.py -a sac td3 tqc --env Half Ant -f logs/ -o logs/offpolicy

You can now use the plot_from_file.py script with --rliable, --versus and --iqm arguments:

python scripts/plot_from_file.py -i logs/offpolicy.pkl --skip-timesteps --rliable --versus -l SAC TD3 TQC

Note: you may need to edit plot_from_file.py, in particular the env_key_to_env_id dictionary and the scripts/score_normalization.py which stores min and max score for each environment.

Remark: plotting with the --rliable option is usually slow as confidence interval need to be computed using bootstrap sampling.

Custom Environment

The easiest way to add support for a custom environment is to edit utils/import_envs.py and register your environment here. Then, you need to add a section for it in the hyperparameters file (hyperparams/algo.yml).

Enjoy a Trained Agent

Note: to download the repo with the trained agents, you must use git clone --recursive https://github.com/DLR-RM/rl-baselines3-zoo in order to clone the submodule too.

If the trained agent exists, then you can see it in action using:

python enjoy.py --algo algo_name --env env_id

For example, enjoy A2C on Breakout during 5000 timesteps:

python enjoy.py --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000

If you have trained an agent yourself, you need to do:

# exp-id 0 corresponds to the last experiment, otherwise, you can specify another ID
python enjoy.py --algo algo_name --env env_id -f logs/ --exp-id 0

To load the best model (when using evaluation environment):

python enjoy.py --algo algo_name --env env_id -f logs/ --exp-id 1 --load-best

To load a checkpoint (here the checkpoint name is rl_model_10000_steps.zip):

python enjoy.py --algo algo_name --env env_id -f logs/ --exp-id 1 --load-checkpoint 10000

To load the latest checkpoint:

python enjoy.py --algo algo_name --env env_id -f logs/ --exp-id 1 --load-last-checkpoint

Huggingface Hub Integration

Upload model to hub (same syntax as for enjoy.py):

python -m utils.push_to_hub --algo ppo --env CartPole-v1 -f logs/ -orga sb3 -m "Initial commit"

you can choose custom repo-name (default: {algo}-{env_id}) by passing a --repo-name argument.

Download model from hub:

python -m utils.load_from_hub --algo ppo --env CartPole-v1 -f logs/ -orga sb3

Hyperparameter yaml syntax

The syntax used in hyperparameters/algo_name.yml for setting hyperparameters (likewise the syntax to overwrite hyperparameters on the cli) may be specialized if the argument is a function. See examples in the hyperparameters/ directory. For example:

  • Specify a linear schedule for the learning rate:
  learning_rate: lin_0.012486195510232303

Specify a different activation function for the network:

  policy_kwargs: "dict(activation_fn=nn.ReLU)"

Hyperparameter Tuning

We use Optuna for optimizing the hyperparameters. Not all hyperparameters are tuned, and tuning enforces certain default hyperparameter settings that may be different from the official defaults. See utils/hyperparams_opt.py for the current settings for each agent.

Hyperparameters not specified in utils/hyperparams_opt.py are taken from the associated YAML file and fallback to the default values of SB3 if not present.

Note: when using SuccessiveHalvingPruner ("halving"), you must specify --n-jobs > 1

Budget of 1000 trials with a maximum of 50000 steps:

python train.py --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \
  --sampler tpe --pruner median

Distributed optimization using a shared database is also possible (see the corresponding Optuna documentation):

python train.py --algo ppo --env MountainCar-v0 -optimize --study-name test --storage sqlite:///example.db

Print and save best hyperparameters of an Optuna study:

python scripts/parse_study.py -i path/to/study.pkl --print-n-best-trials 10 --save-n-best-hyperparameters 10

The default budget for hyperparameter tuning is 500 trials and there is one intermediate evaluation for pruning/early stopping per 100k time steps.

Hyperparameters search space

Note that the default hyperparameters used in the zoo when tuning are not always the same as the defaults provided in stable-baselines3. Consult the latest source code to be sure of these settings. For example:

  • PPO tuning assumes a network architecture with ortho_init = False when tuning, though it is True by default. You can change that by updating utils/hyperparams_opt.py.

  • Non-episodic rollout in TD3 and DDPG assumes gradient_steps = train_freq and so tunes only train_freq to reduce the search space.

When working with continuous actions, we recommend to enable gSDE by uncommenting lines in utils/hyperparams_opt.py.

Experiment tracking

We support tracking experiment data such as learning curves and hyperparameters via Weights and Biases.

The following command

python train.py --algo ppo --env CartPole-v1 --track --wandb-project-name sb3

yields a tracked experiment at this URL.

Env normalization

In the hyperparameter file, normalize: True means that the training environment will be wrapped in a VecNormalize wrapper.

Normalization uses the default parameters of VecNormalize, with the exception of gamma which is set to match that of the agent. This can be overridden using the appropriate hyperparameters/algo_name.yml, e.g.

  normalize: "{'norm_obs': True, 'norm_reward': False}"

Env Wrappers

You can specify in the hyperparameter config one or more wrapper to use around the environment:

for one wrapper:

env_wrapper: gym_minigrid.wrappers.FlatObsWrapper

for multiple, specify a list:

env_wrapper:
    - utils.wrappers.DoneOnSuccessWrapper:
        reward_offset: 1.0
    - sb3_contrib.common.wrappers.TimeFeatureWrapper

Note that you can easily specify parameters too.

VecEnvWrapper

You can specify which VecEnvWrapper to use in the config, the same way as for env wrappers (see above), using the vec_env_wrapper key:

For instance:

vec_env_wrapper: stable_baselines3.common.vec_env.VecMonitor

Note: VecNormalize is supported separately using normalize keyword, and VecFrameStack has a dedicated keyword frame_stack.

Callbacks

Following the same syntax as env wrappers, you can also add custom callbacks to use during training.

callback:
  - utils.callbacks.ParallelTrainCallback:
      gradient_steps: 256

Env keyword arguments

You can specify keyword arguments to pass to the env constructor in the command line, using --env-kwargs:

python enjoy.py --algo ppo --env MountainCar-v0 --env-kwargs goal_velocity:10

Overwrite hyperparameters

You can easily overwrite hyperparameters in the command line, using --hyperparams:

python train.py --algo a2c --env MountainCarContinuous-v0 --hyperparams learning_rate:0.001 policy_kwargs:"dict(net_arch=[64, 64])"

Note: if you want to pass a string, you need to escape it like that: my_string:"'value'"

Record a Video of a Trained Agent

Record 1000 steps with the latest saved model:

python -m utils.record_video --algo ppo --env BipedalWalkerHardcore-v3 -n 1000

Use the best saved model instead:

python -m utils.record_video --algo ppo --env BipedalWalkerHardcore-v3 -n 1000 --load-best

Record a video of a checkpoint saved during training (here the checkpoint name is rl_model_10000_steps.zip):

python -m utils.record_video --algo ppo --env BipedalWalkerHardcore-v3 -n 1000 --load-checkpoint 10000

Record a Video of a Training Experiment

Apart from recording videos of specific saved models, it is also possible to record a video of a training experiment where checkpoints have been saved.

Record 1000 steps for each checkpoint, latest and best saved models:

python -m utils.record_training --algo ppo --env CartPole-v1 -n 1000 -f logs --deterministic

The previous command will create a mp4 file. To convert this file to gif format as well:

python -m utils.record_training --algo ppo --env CartPole-v1 -n 1000 -f logs --deterministic --gif

Current Collection: 195+ Trained Agents!

Final performance of the trained agents can be found in benchmark.md. To compute them, simply run python -m utils.benchmark.

List and videos of trained agents can be found on our Huggingface page: https://huggingface.co/sb3

NOTE: this is not a quantitative benchmark as it corresponds to only one run (cf issue #38). This benchmark is meant to check algorithm (maximal) performance, find potential bugs and also allow users to have access to pretrained agents.

Atari Games

7 atari games from OpenAI benchmark (NoFrameskip-v4 versions).

RL Algo BeamRider Breakout Enduro Pong Qbert Seaquest SpaceInvaders
A2C ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
DQN ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
QR-DQN ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️

Additional Atari Games (to be completed):

RL Algo MsPacman Asteroids RoadRunner
A2C ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️
DQN ✔️ ✔️ ✔️
QR-DQN ✔️ ✔️ ✔️

Classic Control Environments

RL Algo CartPole-v1 MountainCar-v0 Acrobot-v1 Pendulum-v1 MountainCarContinuous-v0
ARS ✔️ ✔️ ✔️ ✔️ ✔️
A2C ✔️ ✔️ ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️ ✔️ ✔️
DQN ✔️ ✔️ ✔️ N/A N/A
QR-DQN ✔️ ✔️ ✔️ N/A N/A
DDPG N/A N/A N/A ✔️ ✔️
SAC N/A N/A N/A ✔️ ✔️
TD3 N/A N/A N/A ✔️ ✔️
TQC N/A N/A N/A ✔️ ✔️
TRPO ✔️ ✔️ ✔️ ✔️ ✔️

Box2D Environments

RL Algo BipedalWalker-v3 LunarLander-v2 LunarLanderContinuous-v2 BipedalWalkerHardcore-v3 CarRacing-v0
ARS ✔️ ✔️
A2C ✔️ ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️ ✔️
DQN N/A ✔️ N/A N/A N/A
QR-DQN N/A ✔️ N/A N/A N/A
DDPG ✔️ N/A ✔️
SAC ✔️ N/A ✔️ ✔️
TD3 ✔️ N/A ✔️ ✔️
TQC ✔️ N/A ✔️ ✔️
TRPO ✔️ ✔️

PyBullet Environments

See https://github.com/bulletphysics/bullet3/tree/master/examples/pybullet/gym/pybullet_envs. Similar to MuJoCo Envs but with a free (MuJoCo 2.1.0+ is now free!) easy to install simulator: pybullet. We are using BulletEnv-v0 version.

Note: those environments are derived from Roboschool and are harder than the Mujoco version (see Pybullet issue)

RL Algo Walker2D HalfCheetah Ant Reacher Hopper Humanoid
ARS
A2C ✔️ ✔️ ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️ ✔️ ✔️
DDPG ✔️ ✔️ ✔️ ✔️ ✔️
SAC ✔️ ✔️ ✔️ ✔️ ✔️
TD3 ✔️ ✔️ ✔️ ✔️ ✔️
TQC ✔️ ✔️ ✔️ ✔️ ✔️
TRPO ✔️ ✔️ ✔️ ✔️ ✔️

PyBullet Envs (Continued)

RL Algo Minitaur MinitaurDuck InvertedDoublePendulum InvertedPendulumSwingup
A2C
PPO
DDPG
SAC
TD3
TQC

MuJoCo Environments

RL Algo Walker2d HalfCheetah Ant Swimmer Hopper Humanoid
ARS ✔️ ✔️ ✔️ ✔️ ✔️
A2C ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️ ✔️ ✔️
DDPG
SAC ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
TD3 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
TQC ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
TRPO ✔️ ✔️ ✔️ ✔️ ✔️

Robotics Environments

See https://gym.openai.com/envs/#robotics and DLR-RM#71

MuJoCo version: 1.50.1.0 Gym version: 0.18.0

We used the v1 environments.

RL Algo FetchReach FetchPickAndPlace FetchPush FetchSlide
HER+TQC ✔️ ✔️ ✔️ ✔️

Panda robot Environments

See https://github.com/qgallouedec/panda-gym/.

Similar to MuJoCo Robotics Envs but with a free easy to install simulator: pybullet.

We used the v1 environments.

RL Algo PandaReach PandaPickAndPlace PandaPush PandaSlide PandaStack
HER+TQC ✔️ ✔️ ✔️ ✔️ ✔️

To visualize the result, you can pass --env-kwargs render:True to the enjoy script.

MiniGrid Envs

See https://github.com/maximecb/gym-minigrid A simple, lightweight and fast Gym environments implementation of the famous gridworld.

RL Algo Empty FourRooms DoorKey MultiRoom Fetch
A2C
PPO
DDPG
SAC
TRPO

There are 19 environment groups (variations for each) in total.

Note that you need to specify --gym-packages gym_minigrid with enjoy.py and train.py as it is not a standard Gym environment, as well as installing the custom Gym package module or putting it in python path.

pip install gym-minigrid
python train.py --algo ppo --env MiniGrid-DoorKey-5x5-v0 --gym-packages gym_minigrid

This does the same thing as:

import gym_minigrid

Colab Notebook: Try it Online!

You can train agents online using colab notebook.

Installation

Stable-Baselines3 PyPi Package

We recommend using stable-baselines3 and sb3_contrib master versions.

apt-get install swig cmake ffmpeg
pip install -r requirements.txt

Please see Stable Baselines3 documentation for alternatives.

Docker Images

Build docker image (CPU):

make docker-cpu

GPU:

USE_GPU=True make docker-gpu

Pull built docker image (CPU):

docker pull stablebaselines/rl-baselines3-zoo-cpu

GPU image:

docker pull stablebaselines/rl-baselines3-zoo

Run script in the docker image:

./scripts/run_docker_cpu.sh python train.py --algo ppo --env CartPole-v1

Tests

To run tests, first install pytest, then:

make pytest

Same for type checking with pytype:

make type

Citing the Project

To cite this repository in publications:

@misc{rl-zoo3,
  author = {Raffin, Antonin},
  title = {RL Baselines3 Zoo},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/DLR-RM/rl-baselines3-zoo}},
}

Contributing

If you trained an agent that is not present in the RL Zoo, please submit a Pull Request (containing the hyperparameters and the score too).

Contributors

We would like to thank our contributors: @iandanforth, @tatsubori @Shade5 @mcres, @ernestum

rl-baselines3-zoo's People

Contributors

araffin avatar jkterry1 avatar ernestum avatar mcres avatar salmannotkhan avatar vwxyzjn avatar qgallouedec avatar amy12xx avatar schuderer avatar cboettig avatar cyprienc avatar kant avatar gregwar avatar blurlake avatar nikhilrayaprolu avatar sonsang avatar scottemmons avatar sgillen avatar toshikwa avatar manifoldfr avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.