Giter Site home page Giter Site logo

usuyama / pytorch-unet Goto Github PK

View Code? Open in Web Editor NEW
817.0 10.0 228.0 365 KB

Simple PyTorch implementations of U-Net/FullyConvNet (FCN) for image segmentation

Home Page: https://colab.research.google.com/github/usuyama/pytorch-unet/blob/master/pytorch_unet_resnet18_colab.ipynb

License: MIT License

Python 2.23% Jupyter Notebook 97.77%
image-segmentation unet fully-convolutional-networks semantic-segmentation

pytorch-unet's People

Contributors

karanchahal avatar usuyama avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-unet's Issues

RGB Target mask

Firstly thanks for your implantation, it is really so good.

Now I want to try if the input image is grayscale [512, 512] and the generated target mask is RGB image [512, 512, 3] (same figures). After SimDataset and DataLoader we have shapes torch.Size([20, 512, 512]) torch.Size([20, 512, 512, 3]) and training show RuntimeError: Given groups=1, weight of size [64, 3, 3, 3], expected input[1, 20, 512, 512] to have 3 channels, but got 20 channels instead.

How to fix this issue?

Resnet Model Validity

Hi,

I can't get your pytorch_resnet18_unet.ipynb working.

In layer4 = self.layer4_1x1(layer4) line, it throws the following error:

RuntimeError: Given groups=1, weight of size [1024, 2048, 1, 1], expected input[2, 512, 7, 7] to have 2048 channels, but got 512 channels instead

I haven't modified any of your code.

Image dimensions do not match

I've just downloaded your jupyter notebook and got an error that the dimensions do not match. I'm talking about the very first cell in pytorch_fcn.ipynb.

Printed output: "(3,192,192,3) (3,6,192,192)"

Thanks in advance for your help. 👍

Grayscale Images

Hi, How do I change the model to input grayscale images instead of 3 channel images?
Thanks

Use ConvTranspose2d instead

Here in the code of U-Net you have used upsampling layer. Instead of it you should be using ConvTranspose2d.

UpSampling2D is just a simple scaling up of the image by using nearest neighbour or bilinear upsampling, so nothing smart. Advantage is it's cheap.

Conv2DTranspose is a convolution operation whose kernel is learnt (just like normal conv2d operation) while training your model. Using Conv2DTranspose will also upsample its input but the key difference is the model should learn what is the best upsampling for the job.

making label

Your implementation is very good, why don't you use the background as a segmentation category? FCN uses one-hot is often to use the background as a category

RuntimeError with ResNetUNet

Hello!

When I get to the line

model = ResNetUNet(6)

I receive the following error message:

Traceback (most recent call last):
File "", line 1, in
File "", line 4, in init
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torchvision/models/resnet.py", line 237, in resnet18
**kwargs)
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torchvision/models/resnet.py", line 220, in resnet
model = ResNet(block, layers, **kwargs)
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torchvision/models/resnet.py", line 142, in init
bias=False)
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 332, in init
False, pair(0), groups, bias, padding_mode)
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 46, in init
self.reset_parameters()
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 49, in reset_parameters
init.kaiming_uniform
(self.weight, a=math.sqrt(5))
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torch/nn/init.py", line 315, in kaiming_uniform

return tensor.uniform_(-bound, bound)
RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

This looks likeit might be a bug in pytorch (or maybe just how I'm using it), but if you have any pointers I'd be greatful.

Thanks,
P

Dice value is high at the beginning

Hi @usuyama,

Thanks for sharing your good work. I try to run the training using the simulation data and the value of dice is quite high at the beginning (0.98)

Have you ever experience this case? I assume it is something wrong, any suggestion? Thanks

CUDA out of memory

Hello!
Thanks for your excellent work! I'm just learning the pytorch and unet network. When I ran your scripts I got the following error:
OutOfMemoryError: CUDA out of memory. Tried to allocate 226.00 MiB (GPU 0; 4.00 GiB total capacity; 3.24 GiB already allocated; 0 bytes free; 3.27 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
The error was got in the step of training. And there were no other process hobbling the gpu.
Thanks for your reply.

Dimensions not working

Dimensions don't match up.

/usr/local/lib/python2.7/dist-packages/torch/nn/functional.py:1474: UserWarning: Using a target size (torch.Size([1966080])) that is different to the input size (torch.Size([3010560])) is deprecated. Please ensure they have the same size.
"Please ensure they have the same size.".format(target.size(), input.size()))

I am using BCELoss and i have tried with 192*192 dimensions as well and it is not working.

License?

I hope your license is MIT or GPL3

network output

Using this implementation for 2 class problem (I'm using 1 channel target image with binary(0,1) value), but my output is not binary (0,1). did I miss something? should I add some activation function at the last layer?

Can dice and bce loss work on a multi-class task?

Thanks for the great implementation code.

I am confusing about the loss function. As far as I can see dice and bce are both used in binary-class task. Can they work well on multi-class task? From your code I can see the losses work ok, but what about bigger data set.

I tried F.cross_entropy(), but it gives me this: RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [36, 4, 224, 224]. Could you please tell me whats wrong? thx

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)
    target_long = target.type(torch.LongTensor)
    ce = F.cross_entropy(pred, target_long.cuda())

    # pred = F.sigmoid(pred)
    pred = torch.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.