Giter Site home page Giter Site logo

yejg2017 / brain-segmentation-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from mateuszbuda/brain-segmentation-pytorch

0.0 1.0 0.0 30.1 MB

U-Net implementation in PyTorch for FLAIR abnormality segmentation in brain MRI

Home Page:

License: MIT License

Dockerfile 2.78% Python 97.22%

brain-segmentation-pytorch's Introduction

U-Net for brain segmentation

U-Net implementation in PyTorch for FLAIR abnormality segmentation in brain MRI based on a deep learning segmentation algorithm used in Association of genomic subtypes of lower-grade gliomas with shape features automatically extracted by a deep learning algorithm.

This repository is an all Python port of official MATLAB/Keras implementation in brain-segmentation. Weights for trained models are provided and can be used for inference or fine-tuning on a different dataset. If you use code or weights shared in this repository, please consider citing:

  title={Association of genomic subtypes of lower-grade gliomas with shape features automatically extracted by a deep learning algorithm},
  author={Buda, Mateusz and Saha, Ashirbani and Mazurowski, Maciej A},
  journal={Computers in Biology and Medicine},


docker build -t brainseg .
nvidia-docker run --rm --shm-size 8G -it -v `pwd`:/workspace brainseg

PyTorch Hub

Loading model using PyTorch Hub:

import torch
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)



Dataset used for development and evaluation was made publicly available on Kaggle: It contains MR images from TCIA LGG collection with segmentation masks approved by a board-certified radiologist at Duke University.


A segmentation model implemented in this repository is U-Net as described in Association of genomic subtypes of lower-grade gliomas with shape features automatically extracted by a deep learning algorithm with added batch normalization.



TCGA_DU_6404_19850629 TCGA_HT_7879_19981009 TCGA_CS_4944_20010208
94% DSC 91% DSC 89% DSC

Qualitative results for validation cases from three different institutions with DSC of 94%, 91%, and 89%. Green outlines correspond to ground truth and red to model predictions. Images show FLAIR modality after preprocessing.


Distribution of DSC for 10 randomly selected validation cases. The red vertical line corresponds to mean DSC (91%) and the green one to median DSC (92%). Results may be biased since model selection was based on the mean DSC on these validation cases.


  1. Download and extract the dataset from Kaggle.
  2. Run docker container.
  3. Run script with specified paths to weights and images. Trained weights for input images of size 256x256 are provided in ./weights/ file. For more options and help run: python3 --help.


  1. Download and extract the dataset from Kaggle.
  2. Run docker container.
  3. Run script. Default path to images is ./kaggle_3m. For more options and help run: python3 --help.

Training can be also run using Kaggle kernel shared together with the dataset: Due to memory limitations for Kaggle kernels, input images are of size 224x224 instead of 256x256.

Running this code on a custom dataset would likely require adjustments in Should you need help with this, just open an issue.

brain-segmentation-pytorch's People


dependabot[bot] avatar mateuszbuda avatar soumith avatar



Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.