Giter Site home page Giter Site logo

utnnproject / torch-ac Goto Github PK

View Code? Open in Web Editor NEW

This project forked from lcswillems/torch-ac

0.0 0.0 0.0 19 KB

Recurrent and multi-process PyTorch implementation of deep reinforcement Actor-Critic algorithms A2C and PPO

License: MIT License

Python 100.00%

torch-ac's Introduction

PyTorch Actor-Critic deep reinforcement learning algorithms: A2C and PPO

The torch_ac package contains the PyTorch implementation of two Actor-Critic deep reinforcement learning algorithms:

Note: An example of use of this package is given in the rl-starter-files repository. More details below.

Features

  • Recurrent policies
  • Reward shaping
  • Handle observation spaces that are tensors or dict of tensors
  • Handle discrete action spaces
  • Observation preprocessing
  • Multiprocessing
  • CUDA

Installation

pip3 install torch-ac

Note: If you want to modify torch-ac algorithms, you will need to rather install a cloned version, i.e.:

git clone https://github.com/lcswillems/torch-ac.git
cd torch-ac
pip3 install -e .

Package components overview

A brief overview of the components of the package:

  • torch_ac.A2CAlgo and torch_ac.PPOAlgo classes for A2C and PPO algorithms
  • torch_ac.ACModel and torch_ac.RecurrentACModel abstract classes for non-recurrent and recurrent actor-critic models
  • torch_ac.DictList class for making dictionnaries of lists list-indexable and hence batch-friendly

Package components details

Here are detailled the most important components of the package.

torch_ac.A2CAlgo and torch_ac.PPOAlgo have 2 methods:

  • __init__ that may take, among the other parameters:
    • an acmodel actor-critic model, i.e. an instance of a class inheriting from either torch_ac.ACModel or torch_ac.RecurrentACModel.
    • a preprocess_obss function that transforms a list of observations into a list-indexable object X (e.g. a PyTorch tensor). The default preprocess_obss function converts observations into a PyTorch tensor.
    • a reshape_reward function that takes into parameter an observation obs, the action action taken, the reward reward received and the terminal status done and returns a new reward. By default, the reward is not reshaped.
    • a recurrence number to specify over how many timesteps gradient is backpropagated. This number is only taken into account if a recurrent model is used and must divide the num_frames_per_agent parameter and, for PPO, the batch_size parameter.
  • update_parameters that first collects experiences, then update the parameters and finally returns logs.

torch_ac.ACModel has 2 abstract methods:

  • __init__ that takes into parameter an observation_space and an action_space.
  • forward that takes into parameter N preprocessed observations obs and returns a PyTorch distribution dist and a tensor of values value. The tensor of values must be of size N, not N x 1.

torch_ac.RecurrentACModel has 3 abstract methods:

  • __init__ that takes into parameter the same parameters than torch_ac.ACModel.
  • forward that takes into parameter the same parameters than torch_ac.ACModel along with a tensor of N memories memory of size N x M where M is the size of a memory. It returns the same thing than torch_ac.ACModel plus a tensor of N memories memory.
  • memory_size that returns the size M of a memory.

Note: The preprocess_obss function must return a list-indexable object (e.g. a PyTorch tensor). If your observations are dictionnaries, your preprocess_obss function may first convert a list of dictionnaries into a dictionnary of lists and then make it list-indexable using the torch_ac.DictList class as follow:

>>> d = DictList({"a": [[1, 2], [3, 4]], "b": [[5], [6]]})
>>> d.a
[[1, 2], [3, 4]]
>>> d[0]
DictList({"a": [1, 2], "b": [5]})

Note: if you use a RNN, you will need to set batch_first to True.

Examples

Examples of use of the package components are given in the rl-starter-scripts repository.

Example of use of torch_ac.A2CAlgo and torch_ac.PPOAlgo

...

algo = torch_ac.PPOAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_lambda,
                        args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence,
                        args.optim_eps, args.clip_eps, args.epochs, args.batch_size, preprocess_obss)

...

exps, logs1 = algo.collect_experiences()
logs2 = algo.update_parameters(exps)

More details here.

Example of use of torch_ac.DictList

torch_ac.DictList({
    "image": preprocess_images([obs["image"] for obs in obss], device=device),
    "text": preprocess_texts([obs["mission"] for obs in obss], vocab, device=device)
})

More details here.

Example of implementation of torch_ac.RecurrentACModel

class ACModel(nn.Module, torch_ac.RecurrentACModel):
    ...

    def forward(self, obs, memory):
        ...

        return dist, value, memory

More details here.

Examples of preprocess_obss functions

More details here.

torch-ac's People

Contributors

lcswillems avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.