Giter Site home page Giter Site logo

segan's Introduction

SegAN: Semantic Segmentation with Adversarial Learning

Pytorch implementation for the basic ideas from the paper SegAN: Adversarial Network with Multi-scale L1 Loss for Medical Image Segmentation by Yuan Xue, Tao Xu, Han Zhang, L. Rodney Long, Xiaolei Huang.

The data and architecture are mainly from the paper Adversarial Learning with Multi-Scale Loss for Skin Lesion Segmentation by Yuan Xue, Tao Xu, Xiaolei Huang.

Dependencies

python 2.7

Pytorch 1.2

Data

Training

  • The steps to train a SegAN model on the ISIC skin lesion segmentation dataset.
    • Run with: CUDA_VISIBLE_DEVICES=X(your GPU id) python train.py --cuda. You can change training hyperparameters as you wish, the default output folder is ~/outputs. For now we only support training with one GPU. The training images will be save in the ~/outputs folder.
    • The training code also includes the validation part, we will report validation results every 10 epochs, validation images will also be saved in the ~/outputs folder.
  • If you want to try your own datasets, you can just do whatever preprocess you want for your data to make them have similar format as this skin lesion segmentation dataset and put them in a folder similar to ~/ISIC-2017_Training_Data. You can run the model directly for a natural image dataset; For 3D medical data such as brain MRI scans, you need to extract 2D slices from the original data first. If your dataset has more than one class of label, you can run multiple S1-1C models as we described in the SegAN paper.

Citing SegAN

If you find SegAN useful in your research, please consider citing:

@article{xue2017segan,
  title={SegAN: Adversarial Network with Multi-scale $ L\_1 $ Loss for Medical Image Segmentation},
  author={Xue, Yuan and Xu, Tao and Zhang, Han and Long, Rodney and Huang, Xiaolei},
  journal={arXiv preprint arXiv:1706.01805},
  year={2017}
}

References

  • Some of the code for Global Convolutional Block are borrowed from Zijun Deng's excellent code
  • We thank the Pytorch team and some of our image prepocessing code are borrowed from the pytorch official examples

segan's People

Contributors

never-less avatar yuanxue1993 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

segan's Issues

Errors

File "train.py", line 114, in
target = target.type(torch.FloatTensor)
NameError: name 'target' is not defined

The loss of S and C optimization is different

The paper mentions that S and C optimization have the same loss, but in the code, dice loss is added to optimize S, and after I try to get rid of it, the effect will become very poor, why?

Difficulties on reproducing the results

Hi, Yuan!

I'm trying to reproduce the results you published at Adversarial Learning with Multi-Scale Loss for Skin Lesion Segmentation. The problem is that I'm not able to achieve the same results as you did (mDice = 0.867). The highest number I get is around 0.832. I'm using CUDA 8.0, with Python 2.7 and PyTorch 1.2. The hyper parameters I use are the same as you describe in the paper. I'm training for 490 epochs, using ISIC 2017 train and test sets as training and validation, respectively.

The image below has the results of 5 experiments with the code available in this repo. I really appreciate if you can give me some thoughts on what I might be missing.

segan_exp

Thanks a lot! Best regards!

Why is the Prediction results all Nan?

After training, I saved the model parameters of NetS, and then the segmentation prediction is carried out with NetS model. Why is the result all Nan?

I want to perform segmentation tasks on my own dataset, but there seems to be a problem with the Nets

