Giter Site home page Giter Site logo

vox2cortex's Introduction

Vox2Cortex and related methods

Note: This repository has been refactored entirely, you can find the old Vox2Cortex repo here.

This repository implements several mesh-based segmentation methods for the cortex and abdominal organs, namely:

Installation

  1. Make sure you use python 3.9
  2. Clone this (Vox2Cortex) repo
    git clone [email protected]:ai-med/Vox2Cortex.git
    cd Vox2Cortex/
  1. Create conda environment
    conda env create -f requirements.yml
    conda activate vox2organ
  1. Clone and install our pytorch3d fork as described therein (basically running pip install -e . in the cloned pytorch3d repo).

Usage

You can include new datasets directly in vox2organ/data/supported_datasets.py. It is generally assumed that subject data (comprising image, meshes, and segmentation maps) is stored in the form data-raw-directory/sample-ID/subject-data. Currently, it is required that the mapping from world to image space is equal for all images, which can be achieved by affine registration of all input images to a common template, e.g., with niftyreg, and applying the computed affine transformation to the respective meshes. See the preprocessing/ directory for preprocessing scripts.

Inference

We provide a pre-trained V2C-Flow-S model in vox2organ/pretrained_models/V2C-Flow-S-ADNI. For inference with this model, we recommend copying it to an experiment dir first.

mkdir experiments
cp -r vox2organ/pretrained_models/V2C-Flow-S-ADNI experiments/V2C-Flow-S-ADNI
cd vox2organ
python main.py --test -n V2C-Flow-S-ADNI --dataset YOUR_DATASET

Training

A V2C-Flow training on a new dataset with subsequent model testing can be started with

    cd vox2organ/
    python3 main.py --train --test --group "V2C-Flow-S" --dataset YOUR_DATASET

We recommend using the pre-trained V2C-Flow model as a starting point for cortex reconstruction to shorten training time and save resources, i.e.,

    python3 main.py --train --test --group "V2C-Flow-S" --dataset YOUR_DATASET --pretrained_model pretrained_models/V2C-Flow-S-ADNI/best.pt

For information about command-line options see

    python3 main.py --help

Models and parameters

Training a UNetFlow model works similarly, see vox2organ/params/groups.py for implemented models. A list of all available parameters and their default values is in vox2organ/params/default.py. Parameters are overwritten in the following sequential manner: CLI -> vox2organ/main.py -> vox2organ/params/groups.py -> vox2organ/params/default.py. That is, a parameter specified in main.py overwrites parameter groups and default parameters etc.

Templates

A couple of mesh templates for the cortex and the abdomen are in supplementary_material/; new ones can also be added, of course.

Docker

We provide files for creating a docker image in the docker/ directory.

Debugging

For debugging, it is usually helpful to start training/testing on a few samples (N) with the command-line arguments -n debug --overfit [N]. This omits logging in wandb and writes output to a "debug" experiment.

Coordinate convention

The coordinate convention is the following:

  • Input/output meshes should be stored in scanner RAS coordinates. A simple check can be performed by loading an image/segmentation and corresponding meshes via 3D slicer, selecting "RAS" as the coordinate convention for the meshes. FreeSurfer surfaces are, by default, stored in tkrRAS coordinates, see for example this link; conversion from tkrRAS to scanner RAS can be done by mris_convert --to-scanner input-surf output-surf
  • Internally, mesh coordinates are converted to image coordinates normalized by image dimensions so that they fit the requirements of torch.nn.functional.grid_sample. A sample code snipped documenting this convention is also provided below.
import torch
import torch.nn.functional as F
a = torch.tensor([[[0,0,0],[0,0,1],[0,0,0]],[[0,0,0],[0,0,0],[0,0,0]],[[0,0,0],[0,0,0],[0,0,0]]]).float()
c = torch.nonzero(a).float() - 1 # coords in [-1,1]
c = torch.flip(c, dims=[1]) # z,y,x --> x,y,z
a = a[None][None]
c = c[None][None][None]
print(F.grid_sample(a, c, align_corners=True))

Output:

tensor([[[[[1.]]]]])

Normal convention

The normal convention of input meshes should follow the convention used in most libraries like pytorch3d or trimesh. That is, the face indices are ordered such that the face normal of a face with vertex indices (i, j, k) calculates as (vj - vi) x (vk - vi).

