aelnouby / text-to-image-synthesis Goto Github PK
View Code? Open in Web Editor NEWPytorch implementation of Generative Adversarial Text-to-Image Synthesis paper
License: GNU General Public License v3.0
Pytorch implementation of Generative Adversarial Text-to-Image Synthesis paper
License: GNU General Public License v3.0
when train the GAN, i did not see text encoding process. you store text in sample.
` right_images = sample['right_images']
right_embed = sample['right_embed']
wrong_images = sample['wrong_images']
right_images = Variable(right_images.float()).cuda()
right_embed = Variable(right_embed.float()).cuda()
wrong_images = Variable(wrong_images.float()).cuda()
`
There is no text_embedding.
2. when we train the GAN, should we extract text and image embedding first or we simplely put row images into the network?
what is the embedding path for custom dataset prep
When I want to use the data file provided by author,my IDE report bug as "RuntimeError: Unable to get group info (wrong B-tree signature)" below.
import h5py f = h5py.File('flowers.hdf5', 'r') for k in f.keys(): print(f["test"])
Traceback (most recent call last): File "h5py\_objects.pyx", line 54, in h5py._objects.with_phil.wrapper File "h5py\_objects.pyx", line 55, in h5py._objects.with_phil.wrapper File "F:\Anaconda\lib\site-packages\h5py\_hl\group.py", line 623, in __repr__ r = '<HDF5 group %s (%d members)>' % (namestr, len(self)) File "h5py\_objects.pyx", line 54, in h5py._objects.with_phil.wrapper File "h5py\_objects.pyx", line 55, in h5py._objects.with_phil.wrapper File "F:\Anaconda\lib\site-packages\h5py\_hl\group.py", line 443, in __len__ return self.id.get_num_objs() File "h5py\_objects.pyx", line 54, in h5py._objects.with_phil.wrapper File "h5py\_objects.pyx", line 55, in h5py._objects.with_phil.wrapper File "h5py\h5g.pyx", line 336, in h5py.h5g.GroupID.get_num_objs RuntimeError: Unable to get group info (wrong B-tree signature)
Then I use vitables to check if there has any problems in this datafile,the result is that i can only see the groups named train,test,valid without dataset inside like this .
I am a freshman in using such things.Could anyone tell me where is my problem?
Hi, @aelnouby
Thank you for your sharing.
In the process of getting data, I want to Convert the data by myself, so I download the dataset as the describe of https://github.com/reedscot/cvpr2016 , then I got a link (flowers , for example) https://drive.google.com/open?id=0B0ywwgffWnLLcms2WWJQRFNSWXM and download a file named cvpr2016_flowers.tar.gz , then I unzipped it and the unzipped folder looked like following
I can't understand how you organized your file to excute the script 'convert_flowers_to_hd5_script.py' correctly . Did I download a wrong file ?
Looking forward your reply!
Can you please tell me how do I test it? I need to give an input and see the generated image.
Hi,
When training , you generate fake image twice in a mini-batch, however according to the paper is seems like when updating D and G, they both use the same fake image, so i'm confused about it..
I think generating fake image twice may increase the instability of training.
Hoping for your reply.
ModuleNotFoundError: No module named 'torch.utils.serialization'
Thanks for your contribution!
But when I run the code, an error occurs:
Traceback (most recent call last):
File "runtime.py", line 42, in
trainer.train(args.cls)
File "/home/zzw/program/text2img/text-to-Image-Synthesis-pytorch/trainer.py", line 65, in train
self._train_wgan(cls)
File "/home/zzw/program/text2img/text-to-Image-Synthesis-pytorch/trainer.py", line 103, in _train_wgan
sample = next(data_iterator)
File "/home/zzw/.local/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 281, in next
return self._process_next_batch(batch)
File "/home/zzw/.local/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 301, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
IOError: Traceback (most recent call last):
File "/home/zzw/.local/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 55, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/zzw/program/text2img/text-to-Image-Synthesis-pytorch/txt2image_dataset.py", line 46, in getitem
right_image = Image.open(io.BytesIO(right_image)).resize((64, 64))
File "/home/zzw/.local/lib/python2.7/site-packages/PIL/Image.py", line 2590, in open
% (filename if filename else fp))
IOError: cannot identify image file <_io.BytesIO object at 0x7f1be1801770>
My Pillow version is 5.1.0, and it seems like something related to the version.
Could anyone help me out?
Hi,
Thanks for the nice and very helpful work! I am also trying to do text-to-image generation, but on tensorflow.
My loss graphs are going totally wrong:
d_loss_real = tf.reduce_mean(disc_real_image_logits)
d_loss_fake = tf.reduce_mean(disc_fake_image_logits)
d_loss_wrong = tf.reduce_mean(disc_wrong_image_logits)
d_w_loss = d_loss_fake + d_loss_wrong - d_loss_real
g_w_loss = -1*(d_loss_fake)
rms_d_optim = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize( loss['d_loss'],var_list=variables['d_vars'])
rms_g_optim = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(loss['g_loss'], var_list=variables['g_vars'])
d_clip = [v.assign(tf.clip_by_value(v, -args.d_clip_limit, args.d_clip_limit)) for v in variables['d_vars']]
with tf.control_dependencies([rms_d_optim]):
rms_d_optim = tf.tuple(d_clip)
for epoch in range(100):
for diter in range(10):
sess.run([rms_d_optim],feed_dict=feed)
sess.run(d_clip)
sess.run([rms_g_optim],feed_dict=feed)
Could you suggest some direction for fixing this?
I went through your code (I am not well versed with PyTorch as of now), and it seems that you are also using same losses. Please correct me if I am mistaken.
Thanks
I've tried to run your code but it required nvidia and I don't have a nvidia graphic card ; can this work with other types of GPU?
Thanks
Hi, I'm learning text-to-image, and thanks a lot for your reimplement in pytorch!
However, i can not find the text_c10 files for birds and flowers dataset, is it the descriptions of the image?
It looks like the author of the icml2016 have not provide the original descriptions but the text embedding in torch format.
Watting for your reply.
Hi!
I have met a very weird problem that the model could be trained several times but soon it came into error below.
hope for your reply!
Epoch: 0, d_loss= 2.205419, g_loss= 31.352011, D(X)= 0.381197, D(G(X))= 0.543290
Epoch: 0, d_loss= 1.640341, g_loss= 30.674690, D(X)= 0.512067, D(G(X))= 0.349348
Epoch: 0, d_loss= 1.334773, g_loss= 34.140278, D(X)= 0.618809, D(G(X))= 0.341588
Traceback (most recent call last):
File "runtime.py", line 43, in
trainer.train(args.cls)
File "/home/jurh/disk2/liupeng/text2image_3/Text-to-Image-Synthesis/trainer.py", line 67, in train
self._train_gan(cls)
File "/home/jurh/disk2/liupeng/text2image_3/Text-to-Image-Synthesis/trainer.py", line 177, in _train_gan for sample in self.data_loader:
File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.py", line 187, in next
return self._process_next_batch(batch)
File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.py", line 221, in _process_next_batch raise batch.exc_type(batch.exc_msg)
TypeError: function takes exactly 5 arguments (1 given)
Hi,
I trained with gan_cls (not the vanilla but conditioned version) on flowers, with the shared hdf5 file, and I got curves of https://drive.google.com/open?id=1pASanOh9YUdg__I5OPRi_srmu3T2JYx8.
The discriminator loss keeps going down (then almost converge) to 0.483, but generator loss keeps up (not converge) to 18.58, and D(X) to 0.846, and D(G(X)) to 0.016. And I got similar image results on prediction as reported.
I think the curves I got during training suggests divergence, correct? If converge, we should see the generator loss also going down, and D(G(X)) up, correct?
Am I missing anything here?
And do you have any suggestions to make the training converge? (I see that you've implemented many tricks from https://github.com/soumith/ganhacks#13-add-noise-to-inputs-decay-over-time. I'm trying #2 to flip real_label with fake_label for generator, doesn't seem help though).
Look forward to your answer. Thanks!
Can someone provide guidance on how to generate hdf5 file for COCO dataset.
How do I get the following?
birds_images_path: '/export/mlrg/aelnouby/projects/GANs/Birds dataset/CUB_200_2011/CUB_200_2011/images/'
birds_embedding_path: '/export/mlrg/aelnouby/projects/GANs/Birds dataset/cub_icml/'
birds_text_path: '/export/mlrg/aelnouby/projects/GANs/Birds dataset/cvpr2016_cub/text_c10/'
val_split_path: '/export/mlrg/aelnouby/projects/GANs/Birds dataset/cub_icml/valclasses.txt'
train_split_path: '/export/mlrg/aelnouby/projects/GANs/Birds dataset/cub_icml/trainclasses.txt'
test_split_path: '/export/mlrg/aelnouby/projects/GANs/Birds dataset/cub_icml/testclasses.txt'
There was some reference to use https://github.com/reedscot/cvpr2016, but when I checked the code, it didnt provide much insight. Any help would be appreciated
As I was reading convert_cub_to_hd5_script.py
to get a better understanding of the .hdf5 structure, I came across this line:
txt_choice = np.random.choice(range(10), 5)
where five text embeddings are supposedly randomly sampled. However, by default, np.random.choice
does sampling with replacement, which in this case could cause the same embedding to be picked twice. Would this be an issue?
Hi, I want to use gradient penalty since i wana try improved wgan.
what does build from source
mean?
it means I should train it from scratch or build pytorch souce code?
thanks.
Hi,
Can you please release pre-trained models for all 3 datasets ?
Hi!When I trained the bird dataset,I found the results is not good.Have you trained the bird dataset and how about the result? or do you have any advice to train the birds?
Hope for reply!
Thank you for your code, when I try to run your code, I want to test with HDF5 file of Bird and Flower but the download link was invalid. Can you update the download link?
Thanks!
can you report the inception scores of the modified model?
I want to train with my own data and to prepare my dataset and for that I'll have to convert text files to .t7 format for text embeddings. How do i achieve the same?
Hello, author. I want to reproduce the experimental results. When I trained the model, I found that the loss of generator was always high and it looked like the generator was much weaker than the discriminator, what should I do?
why I don’t have any results after running the runtime for a long time. No errors or output
Hello I really want to use this repository for my deep learning course. After downloaded datasets when I want to run 'runtime.py' on google colab(I don't have GPU to run on my system), some error happened. The error is related to 'visualize.py'. Error is :
I really need this code please help me to solve it.
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 600, in urlopen
chunked=chunked)
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 354, in _make_request
conn.request(method, url, **httplib_request_kw)
File "/usr/lib/python3.7/http/client.py", line 1281, in request
self._send_request(method, url, body, headers, encode_chunked)
File "/usr/lib/python3.7/http/client.py", line 1327, in _send_request
self.endheaders(body, encode_chunked=encode_chunked)
File "/usr/lib/python3.7/http/client.py", line 1276, in endheaders
self._send_output(message_body, encode_chunked=encode_chunked)
File "/usr/lib/python3.7/http/client.py", line 1036, in _send_output
self.send(msg)
File "/usr/lib/python3.7/http/client.py", line 976, in send
self.connect()
File "/usr/local/lib/python3.7/dist-packages/urllib3/connection.py", line 181, in connect
conn = self._new_conn()
File "/usr/local/lib/python3.7/dist-packages/urllib3/connection.py", line 168, in _new_conn
self, "Failed to establish a new connection: %s" % e)
urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPConnection object at 0x7fcf91fe6c50>: Failed to establish a new connection: [Errno 111] Connection refused
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/requests/adapters.py", line 449, in send
timeout=timeout
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 638, in urlopen
_stacktrace=sys.exc_info()[2])
File "/usr/local/lib/python3.7/dist-packages/urllib3/util/retry.py", line 399, in increment
raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPConnectionPool(host='localhost', port=8097): Max retries exceeded with url: /events (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7fcf91fe6c50>: Failed to establish a new connection: [Errno 111] Connection refused'))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/urllib3/connection.py", line 159, in _new_conn
(self._dns_host, self.port), self.timeout, **extra_kw)
File "/usr/local/lib/python3.7/dist-packages/urllib3/util/connection.py", line 80, in create_connection
raise err
File "/usr/local/lib/python3.7/dist-packages/urllib3/util/connection.py", line 70, in create_connection
sock.connect(sa)
ConnectionRefusedError: [Errno 111] Connection refused
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 600, in urlopen
chunked=chunked)
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 354, in _make_request
conn.request(method, url, **httplib_request_kw)
File "/usr/lib/python3.7/http/client.py", line 1281, in request
self._send_request(method, url, body, headers, encode_chunked)
File "/usr/lib/python3.7/http/client.py", line 1327, in _send_request
self.endheaders(body, encode_chunked=encode_chunked)
File "/usr/lib/python3.7/http/client.py", line 1276, in endheaders
self._send_output(message_body, encode_chunked=encode_chunked)
File "/usr/lib/python3.7/http/client.py", line 1036, in _send_output
self.send(msg)
File "/usr/lib/python3.7/http/client.py", line 976, in send
self.connect()
File "/usr/local/lib/python3.7/dist-packages/urllib3/connection.py", line 181, in connect
conn = self._new_conn()
File "/usr/local/lib/python3.7/dist-packages/urllib3/connection.py", line 168, in _new_conn
self, "Failed to establish a new connection: %s" % e)
urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPConnection object at 0x7fcf940a45d0>: Failed to establish a new connection: [Errno 111] Connection refused
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/requests/adapters.py", line 449, in send
timeout=timeout
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 638, in urlopen
_stacktrace=sys.exc_info()[2])
File "/usr/local/lib/python3.7/dist-packages/urllib3/util/retry.py", line 399, in increment
raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPConnectionPool(host='localhost', port=8097): Max retries exceeded with url: /events (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7fcf940a45d0>: Failed to establish a new connection: [Errno 111] Connection refused'))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/visdom/init.py", line 263, in _send
data=json.dumps(msg),
File "/usr/local/lib/python3.7/dist-packages/requests/api.py", line 119, in post
return request('post', url, data=data, json=json, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/requests/api.py", line 61, in request
return session.request(method=method, url=url, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/requests/sessions.py", line 530, in request
resp = self.send(prep, **send_kwargs)
File "/usr/local/lib/python3.7/dist-packages/requests/sessions.py", line 643, in send
r = adapter.send(request, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/requests/adapters.py", line 516, in send
raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPConnectionPool(host='localhost', port=8097): Max retries exceeded with url: /events (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7fcf940a45d0>: Failed to establish a new connection: [Errno 111] Connection refused'))
Traceback (most recent call last):
File "runtime.py", line 42, in
trainer.train(args.cls)
File "/content/gdrive/MyDrive/Txt2Img(2)/trainer.py", line 67, in train
self._train_gan(cls)
File "/content/gdrive/MyDrive/Txt2Img(2)/trainer.py", line 177, in _train_gan
for sample in self.data_loader:
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 652, in next
data = self._next_data()
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1327, in _next_data
return self._process_data(data)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1373, in _process_data
data.reraise()
File "/usr/local/lib/python3.7/dist-packages/torch/_utils.py", line 460, in reraise
raise RuntimeError(msg) from None
RuntimeError: Caught UnicodeDecodeError in DataLoader worker process 6.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/content/gdrive/MyDrive/Txt2Img(2)/txt2image_dataset.py", line 52, in getitem
txt = np.array(example['txt']).astype(str)
UnicodeDecodeError: 'ascii' codec can't decode byte 0xef in position 4: ordinal not in range(128)
I ran 'python runtime.py --type=gan' and got below error:
File "runtime.py", line 42, in
trainer.train(args.cls)
File "/private/home/xjwang/twoDomainGAN-pytorch/Text-to-Image-Synthesis/trainer.py", line 67, in train
self._train_gan(cls)
File "/private/home/xjwang/twoDomainGAN-pytorch/Text-to-Image-Synthesis/trainer.py", line 295, in _train_gan
for sample in self.data_loader:
File "/public/apps/anaconda3/4.3.1/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 212, in next
return self._process_next_batch(batch)
File "/public/apps/anaconda3/4.3.1/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 239, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
TypeError: Traceback (most recent call last):
File "/public/apps/anaconda3/4.3.1/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 41, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/public/apps/anaconda3/4.3.1/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 41, in
samples = collate_fn([dataset[i] for i in batch_indices])
File "/private/home/xjwang/twoDomainGAN-pytorch/Text-to-Image-Synthesis/txt2image_dataset.py", line 42, in getitem
right_image = bytes(np.array(example['img']))
TypeError: only integer scalar arrays can be converted to a scalar index
`import numpy as np
import torch
import yaml
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from PIL import Image
import os
class Trainer(object):
def init(self, type, dataset, split, lr, diter, vis_screen, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, batch_size, num_workers, epochs):
self.generator = torch.nn.DataParallel(gan_factory.generator_factory(type).cuda())
self.discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(type).cuda())
if pre_trained_disc:
self.discriminator.load_state_dict(torch.load(pre_trained_disc))
else:
self.discriminator.apply(Utils.weights_init)
if pre_trained_gen:
self.generator.load_state_dict(torch.load(pre_trained_gen))
else:
self.generator.apply(Utils.weights_init)
if dataset == 'birds':
self.dataset = Text2ImageDataset('Data/Birds/', split=split)
elif dataset == 'flowers':
self.dataset = Text2ImageDataset('Data/Flowers/flowers.hdf5', split=split)
else:
print('Dataset not supported, please select either birds or flowers.')
exit()
self.noise_dim = 100
self.batch_size = batch_size
self.num_workers = num_workers
self.lr = lr
self.beta1 = 0.5
self.num_epochs = epochs
self.DITER = diter
self.l1_coef = l1_coef
self.l2_coef = l2_coef
self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_workers)
self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999))
self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999))
self.logger = Logger(vis_screen)
self.checkpoints_path = 'checkpoints'
self.save_path = save_path
self.type = type
def train(self, cls=False):
if self.type == 'wgan':
self._train_wgan(cls)
elif self.type == 'gan':
self._train_gan(cls)
elif self.type == 'vanilla_wgan':
self._train_vanilla_wgan()
elif self.type == 'vanilla_gan':
self._train_vanilla_gan()
def _train_wgan(self, cls):
one = torch.FloatTensor([1])
mone = one * -1
one = Variable(one).cuda()
mone = Variable(mone).cuda()
gen_iteration = 0
for epoch in range(self.num_epochs):
iterator = 0
data_iterator = iter(self.data_loader)
while iterator < len(self.data_loader):
if gen_iteration < 25 or gen_iteration % 500 == 0:
d_iter_count = 100
else:
d_iter_count = self.DITER
d_iter = 0
# Train the discriminator
while d_iter < d_iter_count and iterator < len(self.data_loader):
d_iter += 1
for p in self.discriminator.parameters():
p.requires_grad = True
self.discriminator.zero_grad()
sample = next(data_iterator)
iterator += 1
right_images = sample['right_images']
right_embed = sample['right_embed']
wrong_images = sample['wrong_images']
right_images = Variable(right_images.float()).cuda()
right_embed = Variable(right_embed.float()).cuda()
wrong_images = Variable(wrong_images.float()).cuda()
outputs, _ = self.discriminator(right_images, right_embed)
real_loss = torch.mean(outputs)
real_loss.backward(mone)
if cls:
outputs, _ = self.discriminator(wrong_images, right_embed)
wrong_loss = torch.mean(outputs)
wrong_loss.backward(one)
noise = Variable(torch.randn(right_images.size(0), self.noise_dim), volatile=True).cuda()
noise = noise.view(noise.size(0), self.noise_dim, 1, 1)
fake_images = Variable(self.generator(right_embed, noise).data)
outputs, _ = self.discriminator(fake_images, right_embed)
fake_loss = torch.mean(outputs)
fake_loss.backward(one)
## NOTE: Pytorch had a bug with gradient penalty at the time of this project development
## , uncomment the next two lines and remove the params clamping below if you want to try gradient penalty
# gp = Utils.compute_GP(self.discriminator, right_images.data, right_embed, fake_images.data, LAMBDA=10)
# gp.backward()
d_loss = real_loss - fake_loss
if cls:
d_loss = d_loss - wrong_loss
self.optimD.step()
for p in self.discriminator.parameters():
p.data.clamp_(-0.01, 0.01)
# Train Generator
for p in self.discriminator.parameters():
p.requires_grad = False
self.generator.zero_grad()
noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
noise = noise.view(noise.size(0), 100, 1, 1)
fake_images = self.generator(right_embed, noise)
outputs, _ = self.discriminator(fake_images, right_embed)
g_loss = torch.mean(outputs)
g_loss.backward(mone)
g_loss = - g_loss
self.optimG.step()
gen_iteration += 1
self.logger.draw(right_images, fake_images)
self.logger.log_iteration_wgan(epoch, gen_iteration, d_loss, g_loss, real_loss, fake_loss)
self.logger.plot_epoch(gen_iteration)
if (epoch+1) % 50 == 0:
Utils.save_checkpoint(self.discriminator, self.generator, self.checkpoints_path, epoch)
def _train_gan(self, cls):
criterion = nn.CrossEntropyLoss()
l2_loss = nn.MSELoss()
l1_loss = nn.L1Loss()
iteration = 0
for epoch in range(self.num_epochs):
for sample in self.data_loader:
iteration += 1
right_images = sample['right_images']
right_embed = sample['right_embed']
wrong_images = sample['wrong_images']
right_images = Variable(right_images.float()).cuda()
right_embed = Variable(right_embed.float()).cuda()
wrong_images = Variable(wrong_images.float()).cuda()
real_labels = torch.ones(right_images.size(0))
fake_labels = torch.zeros(right_images.size(0))
# ======== One sided label smoothing ==========
# Helps preventing the discriminator from overpowering the
# generator adding penalty when the discriminator is too confident
# =============================================
smoothed_real_labels = torch.FloatTensor(Utils.smooth_label(real_labels.numpy(), -0.1))
real_labels = Variable(real_labels).cuda()
smoothed_real_labels = Variable(smoothed_real_labels).cuda()
fake_labels = Variable(fake_labels).cuda()
# Train the discriminator
self.discriminator.zero_grad()
outputs, activation_real = self.discriminator(right_images, right_embed)
real_loss = criterion(outputs, smoothed_real_labels.squeeze())
real_score = outputs
if cls:
outputs, _ = self.discriminator(wrong_images, right_embed)
wrong_loss = criterion(outputs, fake_labels)
wrong_score = outputs
noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
noise = noise.view(noise.size(0), 100, 1, 1)
fake_images = self.generator(right_embed, noise)
outputs, _ = self.discriminator(fake_images, right_embed)
fake_loss = criterion(outputs, fake_labels.squeeze())
fake_score = outputs
d_loss = real_loss + fake_loss
if cls:
d_loss = d_loss + wrong_loss
d_loss.backward()
self.optimD.step()
# Train the generator
self.generator.zero_grad()
noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
noise = noise.view(noise.size(0), 100, 1, 1)
fake_images = self.generator(right_embed, noise)
outputs, activation_fake = self.discriminator(fake_images, right_embed)
_, activation_real = self.discriminator(right_images, right_embed)
activation_fake = torch.mean(activation_fake, 0)
activation_real = torch.mean(activation_real, 0)
#======= Generator Loss function============
# This is a customized loss function, the first term is the regular cross entropy loss
# The second term is feature matching loss, this measure the distance between the real and generated
# images statistics by comparing intermediate layers activations
# The third term is L1 distance between the generated and real images, this is helpful for the conditional case
# because it links the embedding feature vector directly to certain pixel values.
#===========================================
g_loss = criterion(outputs, real_labels) \
+ self.l2_coef * l2_loss(activation_fake, activation_real.detach()) \
+ self.l1_coef * l1_loss(fake_images, right_images)
g_loss.backward()
self.optimG.step()
if iteration % 5 == 0:
self.logger.log_iteration_gan(epoch,d_loss, g_loss, real_score, fake_score)
self.logger.draw(right_images, fake_images)
self.logger.plot_epoch_w_scores(epoch)
if (epoch) % 10 == 0:
Utils.save_checkpoint(self.discriminator, self.generator, self.checkpoints_path, self.save_path, epoch)
def _train_vanilla_wgan(self):
one = Variable(torch.FloatTensor([1])).cuda()
mone = one * -1
gen_iteration = 0
for epoch in range(self.num_epochs):
iterator = 0
data_iterator = iter(self.data_loader)
while iterator < len(self.data_loader):
if gen_iteration < 25 or gen_iteration % 500 == 0:
d_iter_count = 100
else:
d_iter_count = self.DITER
d_iter = 0
# Train the discriminator
while d_iter < d_iter_count and iterator < len(self.data_loader):
d_iter += 1
for p in self.discriminator.parameters():
p.requires_grad = True
self.discriminator.zero_grad()
sample = next(data_iterator)
iterator += 1
right_images = sample['right_images']
right_images = Variable(right_images.float()).cuda()
outputs, _ = self.discriminator(right_images)
real_loss = torch.mean(outputs)
real_loss.backward(mone)
noise = Variable(torch.randn(right_images.size(0), self.noise_dim), volatile=True).cuda()
noise = noise.view(noise.size(0), self.noise_dim, 1, 1)
fake_images = Variable(self.generator(noise).data)
outputs, _ = self.discriminator(fake_images)
fake_loss = torch.mean(outputs)
fake_loss.backward(one)
## NOTE: Pytorch had a bug with gradient penalty at the time of this project development
## , uncomment the next two lines and remove the params clamping below if you want to try gradient penalty
# gp = Utils.compute_GP(self.discriminator, right_images.data, right_embed, fake_images.data, LAMBDA=10)
# gp.backward()
d_loss = real_loss - fake_loss
self.optimD.step()
for p in self.discriminator.parameters():
p.data.clamp_(-0.01, 0.01)
# Train Generator
for p in self.discriminator.parameters():
p.requires_grad = False
self.generator.zero_grad()
noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
noise = noise.view(noise.size(0), 100, 1, 1)
fake_images = self.generator(noise)
outputs, _ = self.discriminator(fake_images)
g_loss = torch.mean(outputs)
g_loss.backward(mone)
g_loss = - g_loss
self.optimG.step()
gen_iteration += 1
self.logger.draw(right_images, fake_images)
self.logger.log_iteration_wgan(epoch, gen_iteration, d_loss, g_loss, real_loss, fake_loss)
self.logger.plot_epoch(gen_iteration)
if (epoch + 1) % 50 == 0:
Utils.save_checkpoint(self.discriminator, self.generator, self.checkpoints_path, epoch)
def _train_vanilla_gan(self):
criterion = nn.BCELoss()
l2_loss = nn.MSELoss()
l1_loss = nn.L1Loss()
iteration = 0
for epoch in range(self.num_epochs):
for sample in self.data_loader:
iteration += 1
right_images = sample['right_images']
right_images = Variable(right_images.float()).cuda()
real_labels = torch.ones(right_images.size(0))
fake_labels = torch.zeros(right_images.size(0))
# ======== One sided label smoothing ==========
# Helps preventing the discriminator from overpowering the
# generator adding penalty when the discriminator is too confident
# =============================================
smoothed_real_labels = torch.FloatTensor(Utils.smooth_label(real_labels.numpy(), -0.1))
real_labels = Variable(real_labels).cuda()
smoothed_real_labels = Variable(smoothed_real_labels).cuda()
fake_labels = Variable(fake_labels).cuda()
# Train the discriminator
self.discriminator.zero_grad()
outputs, activation_real = self.discriminator(right_images)
real_loss = criterion(outputs, smoothed_real_labels)
real_score = outputs
noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
noise = noise.view(noise.size(0), 100, 1, 1)
fake_images = self.generator(noise)
outputs, _ = self.discriminator(fake_images)
fake_loss = criterion(outputs, fake_labels)
fake_score = outputs
d_loss = real_loss + fake_loss
d_loss.backward()
self.optimD.step()
# Train the generator
self.generator.zero_grad()
noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
noise = noise.view(noise.size(0), 100, 1, 1)
fake_images = self.generator(noise)
outputs, activation_fake = self.discriminator(fake_images)
_, activation_real = self.discriminator(right_images)
activation_fake = torch.mean(activation_fake, 0)
activation_real = torch.mean(activation_real, 0)
# ======= Generator Loss function============
# This is a customized loss function, the first term is the regular cross entropy loss
# The second term is feature matching loss, this measure the distance between the real and generated
# images statistics by comparing intermediate layers activations
# The third term is L1 distance between the generated and real images, this is helpful for the conditional case
# because it links the embedding feature vector directly to certain pixel values.
g_loss = criterion(outputs, real_labels) \
+ self.l2_coef * l2_loss(activation_fake, activation_real.detach()) \
+ self.l1_coef * l1_loss(fake_images, right_images)
g_loss.backward()
self.optimG.step()
if iteration % 5 == 0:
self.logger.log_iteration_gan(epoch, d_loss, g_loss, real_score, fake_score)
self.logger.draw(right_images, fake_images)
self.logger.plot_epoch_w_scores(iteration)
if (epoch) % 50 == 0:
Utils.save_checkpoint(self.discriminator, self.generator, self.checkpoints_path, epoch)
def predict(self):
for sample in self.data_loader:
right_images = sample['right_images']
right_embed = sample['right_embed']
txt = sample['txt']
if not os.path.exists('results/{0}'.format(self.save_path)):
os.makedirs('results/{0}'.format(self.save_path))
right_images = Variable(right_images.float()).cuda()
right_embed = Variable(right_embed.float()).cuda()
# Train the generator
noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
noise = noise.view(noise.size(0), 100, 1, 1)
fake_images = self.generator(right_embed, noise)
self.logger.draw(right_images, fake_images)
for image, t in zip(fake_images, txt):
im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
im.save('results/{0}/{1}.jpg'.format(self.save_path, t.replace("/", "")[:100]))
print(t)
`
In the original paper, the authors show that their manifold interpolation gets a big improvement. Did you implement this part in your code?
Thank you for sharing the code. I discovered that the 'generator' generate 64 images each time by 64 different embeddings. However, when I tried to copy embeddings :
for i in range(10):
right_embed[i] = right_embed[0]
it still work and generate normal images with the first 10 images look the same, but when most of embeddings are the same ( range 50 ), the 'generator' cant work and will generate wierd results.
SO, what I should do to generate small number of images by embeddings from 2~3 descriptions, because I cant achieve this by simply copy these embeddings 64 times .
I downloaded flowers.hdf5.
Confused on how to proceed further.
How do I get the following?
flowers_images_path: '/export/mlrg/aelnouby/projects/GANs/flowers_dataset/'
flowers_embedding_path: '/export/mlrg/aelnouby/projects/GANs/flowers_dataset/flowers_icml/'
flowers_text_path: '/export/mlrg/aelnouby/projects/GANs/flowers_dataset/cvpr2016_flowers/text_c10/'
flowers_dataset_path: '/scratch/aelnouby/text2image/flowers.hdf5'
flowers_val_split_path: '/export/mlrg/aelnouby/projects/GANs/flowers_dataset/flowers_icml/valclasses.txt'
flowers_train_split_path: '/export/mlrg/aelnouby/projects/GANs/flowers_dataset/flowers_icml/trainclasses.txt'
flowers_test_split_path: '/export/mlrg/aelnouby/projects/GANs/flowers_dataset/flowers_icml/testclasses.txt'
File "trainer.py", line 32, in init
self.discriminator.apply(Utils.weights_init)
File "/home/aara/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 136, in apply
module.apply(fn)
File "/home/aara/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 136, in apply
module.apply(fn)
File "/home/aara/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 136, in apply
module.apply(fn)
File "/home/aara/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 137, in apply
fn(self)
TypeError: unbound method weights_init() must be called with Utils instance as first argument (got Conv2d instance instead)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.