Giter Site home page Giter Site logo

zhikanggfu / diamond Goto Github PK

View Code? Open in Web Editor NEW

This project forked from eloialonso/diamond

0.0 0.0 0.0 12.53 MB

DIAMOND (DIffusion As a Model Of eNvironment Dreams) is a reinforcement learning agent trained in a diffusion world model.

Home Page: https://arxiv.org/abs/2405.12399

License: MIT License

Shell 0.07% Python 99.93%

diamond's Introduction

Diffusion for World Modeling: Visual Details Matter in Atari

TL;DR We introduce DIAMOND (DIffusion As a Model Of eNvironment Dreams), a reinforcement learning agent trained in a diffusion world model.

Autoregressive imagination with DIAMOND on a subset of Atari games DIAMOND agent in WM

Quick install to try our pretrained world models using miniconda:

git clone [email protected]:eloialonso/diamond.git
cd diamond
conda create -n diamond python=3.10
conda activate diamond
pip install -r requirements.txt
python src/play.py --pretrained

Alternatively, if you do not have miniconda installed you can use python venv:

git clone [email protected]:eloialonso/diamond.git
cd diamond
python3 -m venv diamond_env
source activate ./diamond_env/bin/
pip install -r requirements.txt
python src/play.py --pretrained

And press m to take control (the policy is playing by default)!

Warning: Atari ROMs will be downloaded with the dependencies, which means that you acknowledge that you have the license to use them.

Quick Links

⬆️ Try our playable diffusion world models

python src/play.py --pretrained

Then select a game, and world model and policy pretrained on Atari 100k will be downloaded from our repository on Hugging Face Hub 🤗 and cached on your machine.

Some things you might want to try:

  • Press m to change the policy between the agent and human (the policy is playing by default).
  • Press ↑/↓ to change the imagination horizon (default is 50 for playing).

To adjust the sampling parameters (number of denoising steps, stochasticity, order, etc) of the trained diffusion world model, for instance to trade off sampling speed and quality, edit the section world_model_env.diffusion_sampler in the file config/trainer.yaml.

See Visualization for more details about the available commands and options.

⬆️ Launch a training run

To train with the hyperparameters used in the paper, launch:

python src/main.py env.train.id=BreakoutNoFrameskip-v4 common.device=cuda:0

This creates a new folder for your run, located in outputs/YYYY-MM-DD/hh-mm-ss/.

To resume a run that crashed, navigate to the fun folder and launch:

./scripts/resume.sh

⬆️ Configuration

We use Hydra for configuration management.

All configuration files are located in the config folder:

  • config/trainer.yaml: main configuration file.
  • config/agent/default.yaml: architecture hyperparameters.
  • config/env/atari.yaml: environment hyperparameters.

You can turn on logging to weights & biases in the wandb section of config/trainer.yaml.

Set training.model_free=true in the file config/trainer.yaml to "unplug" the world model and perform standard model-free reinforcement learning.

⬆️ Visualization

⬆️ Play mode (default)

To visualize your last checkpoint, launch from the run folder:

python src/play.py

By default, you visualize the policy playing in the world model. To play yourself, or switch to the real environment, use the controls described below.

Controls (play mode)

(Game-specific commands will be printed on start up)

⏎   : reset environment

m   : switch controller (policy/human)
↑/↓ : imagination horizon (+1/-1)
←/→ : next environment [world model ←→ real env (test) ←→ real env (train)]

.   : pause/unpause
e   : step-by-step (when paused)

Add -r to toggle "recording mode" (works only in play mode). Every completed episode will be saved in dataset/rec_<env_name>_<controller>. For instance:

  • dataset/rec_wm_π: Policy playing in world model.
  • dataset/rec_wm_H: Human playing in world model.
  • dataset/rec_test_H: Human playing in test real environment.

You can then use the "dataset mode" described in the next section to replay the stored episodes.

⬆️ Dataset mode (add -d)

In the run folder, to visualize the datasets contained in the dataset subfolder, add -d to switch to "dataset mode":

python src/play.py -d

You can use the controls described below to navigate the datasets and episodes.

Controls (dataset mode)

m   : next dataset (if multiple datasets, like recordings, etc)
↑/↓ : next/previous episode
←/→ : next/previous timestep in episodes
PgUp: +10 timesteps
PgDn: -10 timesteps
⏎   : back to first timestep

⬆️ Other options, common to play/dataset modes

--fps FPS             Target frame rate (default 15).
--size SIZE           Window size (default 800).
--no-header           Remove header.

⬆️ Run folder structure

Each new run is located at outputs/YYYY-MM-DD/hh-mm-ss/. This folder is structured as follows:

outputs/YYYY-MM-DD/hh-mm-ss/
│
└─── checkpoints
│   │   state.pt  # full training state
│   │
│   └─── agent_versions
│       │   ...
│       │   agent_epoch_00999.pt
│       │   agent_epoch_01000.pt  # agent weights only
│
└─── config
│   |   trainer.yaml
|
└─── dataset
│   │
│   └─── train
│   |   │   info.pt
│   |   │   ...
|   |
│   └─── test
│       │   info.pt
│       │   ...
│
└─── scripts
│   │   resume.sh
|   |   ...
|
└─── src
|   |   main.py
|   |   ...
|
└─── wandb
    |   ...

⬆️ Results

The file results/data/DIAMOND.json contains the results for each game and seed used in the paper.

⬆️ Citation

@misc{alonso2024diffusion,
      title={Diffusion for World Modeling: Visual Details Matter in Atari},
      author={Eloi Alonso and Adam Jelley and Vincent Micheli and Anssi Kanervisto and Amos Storkey and Tim Pearce and François Fleuret},
      year={2024},
      eprint={2405.12399},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

⬆️ Credits

diamond's People

Contributors

eloialonso avatar adamjelley avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.