Giter Site home page Giter Site logo

fubowen1229 / speedplusbaseline Goto Github PK

View Code? Open in Web Editor NEW

This project forked from tpark94/speedplusbaseline

0.0 0.0 0.0 39.55 MB

PyTorch implementations of CNNs and domain-bridging algorithms used in baseline studies of SPEED+ dataset

License: MIT License

Python 100.00%

speedplusbaseline's Introduction

SPEED+: Next-Generation Dataset for Spacecraft Pose Estimation across Domain Gap

This repository is developed by Tae Ha "Jeff" Park at Space Rendezvous Laboratory (SLAB) of Stanford University.

  • [2021.12.02] Our paper will be presented at the 2022 IEEE Aerospace Conference! This repository is updated for our latest draft which will soon become available in arXiv.

Introduction

This is the official repository of the baseline studies conducted in our paper titled SPEED+: Next-Generation Dataset for Spacecraft Pose Estimation across Domain Gap. It consists of the official PyTorch implementations of the following CNN models:

  • Keypoint Regression Network (KRN) [arXiv]
  • Spacecraft Pose Network (SPN) [arXiv]

The implementation of the SPN model follows from the original work by Sumant Sharma based on Tensorflow and MATLAB. The repository also supports the following algorithms for the KRN model:

Currently Unavailable Features

The SPEED+ dataset is currently released and used for the Satellite Pose Estimation Competition (SPEC2021). The competition is planned to last until the end of March 2022. In fairness of the competition, some items that are necessary to reproduce our results will not be available until at least the competition concludes. These include:

  • Keypoints data used to train KRN (src/utils/tangoPoints.mat)
  • CSV files containing bounding box and keypoint labels (KRN) or spacecraft attitude classes (SPN) for all domains of SPEED+

However, you can still create your own data to generate CSV files and train/test the models.

Installation

The code is developed and tested with python 3.7 on Ubuntu 20.04. It is implemented with PyTorch 1.8.0 and trained on a single NVIDIA GeForce RTX 2080 Ti 12GB GPU.

  1. Install PyTorch.

  2. Clone this repository. Its full path ($PROJROOT) should be specified for --projroot in config.py.

  3. Install dependencies:

pip install -r requirements.txt
  1. Download SPEED+. Its full path ($DATAROOT) should be specified for --dataroot in config.py.

  2. Place the appropriate CSV files under $DATAROOT/{domain}/splits_{model}/. For example, CSV files for synthetic training and validation sets for KRN should be placed under $DATAROOT/synthetic/splits_krn/. See below for creating CSV files for yourself.

  3. Download the pre-trained AlexNet weights (bvlc_alexnet.npy) from here and place it under $PROJROOT/checkpoints/pretrained/ to be used for SPN.

Pre-processing

  1. First, recover 11 keypoints as described in this paper. The order of keypoints does not matter as long as you are consistent with them. Save it as [3 x 11] array under the variable named tango3Dpoints and save it under src/utils/tangoPoints.mat. If you choose to save it elsewhere, make sure to specify its location w.r.t. $PROJROOT at --keypts_3d_model in config.py.

  2. For SPN, the attitude classes are provided at src/utils/attitudeClasses.mat.

  3. Pre-processing can be done from preprocess.py. Specify the below arguments when running the script, which will convert the JSON file at $DATAROOT/{domain}/{jsonfile} to $DATAROOT/{domain}/{outcsvfile}.

Argument Description
--model_name KRN or SPN (e.g. krn)
--domain Dataset domain (e.g. synthetic)
--jsonfile JSON file name to convert (e.g. train.json)
--csvfile CSV file to write (e.g. splits_krn/train.csv)

For example, to create CSV file of SPEED+ synthetic training set for KRN, run

python preprocess.py --model_name krn --domain synthetic --jsonfile train.json --csvfile splits_krn/train.csv

Training & Testing

Use below arguments to toggle on/off some settings:

Argument Description
--no_cuda Disable GPU training
--use_fp16 Use mixed-precision training
--randomize_texture Perform style augmentation online during training
--perform_dann Perform domain adaptation via DANN

Note the networks in this repository are not trained with mixed-precision training, but it's recommended if your GPU supports Tensor Cores to expedite the training.

To train KRN on SPEED+ synthetic training set:

python train.py --savedir 'checkpoints/krn/synthetic_only' \
                --logdir 'log/krn/synthetic_only' \
                --model_name 'krn' --input_shape 224 224 \
                --batch_size 48 --max_epochs 75 \
                --optimizer 'adamw' --lr 0.001 \
                --weight_decay 0.01 --lr_decay_alpha 0.95 \
                --train_domain 'synthetic' --test_domain 'synthetic' \
                --train_csv 'train.csv' --test_csv 'test.csv'

Add --randomize_texture to train with style augmentation.

To test KRN on synthetic validation images:

python test.py --pretrained 'checkpoints/krn/synthetic_only/model_best.pth.tar' \
               --logdir 'log/krn/synthetic_only' --resultfn 'results.txt' \
               --model_name 'krn' --input_shape 224 224 \
               --test_domain 'synthetic' --test_csv 'validation.csv'

which will write the test results to $PROJROOT/log/krn/synthetic_only/results.txt.

To test KRN on lightbox test images with DANN:

python adapt.py --savedir 'checkpoints/krn/dann_lightbox' \
                --logdir 'log/krn/dann_lightbox' --resultfn 'results.txt' \
                --model_name 'krn' --input_shape 224 224 \
                --batch_size 16 --max_epochs 750 --test_epoch 50 \
                --optimizer 'adamw' --lr 0.001 \
                --weight_decay 0.01 --lr_decay_alpha 0.95 --lr_decay_step 10 \
                --train_domain 'synthetic' --test_domain 'lightbox' \
                --train_csv 'train.csv' --test_csv 'lightbox.csv' \
                --perform_dann

which currently assumes lightbox.csv is available with test labels for occasional validation. (You can comment out relevant parts in adapt.py to not run testing at all.)

License

The SPEED+ basline studies repository is released under the MIT License.

Citation

If you find this repository and the SPEED+ dataset helpful in your research, please cite the paper below along with the dataset itself.

@article{park2021speedplus,
        title={{SPEED}+: Next-Generation Dataset for Spacecraft Pose Estimation across Domain Gap},
        author={Park, Tae Ha and M{\"a}rtens, Marcus and Lecuyer, Gurvan and Izzo, Dario and D'Amico, Simone},
        journal={arXiv preprint arXiv:2110.03101},
        year={2021},
}

KRN was introduced in the following paper:

@inproceedings{park2019krn,
	author={Park, Tae Ha and Sharma, Sumant and D'Amico, Simone},
	booktitle={2019 AAS/AIAA Astrodynamics Specialist Conference, Portland, Maine},
	title={Towards Robust Learning-Based Pose Estimation of Noncooperative Spacecraft},
	year={2019},
	month={August 11-15}
}

SPN was introduced in the following paper:

@inproceedings{sharma2019spn,
	author={Sharma, Sumant and D'Amico, Simone},
	booktitle={2019 AAS/AIAA Space Flight Mechanics Meeting, Ka'anapali, Maui, HI},
	title={Pose Estimation for Non-Cooperative Spacecraft Rendezvous Using Neural Networks},
	year={2019},
	month={January 13-17}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.