Citation

If you find this work useful, please cite (depending on the used model):

@InProceedings{Bongratz2022Vox2Cortex,
	author    = {Bongratz, Fabian and Rickmann, Anne-Marie and P\"olsterl, Sebastian and Wachinger, Christian},
	title     = {Vox2Cortex: Fast Explicit Reconstruction of Cortical Surfaces From 3D MRI Scans With Geometric Deep Neural Networks},
	booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
	month     = {June},
	year      = {2022},
	pages     = {20773-20783}
}
@article{Bongratz2023Abdominal,
	year = {2023},
	month = oct,
	publisher = {Springer Science and Business Media {LLC}},
	volume = {13},
	number = {1},
	author = {Fabian Bongratz and Anne-Marie Rickmann and Christian Wachinger},
	title = {Abdominal organ segmentation via deep diffeomorphic mesh deformations},
	journal = {Scientific Reports}
}
@article{Bongratz2024V2CFlow,
	title = {Neural deformation fields for template-based reconstruction of cortical surfaces from MRI},
	volume = {93},
	ISSN = {1361-8415},
	journal = {Medical Image Analysis},
	publisher = {Elsevier BV},
	author = {Bongratz,  Fabian and Rickmann,  Anne-Marie and Wachinger,  Christian},
	year = {2024},
	month = apr,
	pages = {103093}
}

vox2cortex's People

Contributors

fabibo3 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

vox2cortex's Issues

About the template

Dear author, thank you for sharing such excellent work, I have learned a lot from it.

I would like to ask how to generate the template of other organs by myself. I looked it up online and read that FreeSurfer is a brain processing tool.

Looking forward to your reply.

weights for segmentations?

Hello, I noticed that there are three (because of deep supervision) final output maps for segmentation? Could you please post the weights you used for the three segmentations for the loss? And why the N_V_CLASSES (the number of vertex classes to distinguish) is 2? Why should it including background? Thanks a lot!

loss = self.training_step(model, data, iteration) seems to fail for larger meshes GPU memory explosion

It seems that the training
loss = self.training_step(model, data, iteration)
is intractable for meshes in our dataset with current settings.

RuntimeError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 31.75 GiB total capacity; 30.37 GiB already allocated; 13.94 MiB free; 30.52 GiB reserved in total by PyTorch)

I exited before and after this line of code and it is the line which fails. I assume it must be the loss function calculation since I don't appear to have that much data loaded.

compiling with gpu support

do you have any suggestions on how to compile with gpu support? I ended up installing facebook research's pytorch3d only to figure out you modified their loss functions because I can't compile with gpu support for some reason.
Traceback (most recent call last):
File "main.py", line 260, in
main(hyper_ps)
File "main.py", line 256, in main
loglevel=hps['LOGLEVEL'], resume=args.resume)
File "/v2c/vox2cortex/utils/train.py", line 553, in training_routine
start_epoch=start_epoch)
File "/v2c/vox2cortex/utils/train.py", line 328, in train
loss = self.training_step(model, data, iteration)
File "/v2c/vox2cortex/utils/logging.py", line 252, in time_wrapper
return_value = func(*args, **kwargs)
File "/v2c/vox2cortex/utils/train.py", line 157, in training_step
loss_total = self.compute_loss(model, data, iteration)
File "/v2c/vox2cortex/utils/logging.py", line 252, in time_wrapper
return_value = func(*args, **kwargs)
File "/v2c/vox2cortex/utils/train.py", line 205, in compute_loss
pred = model(x.cuda())
File "/opt/miniconda3/envs/vox2cortex/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/v2c/vox2cortex/utils/logging.py", line 252, in time_wrapper
return_value = func(*args, **kwargs)
File "/v2c/vox2cortex/models/vox2cortex.py", line 121, in forward
pred_meshes, pred_deltaV = self.graph_net(encoder_skips + decoder_skips)
File "/opt/miniconda3/envs/vox2cortex/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/v2c/vox2cortex/utils/logging.py", line 252, in time_wrapper
return_value = func(*args, **kwargs)
File "/v2c/vox2cortex/models/graph_net.py", line 238, in forward
latent_features = self.graph_conv_first(verts_packed, edges_packed)
File "/opt/miniconda3/envs/vox2cortex/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/v2c/vox2cortex/utils/utils_vox2cortex/graph_conv.py", line 80, in forward
features = F.relu(self.norm_first(self.gconv_first(features, edges)))
File "/opt/miniconda3/envs/vox2cortex/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/v2c/vox2cortex/utils/utils_vox2cortex/graph_conv.py", line 35, in forward
return D_inv * super().forward(verts, edges)
File "/app/pytorch3d/pytorch3d/ops/graph_conv.py", line 76, in forward
neighbor_sums = gather_scatter(verts_w1, edges, self.directed)
File "/app/pytorch3d/pytorch3d/ops/graph_conv.py", line 160, in forward
output = _C.gather_scatter(input, edges, directed, backward)
RuntimeError: Not compiled with GPU support.

