Giter Site home page Giter Site logo

eloialonso / iris Goto Github PK

View Code? Open in Web Editor NEW
756.0 756.0 70.0 14.09 MB

Transformers are Sample-Efficient World Models. ICLR 2023, notable top 5%.

Home Page: https://openreview.net/forum?id=vhFu1Acb0xb

License: GNU General Public License v3.0

Jupyter Notebook 34.05% Python 65.43% Shell 0.52%
artificial-intelligence atari deep-learning machine-learning reinforcement-learning research transformers world-models

iris's People

Contributors

eloialonso avatar vmicheli avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

iris's Issues

update to continuous actions

Thank you for sharing the code.

I have been updating the code to work with continuous action environments such as DeepMind Control suite, but it seems like the actor-critic is not learning properly and the actor stuck with a bad policy and not getting better.

Here is what I update in the code:

  1. I evaluate my code with cartpole env. with only one continuous action.

  2. I discretize the continuous action into 11 bins, so the action space is 11 discrete actions.

  3. The reward loss is a mean-squared error as the reward is a real number between 0 and 1.

I have already searched for different values for the learning rate and the cross-entropy coefficient. I also tried to increase the number of bins up to 15.

Any suggestions on how to solve this issue!

World model working

Hello,

Thank you for the great work you have done! I had a question regarding the working of the world model.

From the paper,
'Our autoregressive Transformer is based on the implementation of minGPT (Karpathy, 2020). It takes as input a sequence of L(K + 1) tokens and embeds it into a L(K + 1) x D tensor using an A x D embedding table for actions, and a N x D embedding table for frames tokens'

In the code world_model.py,

def compute_loss(self, batch: Batch, tokenizer: Tokenizer, **kwargs: Any) -> LossWithIntermediateLosses:
        with torch.no_grad():        
            obs_tokens = tokenizer.encode(batch['observations'], should_preprocess=True).tokens

Why are we using the token indices as input and then embedding them into a new space when we already have encodings of the frame from the tokenizer?

Training in multi-GPU system

Thanks for your work! Can you tell me how to train the algorithm with multi-GPU systems? I have roughly read the codes, but I didn't find any code about distributed training. And I also want to know how long it will take in a single 3090 desktop PC. Thank you for your help!

Size mismatch when evluating world model

Hi, I run your codes on using python src/main.py env.train.id=BreakoutNoFrameskip-v4 common.device=cuda:0. However, I encountered following issue when evaluating the world model for the first time:
image
Could you please help me fix this problem?

ForkingPickler: Can't pickle local object

I am trying to execute the example training run python src/main.py env.train.id=BreakoutNoFrameskip-v4 common.device=cuda:0 wandb.mode=online, but I am getting a Can't pickle local object error. A partial backtrace is shown here:

Exception has occurred: AttributeError
  (note: full exception trace is shown but execution is paused at: _run_module_as_main)

Can't pickle local object 'Trainer.__init__.<locals>.create_env.<locals>.<lambda>'
  File "[PATH]\Lib\multiprocessing\reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
  File "[PATH]\Lib\multiprocessing\popen_spawn_win32.py", line 93, in __init__
    reduction.dump(process_obj, to_child)
  File "[PATH]\Lib\multiprocessing\context.py", line 327, in _Popen
    return Popen(process_obj)
  File "[PATH]\Lib\multiprocessing\context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "[PATH]\Lib\multiprocessing\process.py", line 121, in start
    self._popen = self._Popen(self)
  File "src\envs\multi_process_env.py", line 61, in __init__
    p.start()
  File "src\trainer.py", line 66, in create_env
    return MultiProcessEnv(env_fn, num_envs, should_wait_num_envs_ratio=1.0) if num_envs > 1 else SingleProcessEnv(env_fn)
  File "src\trainer.py", line 74, in __init__
    test_env = create_env(cfg.env.test, cfg.collection.test.num_envs)
  File "src\main.py", line 9, in main
    trainer = Trainer(cfg)

Any recommendation on workarounds?

About world model training

Hi, thank you for your wonderful work!
I have a question about world model training. If I look at world_model.py, I think you are masking the token which is the output of the tokenizer. Is the world model learning the masking problem? I think this is different from the normal world model training like presented in Dreamer, and so on.
Sincerely.

