Giter Site home page Giter Site logo

corl-team / xland-minigrid Goto Github PK

View Code? Open in Web Editor NEW
175.0 9.0 15.0 971 KB

JAX-accelerated Meta-Reinforcement Learning Environments Inspired by XLand and MiniGrid ๐ŸŽ๏ธ

License: Apache License 2.0

Python 99.16% Shell 0.84%
jax meta-reinforcement-learning minigrid python reinforcement-learning xland

xland-minigrid's Introduction

XLand-MiniGrid

Open In Colab

img

Meta-Reinforcement Learning in JAX

๐Ÿฅณ We recently released XLand-100B, a large multi-task dataset for offline meta and in-context RL research, based on XLand-MiniGrid. It is currently the largest dataset for in-context RL, containing full learning histories for 30k unique tasks, 100B transitions, and 2.5B episodes. Check it out!

XLand-MiniGrid is a suite of tools, grid-world environments and benchmarks for meta-reinforcement learning research inspired by the diversity and depth of XLand and the simplicity and minimalism of MiniGrid. Despite the similarities, XLand-MiniGrid is written in JAX from scratch and designed to be highly scalable, democratizing large-scale experimentation with limited resources. Ever wanted to reproduce a DeepMind AdA agent? Now you can and not in months, but days!

Features

  • ๐Ÿ”ฎ System of rules and goals that can be combined in arbitrary ways to produce diverse task distributions
  • ๐Ÿ”ง Simple to extend and modify, comes with example environments ported from the original MiniGrid
  • ๐Ÿช„ Fully compatible with all JAX transformations, can run on CPU, GPU and TPU
  • ๐Ÿ“ˆ Easily scales to $2^{16}$ parallel environments and millions of steps per second on a single GPU
  • ๐Ÿ”ฅ Multi-GPU PPO baselines in the PureJaxRL style, which can achieve 1 trillion environment steps under two days

How cool is that? For more details, take a look at the technical paper or examples, which will walk you through the basics and training your own adaptive agents in minutes!

Installation ๐ŸŽ

The latest release of XLand-MiniGrid can be installed directly from PyPI:

pip install xminigrid
# or, from github directly
pip install "xminigrid @ git+https://github.com/corl-team/xland-minigrid.git"

Alternatively, if you want to install the latest development version from the GitHub and run provided algorithms or scripts, install the source as follows:

git clone [email protected]:corl-team/xland-minigrid.git
cd xland-minigrid

# additional dependencies for baselines
pip install -e ".[dev,baselines]"

Note that the installation of JAX may differ depending on your hardware accelerator! We advise users to explicitly install the correct JAX version (see the official installation guide).

Basic Usage ๐Ÿ•น๏ธ

Most users who are familiar with other popular JAX-based environments (such as gymnax or jumnaji), will find that the interface is very similar. On the high level, current API combines dm_env and gymnax interfaces.

import jax
import xminigrid
from xminigrid.wrappers import GymAutoResetWrapper
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

key = jax.random.key(0)
reset_key, ruleset_key = jax.random.split(key)

# to list available benchmarks: xminigrid.registered_benchmarks()
benchmark = xminigrid.load_benchmark(name="trivial-1m")
# choosing ruleset, see section on rules and goals
ruleset = benchmark.sample_ruleset(ruleset_key)

# to list available environments: xminigrid.registered_environments()
env, env_params = xminigrid.make("XLand-MiniGrid-R9-25x25")
env_params = env_params.replace(ruleset=ruleset)

# auto-reset wrapper
env = GymAutoResetWrapper(env)

# render obs as rgb images if needed (warn: this will affect speed greatly)
env = RGBImgObservationWrapper(env)

# fully jit-compatible step and reset methods
timestep = jax.jit(env.reset)(env_params, reset_key)
timestep = jax.jit(env.step)(env_params, timestep, action=0)

# optionally render the state
env.render(env_params, timestep)

Similar to gymnasium or jumanji, users can register new environment variations with register for convenient further usage with make. timestep is a dataclass containing step_type, reward, discount, observation, as well as the internal environment state.

For a bit more advanced introduction see provided walkthrough notebook.

On environment interface

