Giter Site home page Giter Site logo

artur112 / brain_tumour_segmentation Goto Github PK

View Code? Open in Web Editor NEW
21.0 3.0 8.0 25.4 MB

Code for training a 3DUnet for Brain tumour segmentation from Brats 2019 dataset; for feature extraction from the segmented volumes and for survival prediction. Run train.py for training, segment.py for segmenting test scans and evaluate.py for evaluating the performance of those segmentations. Basic code also written to perform survival prediction with a random forest classifiier.

Python 100.00%

brain_tumour_segmentation's Introduction

Run train.py for training, segment.py for segmenting test scans (performing inference) and evaluate.py for evaluating the results of those segmentations. Basic code also written to perform survival prediction with a random forest classifier in surv_prediction folder.

Old Code and might have many issues. I suggest changing the dataloading part to loading from hdf5 files instead of npz files, will significantly speed up training. Also performing inference on the full sized BraTS volumes might cause issues. I suggest resampling images to the given patch size, performing inference on those downsampled volumes and then resampling back to the original BraTS size.

brain_tumour_segmentation's People

Contributors

artur112 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

brain_tumour_segmentation's Issues

matplotlib.use('TkAgg')

Any other option for this

matplotlib.use('TkAgg')

I am using Google colab for executing your code. Can your share your email so that you that also have a look. I will share notebook with you.

RuntimeError: Error(s) in loading state_dict for UNet3D:

RuntimeError: Error(s) in loading state_dict for UNet3D:
Missing key(s) in state_dict: "encoders.0.basic_module.SingleConv1.conv.weight", "encoders.0.basic_module.SingleConv1.groupnorm.weight", "encoders.0.basic_module.SingleConv1.groupnorm.bias", "encoders.0.basic_module.SingleConv2.conv.weight", "encoders.0.basic_module.SingleConv2.groupnorm.weight", "encoders.0.basic_module.SingleConv2.groupnorm.bias", "encoders.1.basic_module.SingleConv1.conv.weight", "encoders.1.basic_module.SingleConv1.groupnorm.weight", "encoders.1.basic_module.SingleConv1.groupnorm.bias", "encoders.1.basic_module.SingleConv2.conv.weight", "encoders.1.basic_module.SingleConv2.groupnorm.weight", "encoders.1.basic_module.SingleConv2.groupnorm.bias", "encoders.2.basic_module.SingleConv1.conv.weight", "encoders.2.basic_module.SingleConv1.groupnorm.weight", "encoders.2.basic_module.SingleConv1.groupnorm.bias", "encoders.2.basic_module.SingleConv2.conv.weight", "encoders.2.basic_module.SingleConv2.groupnorm.weight", "encoders.2.basic_module.SingleConv2.groupnorm.bias", "encoders.3.basic_module.SingleConv1.conv.weight", "encoders.3.basic_module.SingleConv1.groupnorm.weight", "encoders.3.basic_module.SingleConv1.groupnorm.bias", "encoders.3.basic_module.SingleConv2.conv.weight", "encoders.3.basic_module.SingleConv2.groupnorm.weight", "encoders.3.basic_module.SingleConv2.groupnorm.bias", "decoders.0.basic_module.SingleConv1.conv.weight", "decoders.0.basic_module.SingleConv1.groupnorm.weight", "decoders.0.basic_module.SingleConv1.groupnorm.bias", "decoders.0.basic...
Unexpected key(s) in state_dict: "s.0.basic_module.SingleConv1.conv.weight", "s.0.basic_module.SingleConv1.groupnorm.weight", "s.0.basic_module.SingleConv1.groupnorm.bias", "s.0.basic_module.SingleConv2.conv.weight", "s.0.basic_module.SingleConv2.groupnorm.weight", "s.0.basic_module.SingleConv2.groupnorm.bias", "s.1.basic_module.SingleConv1.conv.weight", "s.1.basic_module.SingleConv1.groupnorm.weight", "s.1.basic_module.SingleConv1.groupnorm.bias", "s.1.basic_module.SingleConv2.conv.weight", "s.1.basic_module.SingleConv2.groupnorm.weight", "s.1.basic_module.SingleConv2.groupnorm.bias", "s.2.basic_module.SingleConv1.conv.weight", "s.2.basic_module.SingleConv1.groupnorm.weight", "s.2.basic_module.SingleConv1.groupnorm.bias", "s.2.basic_module.SingleConv2.conv.weight", "s.2.basic_module.SingleConv2.groupnorm.weight", "s.2.basic_module.SingleConv2.groupnorm.bias", "s.3.basic_module.SingleConv1.conv.weight", "s.3.basic_module.SingleConv1.groupnorm.weight", "s.3.basic_module.SingleConv1.groupnorm.bias", "s.3.basic_module.SingleConv2.conv.weight", "s.3.basic_module.SingleConv2.groupnorm.weight", "s.3.basic_module.SingleConv2.groupnorm.bias", "onv.weight", "onv.bias".