Datasets

Dear author, could you tell me about the format and division criteria for data sets,thanks.
微信截图_20240708204032

Mesh Normalisation to original coordinate space

Hi all,

Thanks for this wonderful repo. I have a doubt since you seem to be using pytorch3d.

When training your model, do you normalise your meshes to be in the space [-1, 1] and then transform to original space using scaling and translation after you get model's output mesh or do you directly train your model so that the coordinates of your output Mesh from model are in their original space?

Thanks!

AbdomentCT-1K

Hello, could you please share the scheme of how use split on train, val and test part of AbdomentCT-1K ?

image

supplementary templates missing

the code references some templates that haven't been uploaded. could you check that all supplementary materials are uploaded.
Great paper btw.
Thanks.
'../supplementary_material/white_pial/cortex_4_1000_3_smoothed_42016_sps[192, 208, 192]_ps[128, 144, 128].obj'
not in github

Implementation for V2C-Flow

Dear author, thank you so much for your excellent work!

Is the V2C-Flow network (2024) in vox2cortex.py? If so, where is the Vox2Cortex network (2022)?

Looking forward to your reply.

RuntimeError: stack expects each tensor to be equal size, but got [4, 40962, 3] at entry 0 and [4, 36685, 3] at entry 1

I ran preprocessing on this dataset, but still get this error.
[INFO] Created training loader of length 5
Traceback (most recent call last):
File "main.py", line 260, in
main(hyper_ps)
File "main.py", line 256, in main
loglevel=hps['LOGLEVEL'], resume=args.resume)
File "/v2c/vox2cortex/utils/train.py", line 556, in training_routine
start_epoch=start_epoch)
File "/v2c/vox2cortex/utils/train.py", line 322, in train
for iter_in_epoch, data in enumerate(training_loader):
File "/opt/miniconda3/envs/vox2cortex/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 435, in next
data = self._next_data()
File "/opt/miniconda3/envs/vox2cortex/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 475, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/opt/miniconda3/envs/vox2cortex/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/opt/miniconda3/envs/vox2cortex/lib/python3.7/site-packages/torch/utils/
data/_utils/collate.py", line 83, in default_collate
return [default_collate(samples) for samples in transposed]
File "/opt/miniconda3/envs/vox2cortex/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 83, in
return [default_collate(samples) for samples in transposed]
File "/opt/miniconda3/envs/vox2cortex/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [4, 40962, 3] at entry 0 and [4, 36685, 3] at entry 1

Training time?

Hello, how long did it takes to converge when you were training the final model?

nvcc fatal : Unsupported gpu architecture 'compute_86'

Thank you for publishing a great repo.
I am building the environment for the repo and got this error after running git checkout tags/vox2cortex_cvpr2022 -b vox2cortex_pytorch3d pip install -e . Im running on CUDA 11.7 on NVIDIA A6000; Python 3.8
Could you address it?