Currently, there are a lot of new JAX-based environments appearing, each offering its own variant of API. Initially, we tried to reuse Jumanji, but it turned out that its design is not suitable for meta learning. The Gymnax design appeared to be more appropriate, but unfortunately it is not actively supported and often departs from the idea that parameters should only be contained in env_params. Furthermore, splitting timestep into multiple entities seems suboptimal to us, as it complicates many things, such as envpool or dm_env style auto reset, where the reset occurs on the next step (we need access to done of previous step).

Therefore, we decided that we would make a minimal interface that would cover just our needs without the goal of making it generic. The core of our library is interface independent, and we plan to switch to the new one when/if a better design becomes available (e.g. when stable Gymnasium FuncEnv is released).

Rules and Goals ๐Ÿ”ฎ

In XLand-MiniGrid, the system of rules and goals is the cornerstone of the emergent complexity and diversity. In the original MiniGrid some environments have dynamic goals, but the dynamics are never changed. To train and evaluate highly adaptive agents, we need to be able to change the dynamics in non-trivial ways.

Rules are the functions that can change the environment state in some deterministic way according to the given conditions. Goals are similar to rules, except they do not change the state, they only test conditions. Every task should be described with a goal, rules and initial objects. We call these rulesets. Currently, we support only one goal per task.

To illustrate, we provide visualization for specific ruleset. To solve this task, agent should take blue pyramid and put it near the purple square to transform both objects into red circle. To complete the goal, red circle should be placed near green circle. However, placing purple square near yellow circle will make it unsolvable in this trial. Initial objects positions will be randomized on each reset.

For more advanced introduction, see corresponding section in the provided walkthrough notebook.

Benchmarks ๐ŸŽฒ

While composing rules and goals by hand is flexible, it can quickly become cumbersome. Besides, it's hard to express efficiently in a JAX-compatible way due to the high number of heterogeneous computations

To avoid significant overhead during training and facilitate reliable comparisons between agents, we pre-sampled several benchmarks with up to three million unique tasks, following the procedure used to train DeepMind AdA agent from the original XLand. Each task is represented with a tree, where root is a goal and all nodes are production rules, which should be triggered in a sequence to solve the task:

These benchmarks differ in the generation configs, producing distributions with varying levels of diversity and average difficulty of the tasks. They can be used for different purposes, for example the trivial-1m benchmark can be used to debug your agents, allowing very quick iterations. However, we would caution against treating benchmarks as a progression from simple to complex. They are just different ๐Ÿคท.

Pre-sampled benchmarks are hosted on HuggingFace and will be downloaded and cached on the first use:

import jax.random
import xminigrid
from xminigrid.benchmarks import Benchmark

# downloading to path specified by XLAND_MINIGRID_DATA,
# ~/.xland_minigrid by default
benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m")
# reusing cached on the second use
benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m")

# users can sample or get specific rulesets
benchmark.sample_ruleset(jax.random.key(0))
benchmark.get_ruleset(ruleset_id=benchmark.num_rulesets() - 1)

# or split them for train & test
train, test = benchmark.shuffle(key=jax.random.key(0)).split(prop=0.8)

We also provide the script used to generate these benchmarks. Users can use it for their own purposes:

python scripts/ruleset_generator.py --help

In depth description of all available benchmarks is provided in the technical paper (Section 3).

Environments ๐ŸŒ

We provide environments from two domains. XLand is our main focus for meta-learning. For this domain we provide single environment and numerous registered variants with different grid layouts and sizes. All of them can be combined with arbitrary rulesets.

To demonstrate the generality of our library we also port majority of non-language based tasks from original MiniGrid. Similarly, some environments come with multiple registered variants. However, we have no current plans to actively develop and support them (but that may change).

Name Domain Visualization Goal
XLand-MiniGrid XLand specified by the provided ruleset
MiniGrid-Empty MiniGrid go to the green goal
MiniGrid-EmptyRandom MiniGrid go the green goal from different starting positions
MiniGrid-FourRooms MiniGrid go the green goal, but goal and starting positions are randomized
MiniGrid-LockedRoom MiniGrid find the key to unlock the door, go to the green goal
MiniGrid-Memory MiniGrid remember the initial object and choose it at the end of the corridor
MiniGrid-Playground MiniGrid goal is not specified
MiniGrid-Unlock MiniGrid unlock the door with the key
MiniGrid-UnlockPickUp MiniGrid unlock the door and pick up the object in another room
MiniGrid-BlockedUnlockPickUp MiniGrid unlock the door blocked by the object and pick up the object in another room
MiniGrid-DoorKey MiniGrid unlock the door and go to the green goal

