Giter Site home page Giter Site logo

grammi's Introduction

GrAMMI: Graph Based Adversarial Modeling with Mutual Information

Arxiv - IROS 2023

GrAMMI is an adversarial tracking method that uses a graph neural network along with a regularized gaussian mixture model that is regularized using mutual information. The method is designed for tracking adversarial targets in sparse partially observable environments.

Training Models

1. Create Conda environment

conda env create -f requirements.yaml
conda activate grammi

2. Download Datasets

Visit the Google Drive to download the datasets. We provide datasets for the Smuggler Domain (high and low visibility) and Prisoner datasets (high, medium, and low visibility). Download the datasets to grammi_datasets in the same directory for the default options.

3. Train

The models can be trained using run_multiple.py. The following command trains multiple seeds for a single model type and a single time horizon prediction.

4. Evaluate the model

Once the models are trained, you can evaluate the models using evaluate/evaluate_models_mi.py. This will record the metrics for the models based on the test set.

Details of configuration

Citation

If you find our code or paper is useful, please consider citing:

@inproceedings{ye2023grammi,
  title={Learning Models of Adversarial Agent Behavior under Partial
Observability},
  author={Ye, Sean and Natarajan, Manisha and Wu, Zixuan and Paleja, Rohan and Chen, Letian and Gombolay, Matthew},
  booktitle={IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
  year={2023}
}

License

This code is distributed under an MIT LICENSE.

grammi's People

Contributors

xunil17 avatar

Stargazers

 avatar Guo Shuai avatar

Watchers

 avatar Zac Chen avatar  avatar Kostas Georgiou avatar

Forkers

rawlaw001

grammi's Issues

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x37 and 45x16)

When I was reproducing your code, I found that your dataset labeling was messy, sometimes it reads from the dataset directory, sometimes it reads from the grammi_dataset directory, I manually changed the relevant paths. However, when the code in run_multiple.py runs to the last line of the run_categorical_info function, an error occurs, and I can't find a solution even after tracing the code back to the source, can you help me see where the problem is?
The following is the error message:

grammi_datasets/prisoner_datasets/3_detect/train
logs/0/3_detect/categorical/20240321-1444/config.yaml
Training model: categorical_mi
Dataset: 3_detect
101 logs/0/3_detect/categorical/20240321-1444/best.pth

Training on Padded Dataset w/ prisoner
Training on Padded Dataset w/ prisoner
0%| | 0/100 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/mnt/workspace/GrAMMI/run_multiple.py", line 196, in
run_categorical_mutual_info(timestep, dataset_paths, base_log_directory)
File "/mnt/workspace/GrAMMI/run_multiple.py", line 80, in run_categorical_mutual_info
main_config_mutual_info(config)
File "/mnt/workspace/GrAMMI/train_mi_gaussian_posterior_categorical.py", line 314, in main_config_mutual_info
train(seed,
File "/mnt/workspace/GrAMMI/train_mi_gaussian_posterior_categorical.py", line 171, in train
post = posterior(*posterior_input)
File "/opt/conda/envs/grammi/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/workspace/GrAMMI/train_mi_gaussian_posterior_categorical.py", line 55, in forward
x = self.leaky_relu(self.h1(x))
File "/opt/conda/envs/grammi/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/envs/grammi/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
File "/opt/conda/envs/grammi/lib/python3.9/site-packages/torch/nn/functional.py", line 1848, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x37 and 45x16)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.