data augmentation

Hi Artur112! Thanks for sharing the code,but i am comfused with some code,in DataAugment.elastic_deform,if my Y is the same shape like the x(B,4,H,W,D),3 foreground + 1 background,each channel represents a class, what should i change in the code to adapt my Y? here is what i have changed ,is these right?

` def elastic_deform(self, X, Y):
# Elastic deformation with a random square deformation grid, where the displacements are sampled from a normal distribution with
# standard deviation sigma. Applies elastic deformation in all 3 axes, if you wish to speed up training time change the d variable
# to two axes chosen randomly. Uses elasticdeform package from gvtulder/elasticdeform.

    # the shape of X and Y are both (4,128,128,128)
    X = X.numpy()
    Y = Y.unsqueeze(1).numpy()
    brain_region = (X > 0).astype(
        'float') * 10  # Multiplying by 10 so there would be a bigger difference between foreground and background pixels to avoid voxels
    # being assigned the wrong label after the elastic deformation

    # Split the labels and deform them separately, as if done together together they'll get mixed up. 10 times multiplication again.
    lbl0 = Y[0] * 10.0  # the background
    lbl1 = Y[1] * 10.0
    lbl2 = Y[2] * 10.0
    lbl3 = Y[3] * 10.0
    sigma_nr = 5  # Random factor by which to deform
    # d = (1,2,3) # Axes in which to deform
    d = tuple(sorted(random.sample([1, 2, 3],
                                   k=2)))  # Use this instead if you wish to speed up training. Performs elastic transform over 2 axes only

    [X, brain_region, lbl0, lbl1, lbl2, lbl3] = elasticdeform.deform_random_grid(
        [X, brain_region, lbl0, lbl1, lbl2, lbl3],
        axis=[d] * 6, sigma=sigma_nr, points=3)

    brain_region = brain_region.astype('int') > 0
    X = X * brain_region  # To make sure background pixels remain 0 in the scans
    X[X < 0] = 0  # Remove any negative values - background values close 0

    lbl0[lbl1 < 0] = -1
    lbl1[lbl1 < 3] = -1
    lbl2[lbl2 < 3] = -1
    lbl3[lbl3 < 3] = -1

    Y = np.concatenate([lbl0, lbl1, lbl2, lbl3], axis=0).astype("float32")
    X = torch.from_numpy(X)
    Y = torch.from_numpy(Y)

    return X, Y`

thaks for your help!!

suivival prediction

Hi, thanks for your work on survival prediction task.
I follow your code to do this task but only get 0.379 validation accuracy.
How about your performance on validation accuracy? Do you have any paper show the detalis?
Appreciate for your reply.

the prediction class type

hi again,i am confused with your prediction.if i understand correctly,you just replace ET from label 4 to 3,and others keep the same,so the label for each scan are 0 for Background;1 for Necrotic and non-enhancing tumor;2 for Edema;3 for the Enhancing tumor.The net predicts the same 4 class as said before,not background,whole tumor,tumor core,enhancing tumor,right?when i submmit my validation set predict results,do i need to transform the (Background,Necrotic and non-enhancing tumor,edema,enhancing tumor ) to (Background,Wt,Tc,Et)? Thanks a lot!!

Scatter function error inside utils file

2020-05-17 07:22:58.417980: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1
Traceback (most recent call last):
File "train.py", line 208, in
masks = expand_as_one_hot(masks, n_classes)
File "/content/drive/My Drive/Artur/model_utils/utils.py", line 426, in expand_as_one_hot
return torch.zeros(shape).to(input.device).scatter_(1, input, longx)
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 'index' in call to th_scatter

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.