Giter Site home page Giter Site logo

tailo's Introduction

Trajectory-Aware Imitation Learning from Observations (TAILO)

This repository is the code of Trajectory-Aware Imitation Learning from Observations (TAILO) method for NeurIPS 2023 submission "A Simple Solution for Offline Imitation from Observations and Examples with Possibly Incomplete Trajectories".

[Project Page]

File Structure

TAILO/data: The data for the experiments, which are partly given in the repo and partly generated by TAILO/datagen.py.

TAILO/envs, TAILO/maze_model.py: Necessary code for some environments; developed from SMODICE [1] repo https://github.com/JasonMa2016/SMODICE.

TAILO/NN.py, TAILO/advance_NN.py: Code for neural networks.

TAILO/main.py: Main file and entry point of our algorithm.

TAILO/dataset.py: Code for dataset handling.

TAILO/utils.py, TAILO/EMA.py, TAILO/normalizer.py: Other auxiliary codes.

Dependency

Mujoco210 is required for all environments, and we run our experiments with CUDA 11.3. Below are the dependency for python packages:

d4rl == 1.1
dm-control == 1.0.7
ema-pytorch == 0.2.1
gym == 0.21.0
h5py == 2.10.0
mujoco-py == 2.1.2.14
numpy == 1.23.1
torch == 1.12.0
tqdm
wandb

By default, OpenGL is used for the D4RL [2] environments. However, sometimes OpenGL will be problematic on headless machines; to fix this, try to set MUJOCO_GL environment variable to 'egl' or 'osmesa'.

Running Code

  1. Clone the repository.

  2. Install the dependencies as stated in the dependency section.

  3. Assume you are in the TAILO directory. Run the following command:

cd envs
python generate_antmaze_random.py --noisy
cd ..
python datagen.py

The first script generates part of data for the task-agnostic data of antmaze environment, and the second script generate most datasets. It might take 10 minutes to run the first script and 30-60 minutes to run the second. We give the rest of the data directly in the data folder.

  1. Find in the line in TAILO/main.py
wandb.login(key=XXXXXXX)
wandb.init(entity=XXXXXXX, project= ...

change XXXXXXX to your key and username for wandb. See wandb official website https://docs.wandb.ai/ for this.

  1. run the code to reproduce results; see the next section for command.

Commands for Reproducing Results

Learning from Task-Agnostic Dataset with Incomplete Trajectories (Sec. 4.1)

python main.py --env_name $ENV --skip_suffix_TA missing$X --seed 1000000 --auto 1

Here, $X = {2, 3, 5, 10, 20}. --auto 1 command automatically generates names of running in wandb; otherwise the name needs to be manually input.

Learning from Task-Specific Dataset with Incomplete Trajectories (Sec. 4.2)

python main.py --env_name $ENV --skip_suffix_TS headtail$X_$Y --seed 1000000 --auto 1

Here, ($X, $Y) = {(0.1, 0.001), (0.09, 0.01), (0.05, 0.05), (0.01, 0.09), (0.001, 0.1)}.

Standard Imitation Learning from Observation (Sec. 4.3)

python main.py --env_name $ENV --skip_suffix_TA expert40 --seed 1000000 --auto 1

Here, $Env = {hopper, halfcheetah, walker2d, ant}. Commands for kitchen and antmaze environments are different because of different dataset size:

python main.py --env_name kitchen --gamma 0.98 --eval_interval 1000 --N 80000 --seed 1000000 --auto 1
python main.py --env_name antmaze --auto 1 --seed 1000000

Learning from Examples (Sec. 4.4)

python main.py --env_name pointmaze --gamma 0.98 --skip_suffix_TS goal --auto 1 --eval_interval 1500 --N 170000 --seed 1000000
python main.py --env_name kitchen --eval_interval 1000 --N 80000 --skip_suffix_TS goal-kettle --seed 1000000 --auto 1
python main.py --env_name kitchen --eval_interval 1000 --N 80000 --skip_suffix_TS goal-microwave --seed 1000000 --auto 1
python main.py --env_name antmaze --skip_suffix_TS goal --auto 1 --seed 1000000

Learning from Mismatched Dynamics (Sec. 4.5)

python main.py --env_name halfcheetah --PU 3 --auto 1 --seed 1000000 --skip_suffix_TS mismatch 
python main.py --env_name ant --PU 3 --auto 1 --seed 1000000 --skip_suffix_TS mismatch
python main.py --env_name antmaze --PU 3 --auto 1 --seed 1000000 --skip_suffix_TS mismatch

Performance Comparison Against ReCOIL (Sec. F.2)

Random + Few Expert:

python main.py --env_name $ENV --seed 1000000 --skip_suffix_TA recoilRFE --auto 1

Medium + Expert:

python main.py --env_name $ENV --seed 1000000 --skip_suffix_TA recoilME --auto 1

Performance Comparison Under Identical SMODICE Settings (Sec. F.5)

python main.py --env_name $ENV --seed 1000000 --auto 1

Here, $Env = {hopper, halfcheetah, walker2d, ant}.

Reference

[1] Y. J. Ma, A. Shen, D. Jayaraman, and O. Bastani. Smodice: Versatile offline imitation learning via state occupancy matching. In ICML, 2022.

[2] J. Fu, A. Kumar, O. Nachum, G. Tucker, and S. Levine. D4rl: Datasets for deep data-driven reinforcement learning. ArXiv:2004.07219, 2020.

tailo's People

Contributors

kaiyan289 avatar eltociear avatar

Stargazers

Alias for KaiYan289 avatar  avatar

Watchers

 avatar

Forkers

eltociear

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.