Giter Site home page Giter Site logo

pytorch-deep-image-matting's People

Contributors

huochaitiantang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-deep-image-matting's Issues

About learning rate setting

Hello @huochaitiantang

Recently, I retrained the model with your code. The learning rate in the train.sh is 0.00001, and the
default value in the train.py is 0.001. Which learning rate is the one you are using for your current
training model? Thanks !

Cannot re-produce the results

Thanks for the great work! I tried to resume training from the pre-trained model (epoch 12, stage 0) by incorporating the latest changes you made. However, the loss does not seem to improve. Would you share with us the hyperparameters you used during training? Do you plan to release the new pre-trained models which incorporated the latest changes (erosion, etc.)? Thanks!

the evaluation code

Thanks for your contribution. Have you ever tried the evaluation code by python? And i saw there was a compute_gradient function in the deploy'py, Is this is equal to thegrad in the evaluation code?

Can I run this without CUDA?

Hi,

I'm trying to run this without CUDA. I got this error:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

But when I try to use it with map_location=torch.device('cpu') then I get RuntimeError: Error(s) in loading state_dict for VGG16: size mismatch for conv6_1.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).

Am I going about this wrong? Please help

trimap used for testing

Hi, what trimaps did you use for testing? did you use the trimaps provided by Adobe or the trimaps generated from alpha, similar to the training process?

absence of normalization

Hi, thanks for great code.
I inspected the input stream and found that input image is not normalized.
As far as i know, pytorch VGG assumes normalized pixel values.
RGB value should be divided by 255, then normalized with mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].

训练数据

您好, 您可以提供下训练用的数据吗

CUDA out of memory

Hi,
Is there a way to limit the gpu memory, my computer gpu is 6g, when i use command nvidia-smi to watch, the memory was consumed soon. Thank you

The error prompt is below:
RuntimeError: CUDA out of memory. Tried to allocate 500.00 MiB (GPU 0; 6.00 GiB total capacity; 4.10 GiB already allocated; 226.63 MiB free; 26.43 MiB cached)

question regarding alpha being used in generating trimap

Hi

I have question regarding alpha needed during testing phase. Per my understanding, paper states that for testing/inference only original image & corresponding trimap are needed and encoder/decoder framework will predict the alpha. On the other hand Ground truth alpha (generated by photoshop) is required for training.

I am aware that originally authors calculated alpha manually for all 431 objects and after composing distributed among training and testing sets. So training and testing both gets high quality alpha produced manually. This is fine is testing images of authors are used.

However if I chose another image from internet and want to pass through network (per paper I should only provide image and correponding trimap)

Why alpha being passed during testing phase to generate trimap? and how do I get that alpha? Is it the case that mask generated by any segmentation framework can be used as initial/rough alpha as input to generate_trimap and by running through the network will produce/predict better alpha?

Many thanks.

Dataset

How to create dataset using VOC 12 and COCO 17. The tools which can help creating the dataset

Sigmoid as last layer

Hi, I just quickly wanted to say thanks: I used your structure to re-implement the model myself using resnet34 instead of VGG16, but turned to your code for inspiration or when I got stuck.

Anyways, I wanted to mention, you write here that you're not adding a Sigmoid at the end because your results converge to zero.
I ran into the same issue. The model learns within a single batch to only predict 0s. For me, the issue was that I had not added erosion to the matting stage. The result is that the "ground truth" for the model mostly contains 0s (>90%), causing it to predict 0s everywhere. I fixed it by also adding erosion, the same amount as dilation. This "balances" the labels to contain both 1s and 0s, punishing a model that only predicts 0s (the explanation here is a bit rubbish, happy to elaborate more if you want me to).

I saw here that you're only dilating, not eroding. Just wanted to suggest you check it out, maybe you're running into the same problem I had. Adding erosion might enable the usage of a sigmoid.

Thanks again for the inspiration!

Automatic Background Removal technology

I am looking for a deep learning library/sdk which can be used to remove the background from any image automatically (with quality as good as www.remove.bg).

I tried some image segmentation SDKs with pre-trained models such as Tensorflow Lite & Fritz AI, but the accuracy of the cutout mask was very low, amongst other issues.