please gab code (delta-IRIS )

Hey Eloi, your delta-IRIS idea seems nice. A concrete improvement in the largest limitation of IRIS. I'm sure you've improved other parts of the code. Do you mind please uploading the code? Even just messy code to a branch

After reading your IRIS delta poster/paper, I wodner if you've considered running it for 1M steps so you can contriute to the crafter benchmark?

AttributeError: 'EntryPoints' object has no attribute 'get'

I'm trying to launch a training run, but when I run the command, I get the following error:

Traceback (most recent call last):
  File "src/main.py", line 4, in <module>
    from trainer import Trainer
  File "/work/ahsia/autoexperiment_dataset/iris/src/trainer.py", line 17, in <module>
    from agent import Agent
  File "/work/ahsia/autoexperiment_dataset/iris/src/agent.py", line 7, in <module>
    from models.actor_critic import ActorCritic
  File "/work/ahsia/autoexperiment_dataset/iris/src/models/actor_critic.py", line 14, in <module>
    from envs.world_model_env import WorldModelEnv
  File "/work/ahsia/autoexperiment_dataset/iris/src/envs/__init__.py", line 2, in <module>
    from .wrappers import make_atari, ResizeObsWrapper
  File "/work/ahsia/autoexperiment_dataset/iris/src/envs/wrappers.py", line 7, in <module>
    import gym
  File "/work/ahsia/anaconda3/envs/iris/lib/python3.7/site-packages/gym/__init__.py", line 13, in <module>
    from gym.envs import make, spec, register
  File "/work/ahsia/anaconda3/envs/iris/lib/python3.7/site-packages/gym/envs/__init__.py", line 10, in <module>
    _load_env_plugins()
  File "/work/ahsia/anaconda3/envs/iris/lib/python3.7/site-packages/gym/envs/registration.py", line 250, in load_env_plugins
    for plugin in metadata.entry_points().get(entry_point, []):
AttributeError: 'EntryPoints' object has no attribute 'get'

About the study restart function

I would like to restart an interrupted training using the following code in src/trainer.py.

if cfg.common.resume:.
    self.load_checkpoint()

What directory and what command should I type to use this function?

I have changed cfg.common.resume to true in outputs/YYYYY-MM-DD/hh-mm-ss/ and I run python src/main.py env.train.id=BreakoutNoFrameskip-v4 common.device=cuda:0 wandb.mode=online, but I got an error.

System Requirements

Hi, what are the systems requirements for this?

I'm using a 3070Ti and am running out of cuda memory.

Thanks,

Actor Critic training time

Hi,

I am running IRIS on a cluster with 4 A100 GPUs. Using 1 GPU, training the Actor Critic for 1 epoch, 200 steps takes me around 25 minutes which is a long time. I tried parallelising the code with 4 GPUs using torch DDP, but this only slows down the AC training to 56 minutes. When profiling the code to find out what is taking so long I have come to the conclusion that the world_model_env within the imagine function of the actor critic is taking almost all of this time in both single gpu as well as ddp training:

ac_train_slow

Is it normal for the AC to train for 25 minutes per epoch? Is there a way to speed this up, either on single GPU or parallel training?

I am using a custom Carla Gym env with observation sizes 64 x 64 x 3.

Thank you.

Compatibility of continuous action space

Thanks for this great work and thanks for opensourcing it!

As I can see in the paper, it shows experiment result in discrete action space but not continuous one. And as mentioned in #13 by @2M-kotb, "the transformer-based world model expects action tokens in discrete form". So, is it possible to modify IRIS for continous actions?

If it's possible, do you have any suggestion working on this?

Thanks!

predict the first token of next obs?

image
There are 4*4=16 tokens in one obs. Each time, only the last 15 tokens of the current obs and the first token of the next obs are predicted, instead of the 16 tokens of the next obs?

Training Actor-Critics: Full Observation vs Latent-State (tokens)

Nice work and thanks for the code!

  1. Why did you decided to train the actor-critics on the FULL Original-Observation-Image/RGB or Imagined-Observation-Image/RGB rather than just simply using output of the transformer i.e. tokens?
  2. Have you done any ablation study on the above matter?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.