Giter Site home page Giter Site logo

psdrl's Introduction

Posterior Sampling for Deep Reinforcement Learning

Implementation of Posterior Sampling for Deep Reinforcement Learning (PSDRL) in PyTorch.

@inproceedings{sasso2023posterior,
  title = {Posterior Sampling for Deep Reinforcement Learning},
  author = {Sasso, Remo and Conserva, Michelangelo and Rauber, Paulo},
  booktitle={International Conference on Machine Learning},
  year = {2023}
}

Overview

PSDRL is the first truly scalable approximation of Posterior Sampling for Reinforcement Learning (PSRL) that retains its model-based essence. In the Atari benchmark, PSDRL significantly outperforms previous state-of-the-art attempts at scaling up posterior sampling such as Bootstrapped DQN + Priors and Successor Uncertainties, while being strongly competitive with the state-of-the-art DreamerV2 agent, both in sample efficiency and computational efficiency.

PSDRL maps high-dimensional observations to a low-dimensional continuous latent state using an autoencoder (a) that enables predicting transitions in latent state space for any given action using a recurrent transition model (b).

Continuous Latent Space Transition Model

PSDRL represents uncertainty through a Bayesian neural network that maintains a distribution over the parameters of the last layer of the transition model, which allows PSDRL to sample a model of the environment. Planning w.r.t. the sampled model is carried out with a value network that is fitted using predictions from the sampled model, thereby approximating the optimal policy w.r.t. the sampled model. The agent then collects data by acting greedily w.r.t. the current sampled model and value network.

By acting greedily w.r.t. different sampled models, the exploration of the agent is naturally driven through uncertainty over models of the environments. An example of trajectories predicted with different sampled models can be found below. Although each trajectory starts from the same initial state and uses identical parameters for the neural network components, it is possible to notice a remarkable diversity among the different sampled models.

For further details, results, and comparisons see the research paper.

Instructions

Install the dependencies:

pip install -r requirements.txt

You can run the PSDRL agent by calling the main.py file, which accepts a configuration file (in the yaml format) and the code name corresponding to the Atari game. For example, you can run the PSDRL agent on Pong with the parameters from the paper as,

python src/main.py --config="src/config.yaml" --env="Pong"

You can set a fixed seed with an additional parameter, e.g. --seed 42.

Training can be monitored with Tensorboard.

tensorboard --logdir=src/logdir

Environments

The repository includes the Atari games. If you wish to test the algorithm on different environments, you can add them to the init_env function in the utils.py file.

The implementation targets environments with visual observation that are grayscale in range zero-one and of dimension 64x64, so please take that into account when using new environments (see preprocess_image in utils.py which the agent uses for all inputs).

If you wish to test the algorithm on environments with vectorial observations, you can either implement a different architecture for the autoencoder (see representation.py) or remove the autoencoder altogether.

Feel free to reach out if you need any help.

Runtime

The implementation found in this repository runs on a single GPU and takes about 8 and 15 hours per 1M environment steps in Atari on an NVIDIA A100 and NVIDIA V100 GPU, respectively. The table below shows the expected runtime for an A100 GPU.

Game Runtime
Freeway 8h53m $\pm$ 0m
Qbert 7h39m $\pm$ 43m
Enduro 9h35m $\pm$ 15m
Asterix 7h44m $\pm$ 24m
Seaquest 8h31m $\pm$ 2m
Pong 7h58m $\pm$ 42m
Hero 9h26m $\pm$ 2m
Average 8h31m $\pm$ 44m

See Appendix E of the paper for a comparison with the baselines.

Tips

You can track additional metrics by calling the add_scalars function of the Logger object.

psdrl's People

Contributors

michelangeloconserva avatar remosasso avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.