Users can get all registered environments with xminigrid.registered_environments(). We also provide manual control to easily explore the environments:

python -m xminigrid.manual_control --env-id="MiniGrid-Empty-8x8"

Baselines ๐Ÿš€

In addition to the environments, we provide high-quality almost single-file implementations of recurrent PPO baselines in the style of PureJaxRL. With the help of magical jax.pmap transformation they can scale to multiple accelerators, achieving impressive FPS of millions during training.

Agents can be trained from the terminal and default arguments can be overwritten from the command line or from the yaml config:

# for meta learning
python training/train_meta_task.py \
    --config-path='some-path/config.yaml' \
    --env_id='XLand-MiniGrid-R1-9x9'

# for minigrid envs
python training/train_singe_task.py \
    --config-path='some-path/config.yaml' \ 
    --env_id='XLand-MiniGrid-R1-9x9'

For the source code and hyperparameters available see /training or run python training/train_meta_task.py --help. Furthermore, we provide standalone implementations that can be trained in Colab: xland, minigrid.

P.S. Do not expect that provided baselines will solve the hardest environments or benchmarks available. How much fun would that be ๐Ÿค”? However, we hope that they will help to get started quickly!

Open Logs ๐Ÿ“ฝ

We value openness and reproducibility in science, therefore all logs for the main experiments from the paper are open and available as a public wandb report. There you can discover all the latest plots, the behaviour of the losses, and exactly see the hyperparameters used. Enjoy!

Contributing ๐Ÿ”จ

We welcome anyone interested in helping out! Please take a look at our contribution guide for further instructions and open an issue if something is not clear.

See Also ๐Ÿ”Ž

A lot of other work is going in a similar direction, transforming RL through JAX. Many of them have inspired us, and we encourage users to check them out as well.

  • Brax - fully differentiable physics engine used for research and development of robotics.
  • Gymnax - implements classic environments including classic control, bsuite, MinAtar and simplistic meta learning tasks.
  • Jumanji - a diverse set of environments ranging from simple games to NP-hard combinatorial problems.
  • Pgx - JAX implementations of classic board games, such as Chess, Go and Shogi.
  • JaxMARL - multi-agent RL in JAX with wide range of commonly used environments.
  • Craftax - Crafter reimplementation with JAX.
  • Purejaxql - off-policy Q-learning baselines with JAX for single and multi-agent RL.

Let's build together!

Citation ๐Ÿ™

