Giter Site home page Giter Site logo

explainable-mp's Introduction

Towards Explainable Multi-modal Motion Prediction using Graph Representations

DOI

This repository contains code for "Towards Explainable Motion Prediction using Heterogeneous Graph Representations" by Sandra Carrasco Limeros, Sylwia Majchrowska, Joakim Johnander, Christoffer Petersson and David Fernández Llorca, 2022.

@misc{Carrasco:22b,
  doi = {10.48550/ARXIV.2212.03806},
  url = {https://arxiv.org/abs/2212.03806},
  author = {Carrasco Limeros, Sandra and Majchrowska, Sylwia
            and Johnander, Joakim and Petersson, Christoffer
            and Llorca, David Fernández},
  title = {Towards Explainable Motion Prediction using Heterogeneous Graph Representations},
  publisher = {arXiv},
  year = {2022}
}

Note: This repository is built on PGP repository

Installation

  1. Clone this repository

  2. Set up a new conda environment

conda create --name xscout python=3.7.10
  1. Install dependencies
conda activate xscout

# nuScenes devkit
pip install nuscenes-devkit

# Pytorch: The code has been tested with Pytorch 1.7.1, CUDA 10.1, but should work with newer versions
conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.1 -c pytorch

# Additional utilities
pip install ray
pip install psutil
pip install scipy
pip install positional-encodings
pip install imageio
pip install tensorboard
pip install dgl-cu101

Dataset

  1. Download the nuScenes dataset. For this project we just need the following.

    • Metadata for the Trainval split (v1.0)
    • Map expansion pack (v1.3)
  2. Organize the nuScenes root directory as follows

└── nuScenes/
    ├── maps/
    |   ├── basemaps/
    |   ├── expansion/
    |   ├── prediction/
    |   ├── 36092f0b03a857c6a3403e25b4b7aab3.png
    |   ├── 37819e65e09e5547b8a3ceaefba56bb2.png
    |   ├── 53992ee3023e5494b90c316c183be829.png
    |   └── 93406b464a165eaba6d9de76ca09f5da.png
    └── v1.0-trainval
        ├── attribute.json
        ├── calibrated_sensor.json
        ...
        └── visibility.json         
  1. Run the following script to extract pre-processed data. This speeds up training significantly.
python preprocess.py -c configs/preprocess_nuscenes.yml -r path/to/nuScenes/root/directory -d path/to/directory/with/preprocessed/data

You can download the preprocessed data in this link.

Evaluation

You can download the trained model weights using this link.

To evaluate on the nuScenes val set run the following script. This will generate a text file with evaluation metrics at the specified output directory. The results should match the benchmark entry on Eval.ai.

python evaluate.py -c configs/xscout_pgp.yml -r path/to/nuScenes/root/directory -d path/to/directory/with/preprocessed/data -o path/to/output/directory -w path/to/trained/weights

Visualization

To visualize predictions run the following script. This will generate gifs for a set of instance tokens (track ids) from nuScenes val at the specified output directory.

python visualize.py -c configs/xscout_pgp.yml -r path/to/nuScenes/root/directory -d path/to/directory/with/preprocessed/data -o path/to/output/directory -w path/to/trained/weights 

You can indicate the number of modes and future temporal horizon to visualize with --num_modes and --tf respectively.

Training

To train the model from scratch, run

python train.py -c configs/xscout_pgp.yml -r path/to/nuScenes/root/directory -d path/to/directory/with/preprocessed/data -o path/to/output/directory -n 100

The training script will save training checkpoints and tensorboard logs in the output directory. Wandb logger is also supported. You need to specify the entity and project in the wandb.init function in train.py. If you do not want to log in wandb, please use --nowandb argument.

To launch tensorboard, run

tensorboard --logdir=path/to/output/directory/tensorboard_logs

Robustness analysis

This repository contains the code to reproduce the robustness analysis (Section IV) presented in "Towards Trustworthy Multi-Modal Motion Prediction: Evaluation and Interpretability" by Sandra Carrasco, Sylwia Majchrowska,Joakim Johnander, Christoffer Petersson and David Fernández LLorca, presented at .. 2022.

You can download the PGP trained model weights using this link.

To evaluate on the nuScenes val set, you can indicate the probability of randomly masking out dynamic objects and/or lanes in agent_mask_p_veh, agent_mask_p_ped and lane_mask_prob arguments in the configuration file configs/pgp_gatx2_lvm_traversal.yml . Indicate a probability of masking out random frames of interacting agents using mask_frames_p argument.

python evaluate.py -c configs/pgp_gatx2_lvm_traversal.yml -r path/to/nuScenes/root/directory -d path/to/directory/with/preprocessed/data -o path/to/output/directory -w path/to/trained/weights

explainable-mp's People

Contributors

sancarlim avatar nachiket92 avatar majsylw avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.