Giter Site home page Giter Site logo

cascading-decision-tree's Introduction

Cascading Decision Tree (CDT) for Explainable Reinforcement Learning

Open-source code for paper CDT: Cascading Decision Trees for Explainable Reinforcement Learning (https://arxiv.org/abs/2011.07553).

Data folder: the data folders (data and data_no_norm) should be put at the root of the repo to run the code. See issues: #4. The data folders are stored at the Google Drive.

File Structure

  • data: all data for experiments (not maintained in the repo, but can be collected with the given scripts below)
    • mlp: data for MLP model;
    • cdt: data for CDT model;
    • sdt: data for SDT model;
    • il: data for general Imitation Learning (IL);
    • rl: data for general Reinforcement Learning (RL);
    • cdt_compare_depth: data for cdt with different depths in RL;
    • sdt_compare_depth: data for sdt with different depths in RL;
  • src: source code
    • mlp: training configurations for MLP as policy function approximator;
    • cdt: the Cascading Decision Tree (CDT) class and necessary functions;
    • sdt: the Soft Decision Tree (SDT) class and necessary functions;
    • hdt: the heuristic agents;
    • il: configurations for Imitation Learning (IL);
    • rl: configurations for Reinforcement Learning (RL) and RL agents (e.g., PPO) etc;
    • utils: some common functions
    • il_data_collect.py: collect dataset (state-action from heuristic or well-trained policy) for IL;
    • rl_data_collect.py: collect dataset (states during training for calculating normalization statistics) for RL;
    • il_train.py: train IL agent with different function approximators (e.g., SDT, CDT);
    • rl_train.py: train RL agent different function approximators (e.g., SDT, CDT, MLP);
    • il_eval.py: evaluate the trained IL agents before and after tree discretization, based on prediction accuracy;
    • rl_eval.py: evaluate the trained RL agents before and after tree discretization, based on episodic reward;
    • il_train.sh: bash to run IL test with different models on server;
    • rl_train.sh: bash to run RL test with different models on server;
    • rl_train_compare_sdt.py: train RL agent with SDT;
    • rl_train_compare_cdt.py: train RL agent with SDT;
    • rl_train_compare_sdt.sh: bash to run RL test with SDT of different depths on server;
    • rl_train_compare_cdt.sh: bash to run RL test with CDT of different depths on server;
  • visual
    • plot.ipynb: plot learning curves, etc.
    • params.ipynb: quantitive analysis of model parameters (SDT and CDT).
    • stability_analysis.ipynb: refer to the stability analysis in paper--compare the tree weights.

To Run

For fully replicating the experiments in the paper, the code needs to run in several stages.

A. Reinforcement Learning Comparison with SDT, CDT and MLP

  1. Collect dataset: for state normalization

    cd ./src
    python rl_data_collect.py
  2. Get statistics on dataset

    cd rl
    jupyter notebook

    open stats.ipynb and run cells in it to generate files for dataset statistics.

    Step 1, 2 can be skipped is not using state normalization.

  3. Train RL agents with different policy function approximators: SDT, CDT, MLP

    cd ..
    python rl_train.py --train --env='CartPole-v1' --method='sdt' --id=0
    python rl_train.py --train --env='LunarLander-v2' --method='cdt' --id=0
    python rl_train.py --train --env='MountainCar-v0' --method='mlp' --id=0

    or simply run with:

    ./rl_train.sh
  4. Evaluate the trained agents (with discretization operation)

    python rl_eval.py --env='CartPole-v1' --method='sdt'
    python rl_eval.py --env='LunarLander-v2' --method='cdt'
  5. Results visualization

    cd ../visual
    jupyter notebook

    see in plot.ipynb.

B. Imitation Learning Comparison with SDT and CDT

  1. Collect dataset: for (1) state normalization and (2) as imitation learning dataset

    cd ./src
    python il_data_collect.py
  2. Train RL agents with different policy function approximators: SDT, CDT

    python il_train.py --train --env='CartPole-v1' --method='sdt' --id=0
    python il_train.py --train --env='LunarLander-v2' --method='cdt' --id=0

    or simply run with:

    ./il_train.sh
  3. Evaluate the trained agents

    python il_eval.py --env='CartPole-v1' --method='sdt'
    python il_eval.py --env='LunarLander-v2' --method='cdt'
  4. Results visualization

    cd ../visual
    jupyter notebook
    

    see in plot.ipynb.

B'. Imitation Learning with DAGGER and Q-DAGGER

DAGGER and Q-DAGGER methods in VIPER are compared in the paper as well under the imitation learning setting. Code in ./src/viper/. Credit gives to Hangrui (Henry) Bi .

C. Tree Depths for SDT and CDT in Reinforcement Learning

Run the comparison with different tree depths:

For SDT:

./rl_train_compare_sdt.sh

For CDT:

./rl_train_compare_cdt.sh

D. Stability Analysis

Compare the tree weights of different agents in IL:

cd ./visual
jupyner notebook

See in stability_analysis.ipynb.

E. Model Simplicity

Quantitative analysis of number of model parameters:

cd ./visual
jupyter notebook

See in params.ipynb.

Citation:

@article{ding2020cdt,
  title={Cdt: Cascading decision trees for explainable reinforcement learning},
  author={Ding, Zihan and Hernandez-Leal, Pablo and Ding, Gavin Weiguang and Li, Changjian and Huang, Ruitong},
  journal={arXiv preprint arXiv:2011.07553},
  year={2020}
}

cascading-decision-tree's People

Contributors

quantumiracle avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

cascading-decision-tree's Issues

The data

Would you please provide the following data?

an error occurred when running code

I have run the first part of the code:python rl_data_collect.py. But the following errors appeared below:
PermissionError: [Errno 13] Permission denied: '../data/mlp/model/cartpole/ppo'
Is it because my data is incomplete?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.