Giter Site home page Giter Site logo

richzhang / colorization-pytorch Goto Github PK

View Code? Open in Web Editor NEW
596.0 596.0 112.0 11.04 MB

PyTorch reimplementation of Interactive Deep Colorization

Home Page: https://richzhang.github.io/ideepcolor/

License: MIT License

Python 96.88% Shell 3.12%
computer-graphics computer-vision convolutional-neural-networks deep-learning image-colorization pytorch siggraph

colorization-pytorch's People

Contributors

alanyee avatar andersasa avatar andyli avatar gdlg avatar guopzhao avatar iver56 avatar jpmerc avatar junyanz avatar lambdawill avatar layumi avatar levirve avatar mengcz13 avatar naruto-sasuke avatar pertence avatar phillipi avatar richzhang avatar ruotianluo avatar simontreu avatar ssnl avatar strob avatar taesungp avatar tariqahassan avatar tylercarberry avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

colorization-pytorch's Issues

Errors when running CPU only

Im trying to run using CPU only in a VM. I keep getting the following error during the first part of the batch script:

`[Network G] Total number of parameters : 34.187 M

No handlers could be found for logger "visdom"
create web directory ./checkpoints/siggraph_class_small/web...
Traceback (most recent call last):
File "train.py", line 61, in
model.optimize_parameters()
File "/home/testing/Desktop/colorization-pytorch-master/models/pix2pix_model.py", line 193, in optimize_parameters
self.forward()
File "/home/testing/Desktop/colorization-pytorch-master/models/pix2pix_model.py", line 123, in forward
self.fake_B_dec_max = self.netG.module.upsample4(util.decode_max_ab(self.fake_B_class, self.opt))
File "/home/testing/.local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 576, in getattr
type(self).name, name))
AttributeError: 'SIGGRAPHGenerator' object has no attribute 'module'
mkdir: cannot create directory ‘./checkpoints/siggraph_class’: File exists
cp: cannot stat './checkpoints/siggraph_class_small/latest_net_G.pth': No such file or directory
`

test.py is not compatible with python 3

Hi, I am trying to test your model in google colab. I put a dummy image a dummy folder:
'/content/colorization-pytorch/dataset/ilsvrc2012/val/MyImages'

I tested these two following commands:
#!python test.py --name siggraph_caffemodel --mask_cent 0
!python test.py --name siggraph_retrained

Both led to the same error:

Traceback (most recent call last):
File "test.py", line 53, in
data_raw[0] = util.crop_mult(data_raw[0], mult=8)
File "/content/colorization-pytorch/util/util.py", line 277, in crop_mult
return data[:,:,h:h+Hnew,w:w+Wnew]
TypeError: slice indices must be integers or None or have an index method

I fixed it by casting indices to integers. Then I got another error in test.py line 57:
img_path = [string.replace('%08d_%.3f' % (i, sample_p), '.', 'p')]

string.replace is deprecated in python 2.7, and does not exists in python 3, see here: https://docs.python.org/2.7/library/string.html?highlight=string%20replace#string.replace

However, when I change my google colab execution to python 2, test.py works like a charm.

num_threads or nThreads?

I've noticed that the base options for this repository uses opt.nThreads, but I can't actually see anywhere this is used.

I've also done some digging and in the data folder init.py the CustomDatasetLoader calls opt.num_threads, which I believe is what the pix2pix repo uses in their base options, although I don't think this is used either in this project?

Is the data loading for this project actually multi-threaded?
In train.py torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True) is called, should I change this to include num_workers=int(opt.nThreads) if I want to speed up the data loading?

Thanks for the amazing code/project! :)

decode_ind_ab

in utils decode_ind_ab(), the calculations is

    data_a = data_q/opt.A
    data_b = data_q - data_a*opt.A
    data_ab = torch.cat((data_a, data_b), dim=1)

however I believe according to how the encoding was done we should have instead something like
data_a = (data_q - data_b)/opt.A

I'm imagining this would have to be solved using linear programming or something.
I was just wondering if this is something you're aware of and whether I am missing something?

My issue is that when I use decode_ind_ab currently all my b values come through as -1 as
with

    data_a = data_q/opt.A (eq1)
    data_b = data_q - data_a*opt.A (eq2)

we can sub eq1 into eq2 to show that

data_b = data_q - data_q = 0

which then gets scaled and shifted to -1 before being returned.

Classification output in SIGGRAPHGenerator, #channel=529 ?

Hi,

I'm reading the implementation detail of the SIGGRAPHGenerator, but had some trouble understanding the classification output of this network.
Classification output in SIGGRAPHGenerator:

model_class=[nn.Conv2d(256, 529, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias),]

Is 529 the number of quantized color Q? While I suppose this number to be 313 instead...
Could anyone please explain this number?

Thanks,

Custom Dataset

The model is a train on ILSVRC2012. Could someone please help me out in clarifying how to train it on my custom dataset?

