hsinyinglee / drit Goto Github PK
View Code? Open in Web Editor NEWLearning diverse image-to-image translation from unpaired data
Learning diverse image-to-image translation from unpaired data
There are two typos in the updated /src/model.py
if not self.no_ms
should be if not self.no_ms:
lozz_lz_BA
shoul be loss_lz_BA
Does DRIT and MSGAN work with batch size > 2, half size > 1? I ran some preliminary tests and there is zero diversity, full mode collapse when the aforementioned settings are used.
Can DRIT and MSGAN support / be modified to support these settings? Thanks.
When running the code it encounters an error at: https://github.com/HsinYingLee/DRIT/blob/master/src/model.py#L362-L367
# latent regression loss
if self.concat:
loss_z_L1_a = torch.mean(torch.abs(self.mu2_a - self.z_random)) * 10
loss_z_L1_b = torch.mean(torch.abs(self.mu2_b - self.z_random)) * 10
else:
loss_z_L1_a = torch.mean(torch.abs(self.z_attr_random_a - self.z_random)) * 10
loss_z_L1_b = torch.mean(torch.abs(self.z_attr_random_b - self.z_random)) * 10
because self.z_random
is generated without specifying the GPU here:
def get_z_random(self, batchSize, nz, random_type='gauss'):
z = torch.cuda.FloatTensor(batchSize, nz)
z.copy_(torch.randn(batchSize, nz))
return z
This worked for me:
def get_z_random(self, batchSize, nz, random_type='gauss'):
z = torch.FloatTensor(batchSize, nz)
z.copy_(torch.randn(batchSize, nz))
return z.cuda(self.gpu)
First of all, thanks for your great work !
In your code, disA2
and disB2
are trained via random attribute code, and used to update generator in backward_G_alone
.
However, backward_G_alone
doesn't seem to be used in training process.
What is the purpose of disA2
, disB2
and backward_G_alone
?
I think line 157 in models.py
should be
self.z_content_recon_a, self.z_content_recon_b = self.enc_c.forward(self.fake_A_encoded, self.fake_B_encoded)
instead of
self.z_content_recon_b, self.z_content_recon_a = self.enc_c.forward(self.fake_A_encoded, self.fake_B_encoded)
.
Hi,
Shall you send me the pretrain model weight? My email address is [email protected]
Hi, I am getting the following error. How can solve this? Thanks in advance.
--- train ---
/home/mahfuj/drit_virtual/lib/python3.5/site-packages/torch/nn/functional.py:1006: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
Traceback (most recent call last):
File "/home/mahfuj/drit_virtual/lib/python3.5/site-packages/PIL/Image.py", line 2460, in fromarray
mode, rawmode = _fromarray_typemap[typekey]
KeyError: ((1, 1, 438), '|u1')
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "train.py", line 78, in
main()
File "train.py", line 56, in main
saver.write_display(total_it, model)
File "/media/user/DATA/wacv_round2/DRIT/src/saver.py", line 55, in write_display
self.writer.add_image('Image', image_dis, total_it)
File "/home/user/drit_virtual/lib/python3.5/site-packages/tensorboardX/writer.py", line 412, in add_image
self.file_writer.add_summary(image(tag, img_tensor), global_step, walltime)
File "/home/user/drit_virtual/lib/python3.5/site-packages/tensorboardX/summary.py", line 205, in image
image = make_image(tensor, rescale=rescale)
File "/home/user/drit_virtual/lib/python3.5/site-packages/tensorboardX/summary.py", line 243, in make_image
image = Image.fromarray(tensor)
File "/home/user/drit_virtual/lib/python3.5/site-packages/PIL/Image.py", line 2463, in fromarray
raise TypeError("Cannot handle this data type")
TypeError: Cannot handle this data type
Hi Lee, Could you provide the network code for the domain adaptation experiment? I use the current network to run. It's too slow.
when i change the crop_size of pictures,the problem happens.
File networks.py line ,in forward
return self.model(x)
Runtime error:sizes must be non-negative
Does the DRIT only suit for the picture size 216 *216 ?
#24
HsinYingLee:
I think there are possible several settings can be used:
1. If you would like to compare with the ground truth set and other methods, say, a collection of N images. You can translate N target images from N source images. Then you can randomly sample M pairs out of this N images to calculate the diversity.
2. Other way to compare among different methods is, given N source images, for each image, you translate M of them and calculate a diversity score among these M target images. You then average the score of all N images.
We use the first setting in our experiments. Since there's no standard setting in this kind of experiment yet, I believe any setting that makes same should be okay.
I dont understand the way1 “randomly sample M pairs out of this N images to calculate the diversity”,in table2 "real images .448 ± .012", for each pair, is there a ground truth image A and a ground truth image B ? in table2 " DRIT .424 ± .010", for each pair, is there a translate image A ( use method DRIT and translate from ground truth image A )and a translate image B ( use method DRIT and translate from ground truth image B ?
@HsinYingLee
Can you share please how you do the linear interpolation between two attributes? Is it just a linspace?
Thank you.
Hi,
Thank you so much for the code and paper. My dataset contains 4 domains. The dataset structure is: trainA, trainB, trainC, trainD, testA, testB, testC, testD. How can I use this code for if 4 domains are exist in the dataset? Thanks in advanced.
Hi,
I have a question regarding the cross-cycle consistency loss.
By the way, I saw your ECCV 2018 oral presentation in Munich and I greatly inspired by your work.
Thanks.
According to the paper, the attribute representations are swapped during the translation stage to calculate the cross-cycle consistency loss.
However, as far as I understand, the content representations are swapped when I look the expressions carefully.
Is my understanding correct? (I can't be sure about my understanding infact...)
Hi, thanks for this marvelous work!
I want to run your project in yosemite data, it seems that in your code, the training epoch is 1200 and batch size is 2.
How long do you finish your traing process?
Did you try with the re-parametrization trick in the KL loss at training stage? For instance here: https://github.com/HsinYingLee/DRIT/blob/master/src/model.py#L134-L140.
Would it be worth it?
Which version is this Code? DRIT or DRIT++??
E:\Users\Raytine\Anaconda3\python.exe F:/zhaiyao/DRIT-master/src/ptest.py --dataroot ../datasets/portrait --name portrait --resume ../results/01199.pth
--- load options ---
a2b: 1
concat: 1
crop_size: 216
dataroot: ../datasets/portrait
gpu: 0
input_dim_a: 3
input_dim_b: 3
nThreads: 4
name: portrait
num: 5
phase: test
resize_size: 256
result_dir: ../outputs
resume: ../results/01199.pth
--- load dataset ---
A: 594 images
--- load model ---
Traceback (most recent call last):
File "F:/zhaiyao/DRIT-master/src/ptest.py", line 50, in
main()
File "F:/zhaiyao/DRIT-master/src/ptest.py", line 25, in main
model.resume(opts.resume, train=False)
File "F:\zhaiyao\DRIT-master\src\model.py", line 401, in resume
self.enc_a.load_state_dict(checkpoint['enc_a'])
File "E:\Users\Raytine\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 721, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for E_attr_concat:
Missing key(s) in state_dict: "fc_A.0.bias", "fc_A.0.weight", "fcVar_A.0.bias", "fcVar_A.0.weight", "conv_A.1.bias", "conv_A.1.weight", "conv_A.2.conv.2.bias", "conv_A.2.conv.2.weight", "conv_A.2.conv.4.1.bias", "conv_A.2.conv.4.1.weight", "conv_A.2.shortcut.1.bias", "conv_A.2.shortcut.1.weight", "conv_A.3.conv.2.bias", "conv_A.3.conv.2.weight", "conv_A.3.conv.4.1.bias", "conv_A.3.conv.4.1.weight", "conv_A.3.shortcut.1.bias", "conv_A.3.shortcut.1.weight", "conv_A.4.conv.2.bias", "conv_A.4.conv.2.weight", "conv_A.4.conv.4.1.bias", "conv_A.4.conv.4.1.weight", "conv_A.4.shortcut.1.bias", "conv_A.4.shortcut.1.weight", "fc_B.0.bias", "fc_B.0.weight", "fcVar_B.0.bias", "fcVar_B.0.weight", "conv_B.1.bias", "conv_B.1.weight", "conv_B.2.conv.2.bias", "conv_B.2.conv.2.weight", "conv_B.2.conv.4.1.bias", "conv_B.2.conv.4.1.weight", "conv_B.2.shortcut.1.bias", "conv_B.2.shortcut.1.weight", "conv_B.3.conv.2.bias", "conv_B.3.conv.2.weight", "conv_B.3.conv.4.1.bias", "conv_B.3.conv.4.1.weight", "conv_B.3.shortcut.1.bias", "conv_B.3.shortcut.1.weight", "conv_B.4.conv.2.bias", "conv_B.4.conv.2.weight", "conv_B.4.conv.4.1.bias", "conv_B.4.conv.4.1.weight", "conv_B.4.shortcut.1.bias", "conv_B.4.shortcut.1.weight".
Unexpected key(s) in state_dict: "model_a.1.weight", "model_a.1.bias", "model_a.4.weight", "model_a.4.bias", "model_a.7.weight", "model_a.7.bias", "model_a.10.weight", "model_a.10.bias", "model_a.13.weight", "model_a.13.bias", "model_a.16.weight", "model_a.16.bias", "model_b.1.weight", "model_b.1.bias", "model_b.4.weight", "model_b.4.bias", "model_b.7.weight", "model_b.7.bias", "model_b.10.weight", "model_b.10.bias", "model_b.13.weight", "model_b.13.bias", "model_b.16.weight", "model_b.16.bias".
Hi, thanks for this marvelous work!
I am reading the paper and found equation (2) a little bit confusing.
Why is the content adversarial loss function like this ?
If we assign domain x to be 0, and domain y to be 1, then shouldn't this binary cross entropy loss look like? :
Have I misunderstood anything?
Thanks in advance!
Hi, http://vllab.ucmerced.edu/hylee/DRIT/datasets/$DATASET.zip is 404, can you check it? Thanks.
why do not use mean in kl_loss
why sum ?
in other kl_loss.. you use mean !
why 2 discriminator(disA, disA2) link
what do you think just update it once with batch size 1?
why did you calculate latent regression loss using only mu?
link
link2
According to the paper, should we calculate the random_normal distribution using mu and stddev
and calculate the loss?
According to the paper, KL_loss seems to fit only for the attribute, why does the code calculate the KL_loss for the content? link
Hi, thanks for this amazing job.
Without changing any codes, I encountered an error when I ran python train.py. Any idea why this happened?
File "E:/project/DRIT/src/train.py", line 78, in
main()
File "E:/project/DRIT/src/train.py", line 52, in main
model.update_EG()
File "E:\project\DRIT\src\model.py", line 301, in update_EG
self.backward_G_alone()
File "E:\project\DRIT\src\model.py", line 401, in backward_G_alone
loss_z_L1.backward()
File "xxxxxx\anaconda3\lib\site-packages\torch\tensor.py", line 198, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "xxxxxx\anaconda3\lib\site-packages\torch\autograd_init_.py", line 100, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [256, 8]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Hi, thanks for your great work about GAN.
I found in your paper you said you use batch size of 1, but in this code the default training batch size is set to 2.
I would like to ask is there any special training trick about batch size?
Hi,
I try your code but after 300 epoch it give me error in model writer the original message is :
TypeError: write_model() missing 1 positional argument: 'model'
I didn't change your code everything is as it is. I don't know whats the problem the error is in
line 62 in train:
saver.write_model(-1,model)
in epoch 302 when it reaches to 183 iterations it stop with the above message.
please have a look >>>> thanks
thanks for your work, i try your model with batch size 1 for small GPU(8GB,12GB) it doesn't wok. i check it both in ubuntu and windows;
in ubuntu it shows the error msg: floating point exception(core dumped)
in windows: it shows nothing but the execution stop.
Can you tell me the reason. BTW i check it in my friends computer too but the error is same
Congrats Lee and your team for your great work!
I have tried the method on difference datasets and it worked very well. However, edges2shoes is an exception. I have in the attachment the result after 14 epochs of training (~40 hours).
I am wondering if edges is the problem? Have anyone tried training the method with a similar dataset?
Thank you!
Since no_ms is a train only option, at test time opt.no_ms (loading the test options) is not recognized, raising error.
Thank you very much for sharing your paper code. When I tested your code, the download link for the pretrain model failed. Can you send me a copy of your trained summer-winter model weight or give me a download link. Thank you! My email address is [email protected]
thanks for your excellent work?
why adding gaussian noise layer in E_encoder is really confused me? hope for your apply
When I train get some pkl files but test need pth file?
Hi,
Thanks for your sharing, while I want to download the dataset which you shared, I encountering "403", could you please tell me what is going on?
Are these code about self-reconstuction loss ?
loss_G_L1_A = self.criterionL1(self.fake_A_recon, self.real_A_encoded)
loss_G_L1_B = self.criterionL1(self.fake_B_recon, self.real_B_encoded)
AND
Are these code about cross cycle consistency loss ?
loss_G_L1_AA = self.criterionL1(self.fake_AA_encoded, self.real_A_encoded)
loss_G_L1_BB = self.criterionL1(self.fake_BB_encoded, self.real_B_encoded)
Thanks for your awesome work!
I see the there are two discriminators for each domain, which are disA
and disA2
for domain A, disB
and disB2
for domain B. I have noticed that they are updated with different inputs. However, I don't know if double discriminators are necessary here.
Have I missed something in the paper? Could you explain a little more about this?
In your original paper, the attribute code is mostly exchanged to guide the translation, but in your code, you add using random noise to guide the translation like MUNIT, why?
Actually, I also have some questions about mode-seeking in translation models.
If I translate a cat into a dog, since the cycle-consistency, the attribute encoder has to encode the details information of the original cat image in the attribute code. The details of the cat image should have increased the diversity of the translation (but in experiments, it seems not like this). So why, is it because of the weight of cycle-loss too small?
Thank you for sharing your source code.
At line 343, loss value is not added to ad_loss
. Is this correct?
I think you should add loss values to ad_loss
like this.
ad_loss += nn.functional.binary_cross_entropy(outputs_fake, all_half)
Lines 338 to 344 in dbfa804
Hi,
I found if the maximal size of the train sets (trainA or trainB) is an integer multiple of opt.d_iter, the model will not execute model.update_D(images_a, images_b) when each epoch ends
Lines 46 to 52 in a590bb4
Lines 69 to 70 in a590bb4
As a result, there are some problems about saving result images. The input images will update to a new pair, while the other images remain unchanged.
For example, my trainA contains 2733 images, and I set opt.d_iter=3. And here is my result grid
If I delete 1 image in trainA, it works well
Hi, thanks for sharing the impressive work. I have a question about the network updating strategy in the train.py file.
For if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2:
Line 47 in f19f50a
if d_iter is set as 3, in each epoch, the content discriminator will be updated using almost 2/3 of the total batches, and the model.update_D(images_a, images_b) and model.update_EG() will be updated using only about 1/3 of total data, right?
In each epoch, the model.update_D(images_a, images_b) and model.update_EG() two parts will only be feed 1 batch in each 3, so, why don't use the same batches to update these three parts?
For example, use every batch to update these three parts, but the content discriminator part can be trained multiple times (e.g. 2 times if d_iter=3) by every batch.
I am new to GAN, if I have any misunderstanding, please correct me, thanks!
The idea of this paper same to MUNIT https://github.com/NVlabs/MUNIT
Both propose share content space and separate style space. I want to know the differences.
Hello,
From your code, you only update the gradients of the style encoder in backward_EG
only from the KL loss. However, as far as I can see, backward_G_alone
also provides gradients with loss_z_L1
, and you are calling self.enc_c_opt.step()
after this function, instead of self.enc_a_opt.step()
. Does it make sense? Am I missing something here?
Thank you.
Many congratulations that your paper being accepted in ECCV as oral !!
I think there is a small mistake in your code.
In attribute encoder of domain B:
Line 164 in ef1c6a2
Maybe it's not input_dim_a
but input_dim_b
.
Hi
I don't know what difference E_attr_concat
and E_attr
what you use ? concat
? or no_concat
?
Is there a difference in performance?
can you tell me ?
thank you
How can I train with multi-gpus? nn.Dataparallel does't work.
Hi
very nice work.
I can't find supplementary material
where is ? can you give me ?
Thank you
Could you please provide more details about the evaluation. You mentioned 1000 pairs of randomly sampled images translated from 100 images, that is 10 pairs per image (20 fixed random styles for the whole set?), and then the average?
Thank you.
Did you employ WGAN for training? I notice this clip_grad_norm_ in your code.
Hi Lee, recently I want to reproduce your experiments about domain adaption in this paper. But I can't find the details in your README. Could you provide some guidance? I will really appreciate that.
Hi,
Could you please tell me how you compared different models? Did you use the same learning rate, number of epochs, Number of decay epochs, image size, optimizer among all models? Also, did you collect test results using the final saved generator or did you use the best results testing all saved generators at different epochs?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.