Traceback (most recent call last):
File "/home_lv/yingjie.wang/fenge/train3.py", line 20, in
from net import NetS, NetC
File "/home_lv/yingjie.wang/fenge/net.py", line 488, in
summary(model, input_size=(3,160,160))
File "/home_lv/yingjie.wang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torchsummary/torchsummary.py", line 72, in summary
model(*x)
File "/home_lv/yingjie.wang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home_lv/yingjie.wang/fenge/net.py", line 314, in forward
decoder4 = self.deconvblock4(decoder3)
File "/home_lv/yingjie.wang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home_lv/yingjie.wang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home_lv/yingjie.wang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1120, in _call_impl
result = forward_call(*input, **kwargs)
File "/home_lv/yingjie.wang/fenge/net.py", line 38, in forward
x_l = self.conv_l1(x)
File "/home_lv/yingjie.wang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1120, in _call_impl
result = forward_call(*input, **kwargs)
File "/home_lv/yingjie.wang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 446, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home_lv/yingjie.wang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 442, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
TypeError: conv2d() received an invalid combination of arguments - got (Tensor, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:

  • (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
    didn't match because some of the arguments have invalid types: (Tensor, !Parameter!, !Parameter!, !tuple!, !tuple!, !tuple!, int)
  • (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
    didn't match because some of the arguments have invalid types: (Tensor, !Parameter!, !Parameter!, !tuple!, !tuple!, !tuple!, int)

'GlobalConvBlock' object has no attribute 'weight'

I am trying to implement your code for medical image segmentation problem. When I tried to run the train.py file it always giving me an error 'GlobalConvBlock' object has no attribute 'weight'.
Please try to resolve this issue.

Reproducing the results

Hi! I tried to reproduced the results with CUDA 9 and pytorch 1.1. I encountered several difficulties that required me to slightly modify code. Namely there are these two problems:

  • RuntimeError: expand(torch.cuda.FloatTensor{[36, 1, 128, 128]}, size=[36, 128, 128]): the number of sizes provided (3) must be greater or equal to the number of dimensions in the tensor (4)
    at the lines
    output_masked[:,d,:,:] = input_mask[:,d,:,:].unsqueeze(1) * output
    target_masked[:,d,:,:] = input_mask[:,d,:,:].unsqueeze(1) * target
    I had to change them to
    output_masked[:,d,:,:] = (input_mask[:,d,:,:].unsqueeze(1) * output).squeeze(1)
    output_masked[:,d,:,:] = (input_mask[:,d,:,:].unsqueeze(1) * output).squeeze(1)

  • IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
    at the lines
    print("===> Epoch[{}]({}/{}): Batch Dice: {:.4f}".format(epoch, i, len(dataloader), 1 - loss_dice.data[0]))
    print("===> Epoch[{}]({}/{}): G_Loss: {:.4f}".format(epoch, i, len(dataloader), loss_G.data[0]))
    print("===> Epoch[{}]({}/{}): D_Loss: {:.4f}".format(epoch, i, len(dataloader), loss_D.data[0]))
    I think it should be loss_dice.data or loss_dice.item() instead.

Could you please look at these? I think the issues might be caused by the updates in pytorch.

Source code for multi-class segmentation requested

I am very interested in your recently published paper:
"SegAN: Adversarial Network with Multi-scale L1 Loss for Medical Image Segmentation"
arxiv.org/pdf/1706.01805.pdf

However, the code for multi-class segmentation is not implemented on GitHub repository: github.com/YuanXue1993/SegAN.
Would you please provide me a source code for multi-class segmentation (esp., S3-3C model)?

Thanks.

Why the discriminator output multiplies 2 and 4?

        output = torch.cat((input.view(batchsize,-1),1*out1.view(batchsize,-1),
                            2*out2.view(batchsize,-1),2*out3.view(batchsize,-1),
                            2*out4.view(batchsize,-1),2*out5.view(batchsize,-1),
                            4*out6.view(batchsize,-1)),1)

What's the use of multiplying 2 * out and 4 * out6? Thank you

loss

Why do my output loss_G and loss_D are opposite to each other? In your code, loss_G and loss_D are just symbols different. And after this training is completed, the predictions are all nan. Why is this so?
I really hope to hear from you.

why NetC.zero_grad()?

I am wondering why you used NetC.zero_grad() in train.py, however usually zero_grad() function should set on the optimizer? Could you please explain this. It would be great if there were some more comments in the code.
Thank you

there is an error in line 125

Traceback (most recent call last):
File "train.py", line 125, in
output_masked[:,d,:,:] = (input_mask[:,d,:,:].unsqueeze(1)) * output
RuntimeError: expand(torch.cuda.FloatTensor{[36, 1, 128, 128]}, size=[36, 128, 128]): the number of sizes provided (3) must be greater or equal to the number of dimensions in the tensor (4)

How to configure the models and ground truth for multi-class segmentation?

Hey, I am adapting this code to work on a different dataset that was used in LiTS competition last year. Therefore, I have two distinct classes + the background.

Currently, I am making the model predict output with 2 channels, one for each class. I am splitting the ground truth into 2 channels too for comparing with the prediction during loss calculation. At the end of the loop am merging both channels again to make it look like original groundtruth (for visualization purposes).

The problem in this case is that the model seems to be learning one of the classes well but doesn't learn the other.

Can you please provide some pointers as to why this may be happening or is there any better way to configure SegAN for multi-class segmentation.

Thanks.

pytorch version?

In your README file, u said that pytorch version is 1.2, I don't find this version in the pytorch web,could you please tell me the specific version? And there is another question, when i run your code, something wrong with the 169 line in train.py, it says TypeError (min=float,max=float) where min and max need int type.

Why do we need to clip gradient in netC?

Thanks for sharing your code. In your code, it has

#clip parameters in D
for p in NetC.parameters():
     p.data.clamp_(-0.05, 0.05)

why do we need to clip parameters in D? How to decide the value -0.05, 0.05? Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.