Global hints network

Hello, Mr Zhang. I've read this source code, but I can't find the global hint module. In the pix2pix_model.py file, the forward module is as following:

def forward(self):
        (self.fake_B_class, self.fake_B_reg) = self.netG(self.real_A, self.hint_B, self.mask_B)
        self.fake_B_dec_max = self.netG.module.upsample4(util.decode_max_ab(self.fake_B_class, self.opt))
        self.fake_B_distr = self.netG.module.softmax(self.fake_B_class)

        self.fake_B_dec_mean = self.netG.module.upsample4(util.decode_mean(self.fake_B_distr, self.opt))

        self.fake_B_entr = self.netG.module.upsample4(-torch.sum(self.fake_B_distr * torch.log(self.fake_B_distr + 1.e-10), dim=1, keepdim=True))
        # embed()

The local hint module is applied, but I can't find the global hints module, and the ground truth of color distribution is also not calculated. Maybe I miss some important things, can you provide me some information? Thank you!

Training tends to yellowish colors

Hi,
Thanks for this Pytorch implementation. I'm following your exact tutorial to train the model on my Mac machine CPU, the problem is that as I monitor the visdom console, I see the output fake_reg image goes more yellowish as the losses curve goes down. I've had this same exact issue while training the original colorization model (i.e. the one without hints) introduced by the paper authors. I don't know why this actually happens, any ideas?

Thank you!

I found the error in util.py

In the line number 248, average patch values are not appropriately computed.
in utils add_color_patches_rand_gt() (line number 248), the calculations is
torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=1,keepdim=True).view(1,C,1,1)
however, I believe that this code should be changed to
torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=-1,keepdim=True).view(1,C,1,1)

The previous code calculates the mean value within the ab channel, while the second calculates within the patch values.

Discriminator in the pretrained model

So I can see that there is a discriminator defined in the graph, if lambda_ GAN > 0. Do you guys have pretrained models that use the discriminator? I am curious about the difference in performance with and without discriminator?
Or what are hyperparameters if providing a pretrained model is not possible?

Where are the local and global hints networks defined?

I cannot find a reference to either the local or global hints networks anywhere in the source code. I've had a look at the definition of the SIGGRAPHGenerator class, but that seems to be just the main colorization network. I would be very grateful for pointing me to the relevant lines. Thank you.

Mac running test.py error

I tried to running the test.py but got this error. Any thoughts?

Traceback (most recent call last):
  File "/Users/spikeyuan/PycharmProjects/pythonProject6/colorization-pytorch/test.py", line 61, in <module>
    model.test(True)  # True means that losses will be computed
  File "/Users/spikeyuan/PycharmProjects/pythonProject6/colorization-pytorch/models/base_model.py", line 56, in test
    self.forward()
  File "/Users/spikeyuan/PycharmProjects/pythonProject6/colorization-pytorch/models/pix2pix_model.py", line 123, in forward
    self.fake_B_dec_max = self.netG.module.upsample4(util.decode_max_ab(self.fake_B_class, self.opt))
  File "/opt/anaconda3/envs/pythonProject6/lib/python2.7/site-packages/torch/nn/modules/module.py", line 576, in __getattr__
    type(self).__name__, name))
AttributeError: 'SIGGRAPHGenerator' object has no attribute 'module'

I run this project on Colab,Then I encounter a problem

The error information is No such file or directory: './dataset/ilsvrc2012/val/1/ILSVRC2012_val_00000059.JPEG',But I can find this image in the directory,Your prompt attention to my question is appreciated,Thanks for your help!!

'SIGGRAPHGenerator' object has no attribute 'model'

when running on MAC(python3.5.4),I encountered the following problem:'SIGGRAPHGenerator' object has no attribute 'model'
I wonder how to fix it?
Traceback (most recent call last):
File "test.py", line 40, in
model.setup(opt)
File "/Users/joanna/Desktop/colorization/colorization-pytorch-master/models/base_model.py", line 42, in setup
self.load_networks(opt.which_epoch)
File "/Users/joanna/Desktop/colorization/colorization-pytorch-master/models/base_model.py", line 136, in load_networks
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
File "/Users/joanna/Desktop/colorization/colorization-pytorch-master/models/base_model.py", line 116, in __patch_instance_norm_state_dict
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
File "/Users/joanna/opt/anaconda3/envs/color/lib/python3.5/site-packages/torch/nn/modules/module.py", line 576, in getattr
type(self).name, name))
AttributeError: 'SIGGRAPHGenerator' object has no attribute 'model'

Model training Loss graph

newplot

Plot.ly Link: https://plot.ly/~Ugness/1/

I am training the model with ILSVRC2012 training set and same options as your implementation and my loss graph looks like above.
I am afraid that my model's loss reduces correctly. Can you check this loss graph or share your loss graph?
Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.