Criteria :-

  1. Background Removal rather than just Human/Portrait Segmentation

If the foreground consists of person holding a balloon, sitting on a chair, with a pet on his side, then I want all of this to get extracted. Not just the human cutout. The segmentation SDKs I tried are only extracting humans (the chair gets vanished), that too with a very low quality mask (hair gets cut, parts of ear gets cut, etc).

  1. Mask quality should be Super-Accurate

I want even the finer details like the hair, delicate clothes, etc to be extracted perfectly.

  1. Fast & Lightweight (for mobile phone)

I want to use this technology on mobile phones (in an Android app) which should ideally work even in an offline environment. If this option is difficult to achieve, then plan B would be install the technoloy on our server.

  1. Technology

What technology should I be exploring to achieve this? Is it called image segmentation or the better term would be image matting? (e.g. http://alphamatting.com/eval_25.php)

I have been reading a lot and I am currently lost in the sea of various technologies out there (OpenCV, Deep Matting, Mask RCNN, Instance Segmentation, Detectron2, Tensorflow, Pytorch, etc). I wonder what magic is happening behind the curtains of www.remove.bg

Would your library help me to achieve what I am looking for? Any help you could provide or a nudge in the right direction would be awesome.

Thanks a ton!

A question on data augmentation part

np.random.randint() returns coordinates of (h, w) instead of (w, h). However, in core/data.py, line 42 - 44, the coordinates are regarded as (w, h)? Is it a bug or ... ? Please correct me if there is any problem.

Thanks.

License

Hello!
I would like to thank you for developping this and sharing your code. As part of my job at University of Mons (UMONS - Belgium), I will need to use your pretrained model and modify some of your scripts. The later goal is to create a deep learning live demo for general public and hopfully raise awareness for new technologies - No comercial use.
Do you plan to add a licence to this project? If not, do you let me use and modify your code - with citation of your work?
Regards.

PS:As this is quite urgent, could you answser before october 18th please?

composition.py

thaks for your great work. When I compose the image using your composition.py file(Opencv),I got a strange merged image like this
image
But when I use the composition code provided by original DIM(PIL),I got the exact merged image like this
image
I don't known what caused the error? It seems that opencv may have some bugs

Doubts about the alpha_loss

After read the code and implemented the training, I had some doubts about the alpha_loss calculation. According to the paper, the formulation'd relate to GT alpha and the Pred alpha, why was the trimap involved? And during the training, the alpha_loss was always staying at about 0.2, never got converged.
Glad to talk with you.

Question about real world images

I trained the stage 0 and the result in test dataset was good.

But when I try to use my own image for inference, it was far from what I expected.

I think it's because of the ** TRIMAP ** .

The trimap Adobe provide was very precise, how can we get such nice trimaps of ourself images?

step parameter

Hi @huochaitiantang , firstly thanks for the good job!
I am trying the training of the model and everything works well, but I don't understand the meaning of the parameter "step".
In the train.sh is set to -1 and in the train.py to 10. Which is the meaning of these numbers?
It seems is serve to influence the learning rate during training, but I don't understand how
Thanks in advance!

How to calculate connectivity error

Hello, I am a noob for image matting task and I was wondering how did you calculate the connectivity error. Could you share some codes on this?

pytorch-deep-image-matting implementation on Android

Hi,
Thanks for you work!
its looking awesome output.
I want to integrate your demo into android project. Is it possible to integrate model into android Project? If it possible, then How can i integrate this model into android project?
Can you please give some suggestions?
Thanks in advance.

Getting horrible performance, why?

Hi , I run this code exactly same. but i am getting very worst result.
Avg-Cost: 0.4214326207637787 s/image
Eval-MSE: 0.0514946369834948
Eval-SAD: 92497.11225042422
why?

运行demo.py收到了一个报错

Load Error:'utf-8' codec can't decode byte 0xda in position 5: invalid continuation byte
Try Load Again...
Traceback (most recent call last):
File ".\core\demo.py", line 17, in my_torch_load
ckpt = torch.load(fname)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\serialization.py", line 585, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\serialization.py", line 765, in _legacy_load
result = unpickler.load()
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xda in position 5: invalid continuation byte

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File ".\core\demo.py", line 66, in
ckpt = my_torch_load(args.resume)
File ".\core\demo.py", line 30, in my_torch_load
ckpt = torch.load(args.resume, pickle_module=c)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\serialization.py", line 585, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\serialization.py", line 755, in _legacy_load
magic_number = pickle_module.load(f, **pickle_load_args)
TypeError: c_load() got an unexpected keyword argument 'encoding'

我的 python版本是3.7.4

training error

Hi, and thanks for making this code available. i am trying to train it, and am getting the following error:

===> Environment init ===> Loading datasets --Valid Samples: 0 Traceback (most recent call last): File "core/train.py", line 357, in <module> main() File "core/train.py", line 337, in main train_loader = get_dataset(args) File "core/train.py", line 67, in get_dataset train_set = MatDatasetOffline(args, train_transform, normalize) File "X:\dev\ML\DEEP_MATTE\pytorch-deep-image-matting-master\core\data.py", line 99, in __init__ assert(self.cnt > 0) AssertionError

running with this cmd:

python core/train.py --crop_h=320,480,640 --crop_w=320,480,640 --size_h=320 --size_w=320 --alphaDir=X:/dev/ML/DEEP_MATTE/data/train/gt --fgDir=X:/dev/ML/DEEP_MATTE/data/train/fg --bgDir=X:/dev/ML/DEEP_MATTE/data/train/bg --imgDir=X:/dev/ML/DEEP_MATTE/data/train/image --saveDir=model/stage0_norm --batchSize=1 --nEpochs=25 --step=-1 --lr=0.00001 --wl_weight=0.5 --threads=4 --printFreq=10 --ckptSaveFreq=1 --cuda --stage=0 --pretrain=model/vgg_state_dict.pth --testFreq=1 --testImgDir=X:/dev/ML/DEEP_MATTE/data/test/image --testTrimapDir=X:/dev/ML/DEEP_MATTE/data/test/trimap --testAlphaDir=X:/dev/ML/DEEP_MATTE/data/test/gt --testResDir=result/tmp --crop_or_resize=whole --max_size=1600

what am i doing wrong here?

Thanks!

Problem in modifying the code to multi GPU process

Hi, thank for your awesome work. I want to modify the code to fit the multi GPU process, and I modify your main code below:

    if args.cuda:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model = nn.DataParallel(model)
        model.to(device)

But, I got error:

 Traceback (most recent call last):
  File "core/train.py", line 361, in <module>
    main()
  File "core/train.py", line 353, in main
    train(args, model, optimizer, train_loader, epoch)
  File "core/train.py", line 200, in train
    pred_mattes, pred_alpha = model(input_img)
  File "/home/chaofan/lib/anaconda2/envs/python36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/chaofan/lib/anaconda2/envs/python36/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 124, in forward
    return self.gather(outputs, self.output_device)
  File "/home/chaofan/lib/anaconda2/envs/python36/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 136, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/home/chaofan/lib/anaconda2/envs/python36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 67, in gather
    return gather_map(outputs)
  File "/home/chaofan/lib/anaconda2/envs/python36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 62, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
  File "/home/chaofan/lib/anaconda2/envs/python36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 62, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
TypeError: zip argument #1 must support iteration

I have no idea about this problem, do you have some suggestions? Thank you!

performance for stage2

Hi~ Thank you for the excellent work.
I have reproduced the performance of stage1 followed you codes, but I can not reproduce the performance of stage2 in the paper (50 SAD).
Would you provide your performance and model of stage2 if you have tried.
Thanks!

Run Demo error

Hi, I meet this error when I try to run the code/demo.py, anyone solve it?

Traceback (most recent call last):
File "core/demo.py", line 42, in
ckpt = torch.load(args.resume)
File "/home/mmdb-gp8/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py", line 303, in load
return _load(f, map_location, pickle_module)
File "/home/mmdb-gp8/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py", line 469, in _load
result = unpickler.load()
UnicodeDecodeError: 'ascii' codec can't decode byte 0x94 in position 1: ordinal not in range(128)

BTW, the model Stage1-SAD=57.1 can not be access ☹️

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.