@inproceedings{
    nikulin2023xlandminigrid,
    title={{XL}and-MiniGrid: Scalable Meta-Reinforcement Learning Environments in {JAX}},
    author={Alexander Nikulin and Vladislav Kurenkov and Ilya Zisman and Viacheslav Sinii and Artem Agarkov and Sergey Kolesnikov},
    booktitle={Intrinsically-Motivated and Open-Ended Learning Workshop, NeurIPS2023},
    year={2023},
    url={https://openreview.net/forum?id=xALDC4aHGz}
}

xland-minigrid's People

Contributors

afspies avatar floringogianu avatar garymm avatar helpingstar avatar howuhh avatar vkurenkov avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

xland-minigrid's Issues

[Feature Request] Add Maze-like env

Hi!

Awesome job on the repo!

Feel free to ignore this request if it's not part of your roadmap. It's more of a suggestion to have other type of exploration tasks.

There's partial code on Farama-Foundation/Minigrid#317 to generate feasible mazes (with a unique direct path to the goal, I believe) based off mini-grid envs. Taking that code I was able to generate envs such as these:

I thought it might do for an interesting meta-RL exploration benchmark, i.e., can your algorithm learn to exhaustively explore the maze until it finds the goal? In principle it might not be that much different than exploring in an open-space grid, but who knows! Maybe the more constrained state-space might even accelerate (or slow down) training progress.

Cheers!

observations of env when not RGB

Hello,

great codebase. I just want to double-check my understanding. for non-RGB experiments you've done, you use the uint8 grid values as inputs? I was expecting to see you convert them to 1-hots or something.

Thanks

The rollout speed is slower than gymnasium

I am new to jax. I can't see why it is extremely slow when I run the example code on cpu. The wall time is so long for each step.

Here is my code

import jax.random
import xminigrid
import time
from xminigrid.benchmarks import Benchmark

from xminigrid.wrappers import GymAutoResetWrapper
from xminigrid.experimental.img_obs import RGBImgObservationWrapper


num_envs = 8
benchmark = xminigrid.load_benchmark(name="trivial-1m")

rng = jax.random.PRNGKey(0)
ruleset_rng = jax.random.split(rng, num=num_envs)
reset_rng = jax.random.split(rng, num=num_envs)
train = jax.vmap(benchmark.sample_ruleset)(ruleset_rng)

def make_env(rulesets, img_obs=False):

    env, env_params = xminigrid.make("XLand-MiniGrid-R1-9x9")
    env_params = env_params.replace(ruleset=rulesets)

    env = GymAutoResetWrapper(env)

    if img_obs:
        # render obs as rgb images if needed (warn: this will affect speed greatly)
        env = RGBImgObservationWrapper(env)
    return env, env_params

train_env, train_params = make_env(train)

timestep = jax.vmap(train_env.reset, in_axes=(0, 0))(train_params, reset_rng)
start = time.time()
for i in range(10):
    timestep = jax.vmap(train_env.step, in_axes=0)(train_params, timestep, action=jax.numpy.zeros((8,), dtype=jax.numpy.uint8).squeeze())
print(time.time() - start)

It take 30 seconds to run 10 steps. The minigrid environment in gymnasium only takes 0.3 seconds for 1000 steps

import gymnasium as gym
import time
# env = gym.make("MiniGrid-Empty-5x5-v0")
env = gym.make("MiniGrid-Playground-v0")
observation, info = env.reset(seed=42)
start = time.time()
for _ in range(1000):
   action = env.action_space.sample()
   observation, reward, terminated, truncated, info = env.step(action)

   if terminated or truncated:
      observation, info = env.reset()
env.close()

print(time.time() - start)

Roadmap to v1.0

Roadmap

Approx v1.0 release: February 2024

General

  • type hints
  • tests
  • documentation and examples
  • environment registration
  • colab notebooks with examples

Baselines

  • PPO+RNN baselines (single task, meta-rl)
  • open wandb reports + configs to reproduce paper figures
  • multi-gpu training

Tiles

  • Key (like in Minigrid)
  • Door (like in Minigrid)
  • Box (like in Minigrid) (may reduce FPS!!!)

Actions

  • stochasticity (could be done with a wrapper)

Rules & Goals

  • procedural generator (like in xland v2)
  • pre-sampled benchmarks, 1M+ tasks

Map

  • different grid layouts (mazes, rooms, objects)

Envs

  • porting majority of minigrid envs
  • full xland procedural meta-env
  • observation/action spaces
  • calibrate max_steps with available benchmarks and baselines
  • provide rgb image observations

Beyond v1.0 directions

  • stable env API (from gymnasium function_env ?)
  • gymnasium adapters & wrappers
  • multi-agent version (highly unlikely)
  • MAML baselines, Mamba instead of GRU?
  • Dataset + offline baselines (AD, DPT)
  • procedural maze generation compatible with jit
  • new INTERESTING rules and goals
  • fast FOV algorithm for see_through_walls=False
  • goals composition (OR, AND, NOT)
  • masked rules as a part of the observation (could be done with a wrapper)
  • rules descriptions with language
  • porting multiroom, moving obstacles, safety envs from minigrid

Do you support single thread for a single environment?

I need to evaluation different hyperparameters on the same environment. So I drop your xland-minigrid into different process. However, I found that there are multi-thread in the process by using "top -Hp ". How can I use only one thread in each process, because I believe the multi-threading have caused the the job hitting the thread limit with an LLVM error: pthread_create failed

Fail to optimize single tasks in the given demo. Maybe intermediate reward is needed.

I tried train_single_task.py and modify the config into:

    env_id: str = "XLand-MiniGrid-R1-9x9"
    benchmark_id: Optional[str] = "trivial-1m"
    ruleset_id: Optional[int] = 0
    num_envs: int = 32
    total_timesteps: int = 1_00000

The final result are

Compiling...
Done in 32.63s.
Training...
Done in 1308.51s
Logging...
Final return:  0.0
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:   actor_loss โ–†โ–†โ–ˆโ–†โ–†โ–„โ–…โ–…โ–„โ–…โ–„โ–โ–„โ–„โ–„โ–ƒโ–ƒโ–ƒโ–ƒโ–†โ–„โ–„โ–ƒโ–„โ–‚โ–ƒโ–„โ–ƒโ–ƒโ–ƒโ–„โ–ƒโ–ƒโ–„โ–…โ–„โ–„โ–ƒโ–ƒโ–ƒ
wandb:      entropy โ–ˆโ–‡โ–†โ–…โ–…โ–†โ–…โ–…โ–„โ–ƒโ–โ–‚โ–‚โ–‚โ–โ–โ–โ–โ–โ–โ–‚โ–โ–‚โ–‚โ–‚โ–‚โ–โ–โ–โ–โ–โ–โ–โ–โ–‚โ–‚โ–‚โ–‚โ–‚โ–‚
wandb: eval/lengths โ–‡โ–ˆโ–ˆโ–ˆโ–„โ–ˆโ–‡โ–…โ–โ–ˆโ–ˆโ–ˆโ–…โ–ˆโ–„โ–ƒโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ƒโ–ˆโ–ƒโ–ˆโ–ˆโ–ˆโ–ˆ
wandb: eval/returns โ–‚โ–‚โ–โ–โ–…โ–โ–ƒโ–„โ–ˆโ–โ–โ–โ–„โ–โ–…โ–†โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–†โ–โ–†โ–โ–โ–โ–
wandb:           lr โ–ˆโ–ˆโ–ˆโ–‡โ–‡โ–‡โ–‡โ–‡โ–‡โ–†โ–†โ–†โ–†โ–†โ–†โ–…โ–…โ–…โ–…โ–…โ–„โ–„โ–„โ–„โ–„โ–„โ–ƒโ–ƒโ–ƒโ–ƒโ–ƒโ–‚โ–‚โ–‚โ–‚โ–‚โ–‚โ–โ–โ–
wandb:   total_loss โ–„โ–„โ–‡โ–„โ–„โ–โ–„โ–„โ–ƒโ–†โ–†โ–โ–…โ–†โ–…โ–„โ–…โ–ƒโ–„โ–ˆโ–…โ–…โ–ƒโ–„โ–‚โ–„โ–…โ–„โ–…โ–…โ–†โ–„โ–„โ–…โ–†โ–…โ–…โ–„โ–„โ–„
wandb:  transitions โ–โ–โ–โ–‚โ–‚โ–‚โ–‚โ–‚โ–‚โ–ƒโ–ƒโ–ƒโ–ƒโ–ƒโ–ƒโ–„โ–„โ–„โ–„โ–„โ–…โ–…โ–…โ–…โ–…โ–…โ–†โ–†โ–†โ–†โ–†โ–‡โ–‡โ–‡โ–‡โ–‡โ–‡โ–ˆโ–ˆโ–ˆ
wandb:   value_loss โ–ˆโ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–‚โ–‚โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–โ–
wandb: 
wandb: Run summary:
wandb:       actor_loss 9e-05
wandb:          entropy 0.3547
wandb:     eval/lengths 243.0
wandb:     eval/returns 0.0
wandb:               lr 1e-05
wandb: steps_per_second 76.42288
wandb:       total_loss -0.00346
wandb:    training_time 1308.5086
wandb:      transitions 99840
wandb:       value_loss 0.0

It seems that the ruleset_id=0 is difficult to learn compared to ruleset_id=1 (return with less total_timesteps is 0.46). The same situation for ruleset_id=3. I found out that ruleset_id=0 and ruleset_id=3 share the same TileNearRightGoal task. I guess the tile task needs more operations than others, and it is comprised of AgentHoldGoal, AgentNearGoal and AgentOnTileGoal? I don't know if it is possible to add intermediate reward for the tile tasks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.