richzhang / colorization-pytorch Goto Github PK
View Code? Open in Web Editor NEWPyTorch reimplementation of Interactive Deep Colorization
Home Page: https://richzhang.github.io/ideepcolor/
License: MIT License
PyTorch reimplementation of Interactive Deep Colorization
Home Page: https://richzhang.github.io/ideepcolor/
License: MIT License
So the L1 loss function more suit for Colorization task?
Im trying to run using CPU only in a VM. I keep getting the following error during the first part of the batch script:
No handlers could be found for logger "visdom"
create web directory ./checkpoints/siggraph_class_small/web...
Traceback (most recent call last):
File "train.py", line 61, in
model.optimize_parameters()
File "/home/testing/Desktop/colorization-pytorch-master/models/pix2pix_model.py", line 193, in optimize_parameters
self.forward()
File "/home/testing/Desktop/colorization-pytorch-master/models/pix2pix_model.py", line 123, in forward
self.fake_B_dec_max = self.netG.module.upsample4(util.decode_max_ab(self.fake_B_class, self.opt))
File "/home/testing/.local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 576, in getattr
type(self).name, name))
AttributeError: 'SIGGRAPHGenerator' object has no attribute 'module'
mkdir: cannot create directory ‘./checkpoints/siggraph_class’: File exists
cp: cannot stat './checkpoints/siggraph_class_small/latest_net_G.pth': No such file or directory
`
Hi, I am trying to test your model in google colab. I put a dummy image a dummy folder:
'/content/colorization-pytorch/dataset/ilsvrc2012/val/MyImages'
I tested these two following commands:
#!python test.py --name siggraph_caffemodel --mask_cent 0
!python test.py --name siggraph_retrained
Both led to the same error:
Traceback (most recent call last):
File "test.py", line 53, in
data_raw[0] = util.crop_mult(data_raw[0], mult=8)
File "/content/colorization-pytorch/util/util.py", line 277, in crop_mult
return data[:,:,h:h+Hnew,w:w+Wnew]
TypeError: slice indices must be integers or None or have an index method
I fixed it by casting indices to integers. Then I got another error in test.py line 57:
img_path = [string.replace('%08d_%.3f' % (i, sample_p), '.', 'p')]
string.replace is deprecated in python 2.7, and does not exists in python 3, see here: https://docs.python.org/2.7/library/string.html?highlight=string%20replace#string.replace
However, when I change my google colab execution to python 2, test.py works like a charm.
I've noticed that the base options for this repository uses opt.nThreads, but I can't actually see anywhere this is used.
I've also done some digging and in the data folder init.py the CustomDatasetLoader calls opt.num_threads, which I believe is what the pix2pix repo uses in their base options, although I don't think this is used either in this project?
Is the data loading for this project actually multi-threaded?
In train.py torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True) is called, should I change this to include num_workers=int(opt.nThreads) if I want to speed up the data loading?
Thanks for the amazing code/project! :)
This is little comment about some functions in utils. Can you explain the ab transformation in detail? Thank you very much.
How can I export ONNX model from latest_net_G.pth?
in utils decode_ind_ab(), the calculations is
data_a = data_q/opt.A
data_b = data_q - data_a*opt.A
data_ab = torch.cat((data_a, data_b), dim=1)
however I believe according to how the encoding was done we should have instead something like
data_a = (data_q - data_b)/opt.A
I'm imagining this would have to be solved using linear programming or something.
I was just wondering if this is something you're aware of and whether I am missing something?
My issue is that when I use decode_ind_ab currently all my b values come through as -1 as
with
data_a = data_q/opt.A (eq1)
data_b = data_q - data_a*opt.A (eq2)
we can sub eq1 into eq2 to show that
data_b = data_q - data_q = 0
which then gets scaled and shifted to -1 before being returned.
flake8 testing of https://github.com/richzhang/colorization-pytorch on Python 3.7.0
$ flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics
./util/util.py:70:18: F821 undefined name 'visuals'
subset = visuals
^
1 F821 undefined name 'visuals'
1
Hi,
I'm reading the implementation detail of the SIGGRAPHGenerator
, but had some trouble understanding the classification output of this network.
Classification output in SIGGRAPHGenerator
:
colorization-pytorch/models/networks.py
Line 315 in 9fd9bd8
Is 529 the number of quantized color Q? While I suppose this number to be 313 instead...
Could anyone please explain this number?
Thanks,
In you paper , no information for gan ,why in your code have gan? what is the use in your code
I do not modify any parameters, is there any problem? or i should change some train parameters? Hope you can help me ,thanks!
The model is a train on ILSVRC2012. Could someone please help me out in clarifying how to train it on my custom dataset?
Hi, thanks for your great work. I tried the interactive colorization GUI from https://github.com/junyanz/interactive-deep-colorization with your provided retrained model, but it doesn't work. The screen shot is as following:
Hello, Mr Zhang. I've read this source code, but I can't find the global hint module. In the pix2pix_model.py file, the forward module is as following:
def forward(self):
(self.fake_B_class, self.fake_B_reg) = self.netG(self.real_A, self.hint_B, self.mask_B)
self.fake_B_dec_max = self.netG.module.upsample4(util.decode_max_ab(self.fake_B_class, self.opt))
self.fake_B_distr = self.netG.module.softmax(self.fake_B_class)
self.fake_B_dec_mean = self.netG.module.upsample4(util.decode_mean(self.fake_B_distr, self.opt))
self.fake_B_entr = self.netG.module.upsample4(-torch.sum(self.fake_B_distr * torch.log(self.fake_B_distr + 1.e-10), dim=1, keepdim=True))
# embed()
The local hint module is applied, but I can't find the global hints module, and the ground truth of color distribution is also not calculated. Maybe I miss some important things, can you provide me some information? Thank you!
Hi,
Thanks for this Pytorch implementation. I'm following your exact tutorial to train the model on my Mac machine CPU, the problem is that as I monitor the visdom console, I see the output fake_reg image goes more yellowish as the losses curve goes down. I've had this same exact issue while training the original colorization model (i.e. the one without hints) introduced by the paper authors. I don't know why this actually happens, any ideas?
Thank you!
In the line number 248, average patch values are not appropriately computed.
in utils add_color_patches_rand_gt() (line number 248), the calculations is
torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=1,keepdim=True).view(1,C,1,1)
however, I believe that this code should be changed to
torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=-1,keepdim=True).view(1,C,1,1)
The previous code calculates the mean value within the ab channel, while the second calculates within the patch values.
So I can see that there is a discriminator defined in the graph, if lambda_ GAN > 0. Do you guys have pretrained models that use the discriminator? I am curious about the difference in performance with and without discriminator?
Or what are hyperparameters if providing a pretrained model is not possible?
I cannot find a reference to either the local or global hints networks anywhere in the source code. I've had a look at the definition of the SIGGRAPHGenerator
class, but that seems to be just the main colorization network. I would be very grateful for pointing me to the relevant lines. Thank you.
In your paper you say using the Global Hints Network, you can colorize a grayscale image using global histograms from an input color image. I want to know how to implement this using the pretrained model. Thank you.
I tried to running the test.py but got this error. Any thoughts?
Traceback (most recent call last):
File "/Users/spikeyuan/PycharmProjects/pythonProject6/colorization-pytorch/test.py", line 61, in <module>
model.test(True) # True means that losses will be computed
File "/Users/spikeyuan/PycharmProjects/pythonProject6/colorization-pytorch/models/base_model.py", line 56, in test
self.forward()
File "/Users/spikeyuan/PycharmProjects/pythonProject6/colorization-pytorch/models/pix2pix_model.py", line 123, in forward
self.fake_B_dec_max = self.netG.module.upsample4(util.decode_max_ab(self.fake_B_class, self.opt))
File "/opt/anaconda3/envs/pythonProject6/lib/python2.7/site-packages/torch/nn/modules/module.py", line 576, in __getattr__
type(self).__name__, name))
AttributeError: 'SIGGRAPHGenerator' object has no attribute 'module'
The error information is No such file or directory: './dataset/ilsvrc2012/val/1/ILSVRC2012_val_00000059.JPEG',But I can find this image in the directory,Your prompt attention to my question is appreciated,Thanks for your help!!
when running on MAC(python3.5.4),I encountered the following problem:'SIGGRAPHGenerator' object has no attribute 'model'
I wonder how to fix it?
Traceback (most recent call last):
File "test.py", line 40, in
model.setup(opt)
File "/Users/joanna/Desktop/colorization/colorization-pytorch-master/models/base_model.py", line 42, in setup
self.load_networks(opt.which_epoch)
File "/Users/joanna/Desktop/colorization/colorization-pytorch-master/models/base_model.py", line 136, in load_networks
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
File "/Users/joanna/Desktop/colorization/colorization-pytorch-master/models/base_model.py", line 116, in __patch_instance_norm_state_dict
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
File "/Users/joanna/opt/anaconda3/envs/color/lib/python3.5/site-packages/torch/nn/modules/module.py", line 576, in getattr
type(self).name, name))
AttributeError: 'SIGGRAPHGenerator' object has no attribute 'model'
Plot.ly Link: https://plot.ly/~Ugness/1/
I am training the model with ILSVRC2012 training set and same options as your implementation and my loss graph looks like above.
I am afraid that my model's loss reduces correctly. Can you check this loss graph or share your loss graph?
Thanks.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.