This repository is the code of Trajectory-Aware Imitation Learning from Observations (TAILO) method for NeurIPS 2023 submission "A Simple Solution for Offline Imitation from Observations and Examples with Possibly Incomplete Trajectories".
TAILO/data: The data for the experiments, which are partly given in the repo and partly generated by TAILO/datagen.py.
TAILO/envs, TAILO/maze_model.py: Necessary code for some environments; developed from SMODICE [1] repo https://github.com/JasonMa2016/SMODICE.
TAILO/NN.py, TAILO/advance_NN.py: Code for neural networks.
TAILO/main.py: Main file and entry point of our algorithm.
TAILO/dataset.py: Code for dataset handling.
TAILO/utils.py, TAILO/EMA.py, TAILO/normalizer.py: Other auxiliary codes.
Mujoco210 is required for all environments, and we run our experiments with CUDA 11.3. Below are the dependency for python packages:
d4rl == 1.1
dm-control == 1.0.7
ema-pytorch == 0.2.1
gym == 0.21.0
h5py == 2.10.0
mujoco-py == 2.1.2.14
numpy == 1.23.1
torch == 1.12.0
tqdm
wandb
By default, OpenGL is used for the D4RL [2] environments. However, sometimes OpenGL will be problematic on headless machines; to fix this, try to set MUJOCO_GL environment variable to 'egl' or 'osmesa'.
-
Clone the repository.
-
Install the dependencies as stated in the dependency section.
-
Assume you are in the TAILO directory. Run the following command:
cd envs
python generate_antmaze_random.py --noisy
cd ..
python datagen.py
The first script generates part of data for the task-agnostic data of antmaze environment, and the second script generate most datasets. It might take 10 minutes to run the first script and 30-60 minutes to run the second. We give the rest of the data directly in the data folder.
- Find in the line in TAILO/main.py
wandb.login(key=XXXXXXX)
wandb.init(entity=XXXXXXX, project= ...
change XXXXXXX to your key and username for wandb. See wandb official website https://docs.wandb.ai/ for this.
- run the code to reproduce results; see the next section for command.
python main.py --env_name $ENV --skip_suffix_TA missing$X --seed 1000000 --auto 1
Here, $X = {2, 3, 5, 10, 20}. --auto 1 command automatically generates names of running in wandb; otherwise the name needs to be manually input.
python main.py --env_name $ENV --skip_suffix_TS headtail$X_$Y --seed 1000000 --auto 1
Here, ($X, $Y) = {(0.1, 0.001), (0.09, 0.01), (0.05, 0.05), (0.01, 0.09), (0.001, 0.1)}.
python main.py --env_name $ENV --skip_suffix_TA expert40 --seed 1000000 --auto 1
Here, $Env = {hopper, halfcheetah, walker2d, ant}. Commands for kitchen and antmaze environments are different because of different dataset size:
python main.py --env_name kitchen --gamma 0.98 --eval_interval 1000 --N 80000 --seed 1000000 --auto 1
python main.py --env_name antmaze --auto 1 --seed 1000000
python main.py --env_name pointmaze --gamma 0.98 --skip_suffix_TS goal --auto 1 --eval_interval 1500 --N 170000 --seed 1000000
python main.py --env_name kitchen --eval_interval 1000 --N 80000 --skip_suffix_TS goal-kettle --seed 1000000 --auto 1
python main.py --env_name kitchen --eval_interval 1000 --N 80000 --skip_suffix_TS goal-microwave --seed 1000000 --auto 1
python main.py --env_name antmaze --skip_suffix_TS goal --auto 1 --seed 1000000
python main.py --env_name halfcheetah --PU 3 --auto 1 --seed 1000000 --skip_suffix_TS mismatch
python main.py --env_name ant --PU 3 --auto 1 --seed 1000000 --skip_suffix_TS mismatch
python main.py --env_name antmaze --PU 3 --auto 1 --seed 1000000 --skip_suffix_TS mismatch
Random + Few Expert:
python main.py --env_name $ENV --seed 1000000 --skip_suffix_TA recoilRFE --auto 1
Medium + Expert:
python main.py --env_name $ENV --seed 1000000 --skip_suffix_TA recoilME --auto 1
python main.py --env_name $ENV --seed 1000000 --auto 1
Here, $Env = {hopper, halfcheetah, walker2d, ant}.
[1] Y. J. Ma, A. Shen, D. Jayaraman, and O. Bastani. Smodice: Versatile offline imitation learning via state occupancy matching. In ICML, 2022.
[2] J. Fu, A. Kumar, O. Nachum, G. Tucker, and S. Levine. D4rl: Datasets for deep data-driven reinforcement learning. ArXiv:2004.07219, 2020.