`Running setup.py develop for pytorch3d
error: subprocess-exited-with-error

× python setup.py develop did not run successfully.
│ exit code: 1
╰─> [53 lines of output]
    running develop
    running egg_info
    creating pytorch3d.egg-info
    writing pytorch3d.egg-info/PKG-INFO
    writing dependency_links to pytorch3d.egg-info/dependency_links.txt
    writing requirements to pytorch3d.egg-info/requires.txt
    writing top-level names to pytorch3d.egg-info/top_level.txt
    writing manifest file 'pytorch3d.egg-info/SOURCES.txt'
    reading manifest file 'pytorch3d.egg-info/SOURCES.txt'
    adding license file 'LICENSE'
    writing manifest file 'pytorch3d.egg-info/SOURCES.txt'
    running build_ext
    building 'pytorch3d._C' extension
    creating build
    creating build/temp.linux-x86_64-cpython-38
    creating build/temp.linux-x86_64-cpython-38/home
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/blending
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/compositing
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/face_areas_normals
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/gather_scatter
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/interp_face_attrs
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/knn
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/mesh_normal_consistency
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/packed_to_padded_tensor
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/point_mesh
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/pulsar
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/pulsar/cuda
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/pulsar/host
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/pulsar/pytorch
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/rasterize_meshes
    creating build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/rasterize_points
    /usr/bin/nvcc -DWITH_CUDA -DTHRUST_IGNORE_CUB_VERSION_CHECK -I/home/thanhtul/code/pytorch3d/pytorch3d/csrc -I/home/thanhtul/miniconda3/envs/py38/lib/python3.8/site-packages/torch/include -I/home/thanhtul/miniconda3/envs/py38/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/thanhtul/miniconda3/envs/py38/lib/python3.8/site-packages/torch/include/TH -I/home/thanhtul/miniconda3/envs/py38/lib/python3.8/site-packages/torch/include/THC -I/home/thanhtul/miniconda3/envs/py38/include/python3.8 -c /home/thanhtul/code/pytorch3d/pytorch3d/csrc/blending/sigmoid_alpha_blend.cu -o build/temp.linux-x86_64-cpython-38/home/thanhtul/code/pytorch3d/pytorch3d/csrc/blending/sigmoid_alpha_blend.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -std=c++14 -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=sm_86
    nvcc fatal   : Unsupported gpu architecture 'compute_86'
    /home/thanhtul/code/pytorch3d/setup.py:72: UserWarning: The environment variable `CUB_HOME` was not found. NVIDIA CUB is required for compilation and can be downloaded from `https://github.com/NVIDIA/cub/releases`. You can unpack it to a location of your choice and set the environment variable `CUB_HOME` to the folder containing the `CMakeListst.txt` file.
      warnings.warn(
    /home/thanhtul/miniconda3/envs/py38/lib/python3.8/site-packages/setuptools/command/easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools.
      warnings.warn(
    /home/thanhtul/miniconda3/envs/py38/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
      warnings.warn(
    /home/thanhtul/miniconda3/envs/py38/lib/python3.8/site-packages/torch/utils/cpp_extension.py:352: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
      warnings.warn(msg.format('we could not find ninja.'))
    /home/thanhtul/miniconda3/envs/py38/lib/python3.8/site-packages/torch/cuda/__init__.py:104: UserWarning:
    NVIDIA RTX A6000 with CUDA capability sm_86 is not compatible with the current PyTorch installation.
    The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70 sm_75.
    If you want to use the NVIDIA RTX A6000 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/

      warnings.warn(incompatible_device_warn.format(device_name, capability, " ".join(arch_list), device_name))
    error: command '/usr/bin/nvcc' failed with exit code 1
    [end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.`

Chamfer distance parameters

Hi, I wanted to use the chamfer curvature loss in some experiments. I was wondering if you could upload the version of pytorch3d.loss.chamfer_distance you were using. I may be missing something, but I am not sure how to interpret the points_weights parameter. I tried modifying the source code of the function, but I wanted to be sure on the implementation. Thanks in advance!

Possiblity to train with extremely anisotrpic images

Hi, I saw the size of the image the paper used is (128, 144, 128). While in my own dataset, the size of the image is always extremely anisotropic(such as (192, 192, 32)). Is it possible to use Vox2Cortex to train on extremely anisotropic data without any modification directly?

About figure3 in the paper

Hi,
Thanks for your work first.
I wanna know what is the meaning of "NNs, ID" in the figure.
In line 342 of graph_net.py, "new_meshes.move_verts(deltaV_padded)", it seems after updating the location of meshes, the next block will use the new location to calculate. So in this way I'm confused about the meaning of "NNs, ID", and there is no explanation in the paper as well(I apologize if I missed anything).

Looking forward to your reply.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.