Giter Site home page Giter Site logo

ustbzgn / superpoint_transformer Goto Github PK

View Code? Open in Web Editor NEW

This project forked from drprojects/superpoint_transformer

0.0 0.0 0.0 8.08 MB

[ICCV'23] Official PyTorch implementation of Superpoint Transformer introduced in "Efficient 3D Semantic Segmentation with Superpoint Transformer"

License: MIT License

Shell 0.52% Python 94.06% Jupyter Notebook 5.42%

superpoint_transformer's Introduction

Superpoint Transformer

python pytorch lightning hydra license

Official implementation for the ICCV 2023 paper
Efficient 3D Semantic Segmentation with Superpoint Transformer
๐Ÿš€โšก๐Ÿ”ฅ

PWC PWC PWC

arXiv DOI Project page


๐Ÿ“Œ Description

SPT is a superpoint-based transformer ๐Ÿค– architecture that efficiently โšก performs semantic segmentation on large-scale 3D scenes. This method includes a fast algorithm that partitions ๐Ÿงฉ point clouds into a hierarchical superpoint structure, as well as a self-attention mechanism to exploit the relationships between superpoints at multiple scales.

โœจ SPT in numbers โœจ
๐Ÿ“Š SOTA on S3DIS 6-Fold (76.0 mIoU)
๐Ÿ“Š SOTA on KITTI-360 Val (63.5 mIoU)
๐Ÿ“Š Near SOTA on DALES (79.6 mIoU)
๐Ÿฆ‹ 212k parameters (PointNeXt รท 200, Stratified Transformer รท 40)
โšก S3DIS training in 3h on 1 GPU (PointNeXt รท 7, Stratified Transformer รท 70)
โšก Preprocessing x7 faster than SPG

๐Ÿ“ฐ Updates


๐Ÿ’ป Environment requirements

This project was tested with:

  • Linux OS
  • NVIDIA GTX 1080 Ti 11G, NVIDIA V100 32G, NVIDIA A40 48G
  • CUDA 11.8 (torch-geometric does not support CUDA 12.0 yet)
  • conda 23.3.1

๐Ÿ— Installation

Simply run install.sh to install all dependencies in a new conda environment named spt.

# Creates a conda env named 'spt' env and installs dependencies
./install.sh

Note: See the Datasets page for setting up your dataset path and file structure.


๐Ÿ”ฉ Project structure

โ””โ”€โ”€ superpoint_transformer
    โ”‚
    โ”œโ”€โ”€ configs                   # Hydra configs
    โ”‚   โ”œโ”€โ”€ callbacks                 # Callbacks configs
    โ”‚   โ”œโ”€โ”€ data                      # Data configs
    โ”‚   โ”œโ”€โ”€ debug                     # Debugging configs
    โ”‚   โ”œโ”€โ”€ experiment                # Experiment configs
    โ”‚   โ”œโ”€โ”€ extras                    # Extra utilities configs
    โ”‚   โ”œโ”€โ”€ hparams_search            # Hyperparameter search configs
    โ”‚   โ”œโ”€โ”€ hydra                     # Hydra configs
    โ”‚   โ”œโ”€โ”€ local                     # Local configs
    โ”‚   โ”œโ”€โ”€ logger                    # Logger configs
    โ”‚   โ”œโ”€โ”€ model                     # Model configs
    โ”‚   โ”œโ”€โ”€ paths                     # Project paths configs
    โ”‚   โ”œโ”€โ”€ trainer                   # Trainer configs
    โ”‚   โ”‚
    โ”‚   โ”œโ”€โ”€ eval.yaml                 # Main config for evaluation
    โ”‚   โ””โ”€โ”€ train.yaml                # Main config for training
    โ”‚
    โ”œโ”€โ”€ data                      # Project data (see docs/datasets.md)
    โ”‚
    โ”œโ”€โ”€ docs                      # Documentation
    โ”‚
    โ”œโ”€โ”€ logs                      # Logs generated by hydra and lightning loggers
    โ”‚
    โ”œโ”€โ”€ media                     # Media illustrating the project
    โ”‚
    โ”œโ”€โ”€ notebooks                 # Jupyter notebooks
    โ”‚
    โ”œโ”€โ”€ scripts                   # Shell scripts
    โ”‚
    โ”œโ”€โ”€ src                       # Source code
    โ”‚   โ”œโ”€โ”€ data                      # Data structure for hierarchical partitions
    โ”‚   โ”œโ”€โ”€ datamodules               # Lightning DataModules
    โ”‚   โ”œโ”€โ”€ datasets                  # Datasets
    โ”‚   โ”œโ”€โ”€ dependencies              # Compiled dependencies
    โ”‚   โ”œโ”€โ”€ loader                    # DataLoader
    โ”‚   โ”œโ”€โ”€ loss                      # Loss
    โ”‚   โ”œโ”€โ”€ metrics                   # Metrics
    โ”‚   โ”œโ”€โ”€ models                    # Model architecture
    โ”‚   โ”œโ”€โ”€ nn                        # Model building blocks
    โ”‚   โ”œโ”€โ”€ optim                     # Optimization 
    โ”‚   โ”œโ”€โ”€ transforms                # Functions for transforms, pre-transforms, etc
    โ”‚   โ”œโ”€โ”€ utils                     # Utilities
    โ”‚   โ”œโ”€โ”€ visualization             # Interactive visualization tool
    โ”‚   โ”‚
    โ”‚   โ”œโ”€โ”€ eval.py                   # Run evaluation
    โ”‚   โ””โ”€โ”€ train.py                  # Run training
    โ”‚
    โ”œโ”€โ”€ tests                     # Tests of any kind
    โ”‚
    โ”œโ”€โ”€ .env.example              # Example of file for storing private environment variables
    โ”œโ”€โ”€ .gitignore                # List of files ignored by git
    โ”œโ”€โ”€ .pre-commit-config.yaml   # Configuration of pre-commit hooks for code formatting
    โ”œโ”€โ”€ install.sh                # Installation script
    โ”œโ”€โ”€ LICENSE                   # Project license
    โ””โ”€โ”€ README.md

Note: See the Datasets page for further details on `data/``.

Note: See the Logs page for further details on `logs/``.


๐Ÿš€ Usage

Datasets

See the Datasets page to set up your datasets.

Evaluating SPT

Use the following commands to evaluate SPT from a checkpoint file checkpoint.ckpt:

# Evaluate SPT on S3DIS Fold 5
python src/eval.py experiment=s3dis datamodule.fold=5 ckpt_path=/path/to/your/checkpoint.ckpt

# Evaluate SPT on KITTI-360 Val
python src/eval.py experiment=kitti360  ckpt_path=/path/to/your/checkpoint.ckpt 

# Evaluate SPT on DALES
python src/eval.py experiment=dales ckpt_path=/path/to/your/checkpoint.ckpt

Note: The pretrained weights of the SPT and SPT-nano models for S3DIS 6-Fold, KITTI-360 Val, and DALES are available at:

DOI

Training SPT

Use the following commands to train SPT on a 32G-GPU:

# Train SPT on S3DIS Fold 5
python src/train.py experiment=s3dis datamodule.fold=5

# Train SPT on KITTI-360 Val
python src/train.py experiment=kitti360 

# Train SPT on DALES
python src/train.py experiment=dales

Use the following to train SPT on a 11G-GPU ๐Ÿ’พ (training time and performance may vary):

# Train SPT on S3DIS Fold 5
python src/train.py experiment=s3dis_11g datamodule.fold=5

# Train SPT on KITTI-360 Val
python src/train.py experiment=kitti360_11g 

# Train SPT on DALES
python src/train.py experiment=dales_11g

Note: Encountering CUDA Out-Of-Memory errors ๐Ÿ’€๐Ÿ’พ ? See our dedicated troubleshooting section.

Note: Other ready-to-use configs are provided in configs/experiment/. You can easily design your own experiments by composing configs:

# Train Nano-3 for 50 epochs on DALES
python src/train.py datamodule=dales model=nano-3 trainer.max_epochs=50

See Lightning-Hydra for more information on how the config system works and all the awesome perks of the Lightning+Hydra combo.

Note: By default, your logs will automatically be uploaded to Weights and Biases, from where you can track and compare your experiments. Other loggers are available in configs/logger/. See Lightning-Hydra for more information on the logging options.

Notebooks & visualization

We provide notebooks to help you get started with manipulating our core data structures, configs loading, dataset and model instantiation, inference on each dataset, and visualization.

In particular, we created an interactive visualization tool โœจ which can be used to produce shareable HTMLs. Demos of how to use this tool are provided in the notebooks. Additionally, examples of such HTML files are provided in media/visualizations.7z


๐Ÿ“š Documentation

  • README - General introduction to the project
  • Data - Introduction to NAG and Data, the core data structures of this project
  • Datasets - Introduction to Datasets and the project's data/ structure
  • Logging - Introduction to logging and the project's logs/ structure

Note: We endeavoured to comment our code as much as possible to make this project usable. Still, if you find some parts are unclear or some more documentation would be needed, feel free to let us know by creating an issue !


๐Ÿ‘ฉโ€๐Ÿ”ง Troubleshooting

Here are some common issues and tips for tackling them.

SPT on an 11G-GPU

Our default configurations are designed for a 32G-GPU. Yet, SPT can run on an 11G-GPU ๐Ÿ’พ, with minor time and performance variations.

We provide configs in configs/experiment/ for training SPT on an 11G-GPU ๐Ÿ’พ:

# Train SPT on S3DIS Fold 5
python src/train.py experiment=s3dis_11g datamodule.fold=5

# Train SPT on KITTI-360 Val
python src/train.py experiment=kitti360_11g 

# Train SPT on DALES
python src/train.py experiment=dales_11g

CUDA Out-Of-Memory Errors

Having some CUDA OOM errors ๐Ÿ’€๐Ÿ’พ ? Here are some parameters you can play with to mitigate GPU memory use, based on when the error occurs.

Parameters affecting CUDA memory.

Legend: ๐ŸŸก Preprocessing | ๐Ÿ”ด Training | ๐ŸŸฃ Inference (including validation and testing during training)

Parameter Description When
datamodule.xy_tiling Splits dataset tiles into xy_tiling^2 smaller tiles, based on a regular XY grid. Ideal square-shaped tiles ร  la DALES. Note this will affect the number of training steps. ๐ŸŸก๐ŸŸฃ
datamodule.pc_tiling Splits dataset tiles into 2^pc_tiling smaller tiles, based on a their principal component. Ideal for varying tile shapes ร  la S3DIS and KITTI-360. Note this will affect the number of training steps. ๐ŸŸก๐ŸŸฃ
datamodule.max_num_nodes Limits the number of $P_1$ partition nodes/superpoints in the training batches. ๐Ÿ”ด
datamodule.max_num_edges Limits the number of $P_1$ partition edges in the training batches. ๐Ÿ”ด
datamodule.voxel Increasing voxel size will reduce preprocessing, training and inference times but will reduce performance. ๐ŸŸก๐Ÿ”ด๐ŸŸฃ
datamodule.pcp_regularization Regularization for partition levels. The larger, the fewer the superpoints. ๐ŸŸก๐Ÿ”ด๐ŸŸฃ
datamodule.pcp_spatial_weight Importance of the 3D position in the partition. The smaller, the fewer the superpoints. ๐ŸŸก๐Ÿ”ด๐ŸŸฃ
datamodule.pcp_cutoff Minimum superpoint size. The larger, the fewer the superpoints. ๐ŸŸก๐Ÿ”ด๐ŸŸฃ
datamodule.graph_k_max Maximum number of adjacent nodes in the superpoint graphs. The smaller, the fewer the superedges. ๐ŸŸก๐Ÿ”ด๐ŸŸฃ
datamodule.graph_gap Maximum distance between adjacent superpoints int the superpoint graphs. The smaller, the fewer the superedges. ๐ŸŸก๐Ÿ”ด๐ŸŸฃ
datamodule.graph_chunk Reduce to avoid OOM when RadiusHorizontalGraph preprocesses the superpoint graph. ๐ŸŸก
datamodule.dataloader.batch_size Controls the number of loaded tiles. Each train batch is composed of batch_size*datamodule.sample_graph_k spherical samplings. Inference is performed on entire validation and test tiles, without spherical sampling. ๐Ÿ”ด๐ŸŸฃ
datamodule.sample_segment_ratio Randomly drops a fraction of the superpoints at each partition level. ๐Ÿ”ด
datamodule.sample_graph_k Controls the number of spherical samples in the train batches. ๐Ÿ”ด
datamodule.sample_graph_r Controls the radius of spherical samples in the train batches. Set to sample_graph_r<=0 to use the entire tile without spherical sampling. ๐Ÿ”ด
datamodule.sample_point_min Controls the minimum number of $P_0$ points sampled per superpoint in the train batches. ๐Ÿ”ด
datamodule.sample_point_max Controls the maximum number of $P_0$ points sampled per superpoint in the train batches. ๐Ÿ”ด
callbacks.gradient_accumulator.scheduling Gradient accumulation. Can be used to train with smaller batches, with more training steps. ๐Ÿ”ด


๐Ÿ’ณ Credits


Citing our work

If your work uses all or part of the present code, please include the following a citation:

@inproceedings{robert2023spt,
  title={Efficient 3D Semantic Segmentation with Superpoint Transformer},
  author={Robert, Damien and Raguet, Hugo and Landrieu, Loic},
  journal={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  year={2023}
}

You can find our paper on arxiv ๐Ÿ“„.

Also, if you like this project, don't forget to give it a โญ, it means a lot to us !

superpoint_transformer's People

Contributors

drprojects avatar charlesgaydon avatar rjanvier avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.