Giter Site home page Giter Site logo

aelnouby / text-to-image-synthesis Goto Github PK

View Code? Open in Web Editor NEW
391.0 391.0 89.0 465 KB

Pytorch implementation of Generative Adversarial Text-to-Image Synthesis paper

License: GNU General Public License v3.0

Python 100.00%
gans image-generation pytorch text-to-image zero-shot-learning

text-to-image-synthesis's People

Contributors

aelnouby avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

text-to-image-synthesis's Issues

A question about text and image encoding

  1. when train the GAN, i did not see text encoding process. you store text in sample.
    ` right_images = sample['right_images']
    right_embed = sample['right_embed']
    wrong_images = sample['wrong_images']

             right_images = Variable(right_images.float()).cuda()
             right_embed = Variable(right_embed.float()).cuda()
             wrong_images = Variable(wrong_images.float()).cuda()
    

`
There is no text_embedding.
2. when we train the GAN, should we extract text and image embedding first or we simplely put row images into the network?

Can't Use flowers.hdf5

When I want to use the data file provided by author,my IDE report bug as "RuntimeError: Unable to get group info (wrong B-tree signature)" below.
import h5py f = h5py.File('flowers.hdf5', 'r') for k in f.keys(): print(f["test"])
Traceback (most recent call last): File "h5py\_objects.pyx", line 54, in h5py._objects.with_phil.wrapper File "h5py\_objects.pyx", line 55, in h5py._objects.with_phil.wrapper File "F:\Anaconda\lib\site-packages\h5py\_hl\group.py", line 623, in __repr__ r = '<HDF5 group %s (%d members)>' % (namestr, len(self)) File "h5py\_objects.pyx", line 54, in h5py._objects.with_phil.wrapper File "h5py\_objects.pyx", line 55, in h5py._objects.with_phil.wrapper File "F:\Anaconda\lib\site-packages\h5py\_hl\group.py", line 443, in __len__ return self.id.get_num_objs() File "h5py\_objects.pyx", line 54, in h5py._objects.with_phil.wrapper File "h5py\_objects.pyx", line 55, in h5py._objects.with_phil.wrapper File "h5py\h5g.pyx", line 336, in h5py.h5g.GroupID.get_num_objs RuntimeError: Unable to get group info (wrong B-tree signature)
K(KP56RSNROTS2XKJ`9JXO9

Then I use vitables to check if there has any problems in this datafile,the result is that i can only see the groups named train,test,valid without dataset inside like this .
%5PSEASB4A((IAP KB5{P0Q

I am a freshman in using such things.Could anyone tell me where is my problem?

Q. about the preprocess of Data

Hi, @aelnouby
Thank you for your sharing.
In the process of getting data, I want to Convert the data by myself, so I download the dataset as the describe of https://github.com/reedscot/cvpr2016 , then I got a link (flowers , for example) https://drive.google.com/open?id=0B0ywwgffWnLLcms2WWJQRFNSWXM and download a file named cvpr2016_flowers.tar.gz , then I unzipped it and the unzipped folder looked like following
image
I can't understand how you organized your file to excute the script 'convert_flowers_to_hd5_script.py' correctly . Did I download a wrong file ?
Looking forward your reply!

Double about the generation of fake image in a mini-batch

Hi,
When training , you generate fake image twice in a mini-batch, however according to the paper is seems like when updating D and G, they both use the same fake image, so i'm confused about it..
I think generating fake image twice may increase the instability of training.

Hoping for your reply.

torch model

ModuleNotFoundError: No module named 'torch.utils.serialization'

cannot identify image file <_io.BytesIO object at 0x7f2d80b2e770>

Thanks for your contribution!
But when I run the code, an error occurs:

Traceback (most recent call last):
File "runtime.py", line 42, in
trainer.train(args.cls)
File "/home/zzw/program/text2img/text-to-Image-Synthesis-pytorch/trainer.py", line 65, in train
self._train_wgan(cls)
File "/home/zzw/program/text2img/text-to-Image-Synthesis-pytorch/trainer.py", line 103, in _train_wgan
sample = next(data_iterator)
File "/home/zzw/.local/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 281, in next
return self._process_next_batch(batch)
File "/home/zzw/.local/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 301, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
IOError: Traceback (most recent call last):
File "/home/zzw/.local/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 55, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/zzw/program/text2img/text-to-Image-Synthesis-pytorch/txt2image_dataset.py", line 46, in getitem
right_image = Image.open(io.BytesIO(right_image)).resize((64, 64))
File "/home/zzw/.local/lib/python2.7/site-packages/PIL/Image.py", line 2590, in open
% (filename if filename else fp))
IOError: cannot identify image file <_io.BytesIO object at 0x7f1be1801770>

My Pillow version is 5.1.0, and it seems like something related to the version.
Could anyone help me out?

Loss function

Hi,
Thanks for the nice and very helpful work! I am also trying to do text-to-image generation, but on tensorflow.
My loss graphs are going totally wrong:
image

  1. The loss functions I am using are:
d_loss_real = tf.reduce_mean(disc_real_image_logits)
d_loss_fake = tf.reduce_mean(disc_fake_image_logits)
d_loss_wrong = tf.reduce_mean(disc_wrong_image_logits)
d_w_loss = d_loss_fake + d_loss_wrong - d_loss_real
g_w_loss = -1*(d_loss_fake)
  1. For optimisation I am using following lines:
rms_d_optim = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(  loss['d_loss'],var_list=variables['d_vars'])
rms_g_optim = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(loss['g_loss'], var_list=variables['g_vars'])
d_clip = [v.assign(tf.clip_by_value(v, -args.d_clip_limit, args.d_clip_limit)) for v in variables['d_vars']]
with tf.control_dependencies([rms_d_optim]):
      rms_d_optim = tf.tuple(d_clip)
for epoch in range(100):
           for diter in range(10):
                       sess.run([rms_d_optim],feed_dict=feed)
                       sess.run(d_clip)
             sess.run([rms_g_optim],feed_dict=feed)

Could you suggest some direction for fixing this?
I went through your code (I am not well versed with PyTorch as of now), and it seems that you are also using same losses. Please correct me if I am mistaken.
Thanks

Can you run this without nvidia?

I've tried to run your code but it required nvidia and I don't have a nvidia graphic card ; can this work with other types of GPU?
Thanks

Can't find the file:"text_c10"

Hi, I'm learning text-to-image, and thanks a lot for your reimplement in pytorch!
However, i can not find the text_c10 files for birds and flowers dataset, is it the descriptions of the image?
It looks like the author of the icml2016 have not provide the original descriptions but the text embedding in torch format.

Watting for your reply.

TypeError;function takes exactly 5 arguments (1 given)

Hi!
I have met a very weird problem that the model could be trained several times but soon it came into error below.
hope for your reply!
Epoch: 0, d_loss= 2.205419, g_loss= 31.352011, D(X)= 0.381197, D(G(X))= 0.543290
Epoch: 0, d_loss= 1.640341, g_loss= 30.674690, D(X)= 0.512067, D(G(X))= 0.349348
Epoch: 0, d_loss= 1.334773, g_loss= 34.140278, D(X)= 0.618809, D(G(X))= 0.341588
Traceback (most recent call last):
File "runtime.py", line 43, in
trainer.train(args.cls)
File "/home/jurh/disk2/liupeng/text2image_3/Text-to-Image-Synthesis/trainer.py", line 67, in train
self._train_gan(cls)
File "/home/jurh/disk2/liupeng/text2image_3/Text-to-Image-Synthesis/trainer.py", line 177, in _train_gan for sample in self.data_loader:
File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.py", line 187, in next
return self._process_next_batch(batch)
File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.py", line 221, in _process_next_batch raise batch.exc_type(batch.exc_msg)
TypeError: function takes exactly 5 arguments (1 given)

Generator diverges?

Hi,

I trained with gan_cls (not the vanilla but conditioned version) on flowers, with the shared hdf5 file, and I got curves of https://drive.google.com/open?id=1pASanOh9YUdg__I5OPRi_srmu3T2JYx8.

The discriminator loss keeps going down (then almost converge) to 0.483, but generator loss keeps up (not converge) to 18.58, and D(X) to 0.846, and D(G(X)) to 0.016. And I got similar image results on prediction as reported.

I think the curves I got during training suggests divergence, correct? If converge, we should see the generator loss also going down, and D(G(X)) up, correct?

Am I missing anything here?
And do you have any suggestions to make the training converge? (I see that you've implemented many tricks from https://github.com/soumith/ganhacks#13-add-noise-to-inputs-decay-over-time. I'm trying #2 to flip real_label with fake_label for generator, doesn't seem help though).

Look forward to your answer. Thanks!

New dataset

Can someone provide guidance on how to generate hdf5 file for COCO dataset.
How do I get the following?
birds_images_path: '/export/mlrg/aelnouby/projects/GANs/Birds dataset/CUB_200_2011/CUB_200_2011/images/'
birds_embedding_path: '/export/mlrg/aelnouby/projects/GANs/Birds dataset/cub_icml/'
birds_text_path: '/export/mlrg/aelnouby/projects/GANs/Birds dataset/cvpr2016_cub/text_c10/'

val_split_path: '/export/mlrg/aelnouby/projects/GANs/Birds dataset/cub_icml/valclasses.txt'
train_split_path: '/export/mlrg/aelnouby/projects/GANs/Birds dataset/cub_icml/trainclasses.txt'
test_split_path: '/export/mlrg/aelnouby/projects/GANs/Birds dataset/cub_icml/testclasses.txt'

There was some reference to use https://github.com/reedscot/cvpr2016, but when I checked the code, it didnt provide much insight. Any help would be appreciated

Duplicate embeddings in convert_cub_to_hd5_script.py

As I was reading convert_cub_to_hd5_script.py to get a better understanding of the .hdf5 structure, I came across this line:

txt_choice = np.random.choice(range(10), 5)

where five text embeddings are supposedly randomly sampled. However, by default, np.random.choice does sampling with replacement, which in this case could cause the same embedding to be picked twice. Would this be an issue?

Cannot download HDF5 file

Thank you for your code, when I try to run your code, I want to test with HDF5 file of Bird and Flower but the download link was invalid. Can you update the download link?

Thanks!

Need advice

Hi,

  • We trained our model for 200 epochs.
  • Ran predict as follows
    python runtime.py --inference --split 2
  • Images got generated in 'results' folder
  • But all the images generated were little off - please see below:

a beautiful flower with outer whorl of prominent big white sepals and the corolla contains several t

Can you please tell us what could have gone wrong?

text embeddings preparation

I want to train with my own data and to prepare my dataset and for that I'll have to convert text files to .t7 format for text embeddings. How do i achieve the same?

The loss of generator

Hello, author. I want to reproduce the experimental results. When I trained the model, I found that the loss of generator was always high and it looked like the generator was much weaker than the discriminator, what should I do?

Error on visualize.py

Hello I really want to use this repository for my deep learning course. After downloaded datasets when I want to run 'runtime.py' on google colab(I don't have GPU to run on my system), some error happened. The error is related to 'visualize.py'. Error is :
I really need this code please help me to solve it.

Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 600, in urlopen
chunked=chunked)
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 354, in _make_request
conn.request(method, url, **httplib_request_kw)
File "/usr/lib/python3.7/http/client.py", line 1281, in request
self._send_request(method, url, body, headers, encode_chunked)
File "/usr/lib/python3.7/http/client.py", line 1327, in _send_request
self.endheaders(body, encode_chunked=encode_chunked)
File "/usr/lib/python3.7/http/client.py", line 1276, in endheaders
self._send_output(message_body, encode_chunked=encode_chunked)
File "/usr/lib/python3.7/http/client.py", line 1036, in _send_output
self.send(msg)
File "/usr/lib/python3.7/http/client.py", line 976, in send
self.connect()
File "/usr/local/lib/python3.7/dist-packages/urllib3/connection.py", line 181, in connect
conn = self._new_conn()
File "/usr/local/lib/python3.7/dist-packages/urllib3/connection.py", line 168, in _new_conn
self, "Failed to establish a new connection: %s" % e)
urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPConnection object at 0x7fcf91fe6c50>: Failed to establish a new connection: [Errno 111] Connection refused

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/requests/adapters.py", line 449, in send
timeout=timeout
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 638, in urlopen
_stacktrace=sys.exc_info()[2])
File "/usr/local/lib/python3.7/dist-packages/urllib3/util/retry.py", line 399, in increment
raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPConnectionPool(host='localhost', port=8097): Max retries exceeded with url: /events (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7fcf91fe6c50>: Failed to establish a new connection: [Errno 111] Connection refused'))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/visdom/init.py", line 263, in _send
data=json.dumps(msg),
File "/usr/local/lib/python3.7/dist-packages/requests/api.py", line 119, in post
return request('post', url, data=data, json=json, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/requests/api.py", line 61, in request
return session.request(method=method, url=url, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/requests/sessions.py", line 530, in request
resp = self.send(prep, **send_kwargs)
File "/usr/local/lib/python3.7/dist-packages/requests/sessions.py", line 643, in send
r = adapter.send(request, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/requests/adapters.py", line 516, in send
raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPConnectionPool(host='localhost', port=8097): Max retries exceeded with url: /events (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7fcf91fe6c50>: Failed to establish a new connection: [Errno 111] Connection refused'))
Exception in user code:

Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/urllib3/connection.py", line 159, in _new_conn
(self._dns_host, self.port), self.timeout, **extra_kw)
File "/usr/local/lib/python3.7/dist-packages/urllib3/util/connection.py", line 80, in create_connection
raise err
File "/usr/local/lib/python3.7/dist-packages/urllib3/util/connection.py", line 70, in create_connection
sock.connect(sa)
ConnectionRefusedError: [Errno 111] Connection refused

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 600, in urlopen
chunked=chunked)
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 354, in _make_request
conn.request(method, url, **httplib_request_kw)
File "/usr/lib/python3.7/http/client.py", line 1281, in request
self._send_request(method, url, body, headers, encode_chunked)
File "/usr/lib/python3.7/http/client.py", line 1327, in _send_request
self.endheaders(body, encode_chunked=encode_chunked)
File "/usr/lib/python3.7/http/client.py", line 1276, in endheaders
self._send_output(message_body, encode_chunked=encode_chunked)
File "/usr/lib/python3.7/http/client.py", line 1036, in _send_output
self.send(msg)
File "/usr/lib/python3.7/http/client.py", line 976, in send
self.connect()
File "/usr/local/lib/python3.7/dist-packages/urllib3/connection.py", line 181, in connect
conn = self._new_conn()
File "/usr/local/lib/python3.7/dist-packages/urllib3/connection.py", line 168, in _new_conn
self, "Failed to establish a new connection: %s" % e)
urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPConnection object at 0x7fcf940a45d0>: Failed to establish a new connection: [Errno 111] Connection refused

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/requests/adapters.py", line 449, in send
timeout=timeout
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 638, in urlopen
_stacktrace=sys.exc_info()[2])
File "/usr/local/lib/python3.7/dist-packages/urllib3/util/retry.py", line 399, in increment
raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPConnectionPool(host='localhost', port=8097): Max retries exceeded with url: /events (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7fcf940a45d0>: Failed to establish a new connection: [Errno 111] Connection refused'))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/visdom/init.py", line 263, in _send
data=json.dumps(msg),
File "/usr/local/lib/python3.7/dist-packages/requests/api.py", line 119, in post
return request('post', url, data=data, json=json, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/requests/api.py", line 61, in request
return session.request(method=method, url=url, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/requests/sessions.py", line 530, in request
resp = self.send(prep, **send_kwargs)
File "/usr/local/lib/python3.7/dist-packages/requests/sessions.py", line 643, in send
r = adapter.send(request, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/requests/adapters.py", line 516, in send
raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPConnectionPool(host='localhost', port=8097): Max retries exceeded with url: /events (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7fcf940a45d0>: Failed to establish a new connection: [Errno 111] Connection refused'))
Traceback (most recent call last):
File "runtime.py", line 42, in
trainer.train(args.cls)
File "/content/gdrive/MyDrive/Txt2Img(2)/trainer.py", line 67, in train
self._train_gan(cls)
File "/content/gdrive/MyDrive/Txt2Img(2)/trainer.py", line 177, in _train_gan
for sample in self.data_loader:
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 652, in next
data = self._next_data()
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1327, in _next_data
return self._process_data(data)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1373, in _process_data
data.reraise()
File "/usr/local/lib/python3.7/dist-packages/torch/_utils.py", line 460, in reraise
raise RuntimeError(msg) from None
RuntimeError: Caught UnicodeDecodeError in DataLoader worker process 6.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/content/gdrive/MyDrive/Txt2Img(2)/txt2image_dataset.py", line 52, in getitem
txt = np.array(example['txt']).astype(str)
UnicodeDecodeError: 'ascii' codec can't decode byte 0xef in position 4: ordinal not in range(128)

typeError during data loading

I ran 'python runtime.py --type=gan' and got below error:

File "runtime.py", line 42, in
trainer.train(args.cls)
File "/private/home/xjwang/twoDomainGAN-pytorch/Text-to-Image-Synthesis/trainer.py", line 67, in train
self._train_gan(cls)
File "/private/home/xjwang/twoDomainGAN-pytorch/Text-to-Image-Synthesis/trainer.py", line 295, in _train_gan
for sample in self.data_loader:
File "/public/apps/anaconda3/4.3.1/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 212, in next
return self._process_next_batch(batch)
File "/public/apps/anaconda3/4.3.1/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 239, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
TypeError: Traceback (most recent call last):
File "/public/apps/anaconda3/4.3.1/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 41, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/public/apps/anaconda3/4.3.1/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 41, in
samples = collate_fn([dataset[i] for i in batch_indices])
File "/private/home/xjwang/twoDomainGAN-pytorch/Text-to-Image-Synthesis/txt2image_dataset.py", line 42, in getitem
right_image = bytes(np.array(example['img']))
TypeError: only integer scalar arrays can be converted to a scalar index

trainer.py RuntimeError: size mismatch (got input: [2], target: [64])

image

getting above error on running trainer.py below is its code:

`import numpy as np
import torch
import yaml
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader

from PIL import Image
import os

class Trainer(object):
def init(self, type, dataset, split, lr, diter, vis_screen, save_path, l1_coef, l2_coef, pre_trained_gen, pre_trained_disc, batch_size, num_workers, epochs):

    self.generator = torch.nn.DataParallel(gan_factory.generator_factory(type).cuda())
    self.discriminator = torch.nn.DataParallel(gan_factory.discriminator_factory(type).cuda())

    if pre_trained_disc:
        self.discriminator.load_state_dict(torch.load(pre_trained_disc))
    else:
        self.discriminator.apply(Utils.weights_init)

    if pre_trained_gen:
        self.generator.load_state_dict(torch.load(pre_trained_gen))
    else:
        self.generator.apply(Utils.weights_init)

    if dataset == 'birds':
        self.dataset = Text2ImageDataset('Data/Birds/', split=split)
    elif dataset == 'flowers':
        self.dataset = Text2ImageDataset('Data/Flowers/flowers.hdf5', split=split)
    else:
        print('Dataset not supported, please select either birds or flowers.')
        exit()

    self.noise_dim = 100
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.lr = lr
    self.beta1 = 0.5
    self.num_epochs = epochs
    self.DITER = diter

    self.l1_coef = l1_coef
    self.l2_coef = l2_coef

    self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True,
                            num_workers=self.num_workers)

    self.optimD = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999))
    self.optimG = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999))

    self.logger = Logger(vis_screen)
    self.checkpoints_path = 'checkpoints'
    self.save_path = save_path
    self.type = type

def train(self, cls=False):

    if self.type == 'wgan':
        self._train_wgan(cls)
    elif self.type == 'gan':
        self._train_gan(cls)
    elif self.type == 'vanilla_wgan':
        self._train_vanilla_wgan()
    elif self.type == 'vanilla_gan':
        self._train_vanilla_gan()

def _train_wgan(self, cls):
    one = torch.FloatTensor([1])
    mone = one * -1

    one = Variable(one).cuda()
    mone = Variable(mone).cuda()

    gen_iteration = 0
    for epoch in range(self.num_epochs):
        iterator = 0
        data_iterator = iter(self.data_loader)

        while iterator < len(self.data_loader):

            if gen_iteration < 25 or gen_iteration % 500 == 0:
                d_iter_count = 100
            else:
                d_iter_count = self.DITER

            d_iter = 0

            # Train the discriminator
            while d_iter < d_iter_count and iterator < len(self.data_loader):
                d_iter += 1

                for p in self.discriminator.parameters():
                    p.requires_grad = True

                self.discriminator.zero_grad()

                sample = next(data_iterator)
                iterator += 1

                right_images = sample['right_images']
                right_embed = sample['right_embed']
                wrong_images = sample['wrong_images']

                right_images = Variable(right_images.float()).cuda()
                right_embed = Variable(right_embed.float()).cuda()
                wrong_images = Variable(wrong_images.float()).cuda()

                outputs, _ = self.discriminator(right_images, right_embed)
                real_loss = torch.mean(outputs)
                real_loss.backward(mone)

                if cls:
                    outputs, _ = self.discriminator(wrong_images, right_embed)
                    wrong_loss = torch.mean(outputs)
                    wrong_loss.backward(one)

                noise = Variable(torch.randn(right_images.size(0), self.noise_dim), volatile=True).cuda()
                noise = noise.view(noise.size(0), self.noise_dim, 1, 1)

                fake_images = Variable(self.generator(right_embed, noise).data)
                outputs, _ = self.discriminator(fake_images, right_embed)
                fake_loss = torch.mean(outputs)
                fake_loss.backward(one)

                ## NOTE: Pytorch had a bug with gradient penalty at the time of this project development
                ##       , uncomment the next two lines and remove the params clamping below if you want to try gradient penalty
                # gp = Utils.compute_GP(self.discriminator, right_images.data, right_embed, fake_images.data, LAMBDA=10)
                # gp.backward()

                d_loss = real_loss - fake_loss

                if cls:
                    d_loss = d_loss - wrong_loss

                self.optimD.step()

                for p in self.discriminator.parameters():
                    p.data.clamp_(-0.01, 0.01)

            # Train Generator
            for p in self.discriminator.parameters():
                p.requires_grad = False
            self.generator.zero_grad()
            noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
            noise = noise.view(noise.size(0), 100, 1, 1)
            fake_images = self.generator(right_embed, noise)
            outputs, _ = self.discriminator(fake_images, right_embed)

            g_loss = torch.mean(outputs)
            g_loss.backward(mone)
            g_loss = - g_loss
            self.optimG.step()

            gen_iteration += 1

            self.logger.draw(right_images, fake_images)
            self.logger.log_iteration_wgan(epoch, gen_iteration, d_loss, g_loss, real_loss, fake_loss)
            
        self.logger.plot_epoch(gen_iteration)

        if (epoch+1) % 50 == 0:
            Utils.save_checkpoint(self.discriminator, self.generator, self.checkpoints_path, epoch)

def _train_gan(self, cls):
    criterion = nn.CrossEntropyLoss()
    l2_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()
    iteration = 0

    for epoch in range(self.num_epochs):
        for sample in self.data_loader:
            iteration += 1
            right_images = sample['right_images']
            right_embed = sample['right_embed']
            wrong_images = sample['wrong_images']

            right_images = Variable(right_images.float()).cuda()
            right_embed = Variable(right_embed.float()).cuda()
            wrong_images = Variable(wrong_images.float()).cuda()

            real_labels = torch.ones(right_images.size(0))
            fake_labels = torch.zeros(right_images.size(0))

            # ======== One sided label smoothing ==========
            # Helps preventing the discriminator from overpowering the
            # generator adding penalty when the discriminator is too confident
            # =============================================
            smoothed_real_labels = torch.FloatTensor(Utils.smooth_label(real_labels.numpy(), -0.1))

            real_labels = Variable(real_labels).cuda()
            smoothed_real_labels = Variable(smoothed_real_labels).cuda()
            fake_labels = Variable(fake_labels).cuda()

            # Train the discriminator
            self.discriminator.zero_grad()
            outputs, activation_real = self.discriminator(right_images, right_embed)
            real_loss = criterion(outputs, smoothed_real_labels.squeeze())
            real_score = outputs

            if cls:
                outputs, _ = self.discriminator(wrong_images, right_embed)
                wrong_loss = criterion(outputs, fake_labels)
                wrong_score = outputs

            noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
            noise = noise.view(noise.size(0), 100, 1, 1)
            fake_images = self.generator(right_embed, noise)
            outputs, _ = self.discriminator(fake_images, right_embed)
            fake_loss = criterion(outputs, fake_labels.squeeze())
            fake_score = outputs

            d_loss = real_loss + fake_loss

            if cls:
                d_loss = d_loss + wrong_loss

            d_loss.backward()
            self.optimD.step()

            # Train the generator
            self.generator.zero_grad()
            noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
            noise = noise.view(noise.size(0), 100, 1, 1)
            fake_images = self.generator(right_embed, noise)
            outputs, activation_fake = self.discriminator(fake_images, right_embed)
            _, activation_real = self.discriminator(right_images, right_embed)

            activation_fake = torch.mean(activation_fake, 0)
            activation_real = torch.mean(activation_real, 0)


            #======= Generator Loss function============
            # This is a customized loss function, the first term is the regular cross entropy loss
            # The second term is feature matching loss, this measure the distance between the real and generated
            # images statistics by comparing intermediate layers activations
            # The third term is L1 distance between the generated and real images, this is helpful for the conditional case
            # because it links the embedding feature vector directly to certain pixel values.
            #===========================================
            g_loss = criterion(outputs, real_labels) \
                     + self.l2_coef * l2_loss(activation_fake, activation_real.detach()) \
                     + self.l1_coef * l1_loss(fake_images, right_images)

            g_loss.backward()
            self.optimG.step()

            if iteration % 5 == 0:
                self.logger.log_iteration_gan(epoch,d_loss, g_loss, real_score, fake_score)
                self.logger.draw(right_images, fake_images)

        self.logger.plot_epoch_w_scores(epoch)

        if (epoch) % 10 == 0:
            Utils.save_checkpoint(self.discriminator, self.generator, self.checkpoints_path, self.save_path, epoch)

def _train_vanilla_wgan(self):
    one = Variable(torch.FloatTensor([1])).cuda()
    mone = one * -1
    gen_iteration = 0

    for epoch in range(self.num_epochs):
        iterator = 0
        data_iterator = iter(self.data_loader)

    while iterator < len(self.data_loader):

        if gen_iteration < 25 or gen_iteration % 500 == 0:
                d_iter_count = 100
        else:
                d_iter_count = self.DITER

        d_iter = 0

         # Train the discriminator
        while d_iter < d_iter_count and iterator < len(self.data_loader):
            d_iter += 1

            for p in self.discriminator.parameters():
                p.requires_grad = True

            self.discriminator.zero_grad()

            sample = next(data_iterator)
            iterator += 1

            right_images = sample['right_images']
            right_images = Variable(right_images.float()).cuda()

            outputs, _ = self.discriminator(right_images)
            real_loss = torch.mean(outputs)
            real_loss.backward(mone)

            noise = Variable(torch.randn(right_images.size(0), self.noise_dim), volatile=True).cuda()
            noise = noise.view(noise.size(0), self.noise_dim, 1, 1)

            fake_images = Variable(self.generator(noise).data)
            outputs, _ = self.discriminator(fake_images)
            fake_loss = torch.mean(outputs)
            fake_loss.backward(one)

            ## NOTE: Pytorch had a bug with gradient penalty at the time of this project development
            ##       , uncomment the next two lines and remove the params clamping below if you want to try gradient penalty
            # gp = Utils.compute_GP(self.discriminator, right_images.data, right_embed, fake_images.data, LAMBDA=10)
            # gp.backward()

            d_loss = real_loss - fake_loss
            self.optimD.step()

            for p in self.discriminator.parameters():
                p.data.clamp_(-0.01, 0.01)

         # Train Generator
        for p in self.discriminator.parameters():
            p.requires_grad = False

        self.generator.zero_grad()
        noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
        noise = noise.view(noise.size(0), 100, 1, 1)
        fake_images = self.generator(noise)
        outputs, _ = self.discriminator(fake_images)

        g_loss = torch.mean(outputs)
        g_loss.backward(mone)
        g_loss = - g_loss
        self.optimG.step()

        gen_iteration += 1

        self.logger.draw(right_images, fake_images)
        self.logger.log_iteration_wgan(epoch, gen_iteration, d_loss, g_loss, real_loss, fake_loss)

    self.logger.plot_epoch(gen_iteration)

    if (epoch + 1) % 50 == 0:
        
        Utils.save_checkpoint(self.discriminator, self.generator, self.checkpoints_path, epoch)

def _train_vanilla_gan(self):
    criterion = nn.BCELoss()
    l2_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()
    iteration = 0

    for epoch in range(self.num_epochs):
        for sample in self.data_loader:
            iteration += 1
            right_images = sample['right_images']

            right_images = Variable(right_images.float()).cuda()

            real_labels = torch.ones(right_images.size(0))
            fake_labels = torch.zeros(right_images.size(0))

            # ======== One sided label smoothing ==========
            # Helps preventing the discriminator from overpowering the
            # generator adding penalty when the discriminator is too confident
            # =============================================
            smoothed_real_labels = torch.FloatTensor(Utils.smooth_label(real_labels.numpy(), -0.1))

            real_labels = Variable(real_labels).cuda()
            smoothed_real_labels = Variable(smoothed_real_labels).cuda()
            fake_labels = Variable(fake_labels).cuda()


            # Train the discriminator
            self.discriminator.zero_grad()
            outputs, activation_real = self.discriminator(right_images)
            real_loss = criterion(outputs, smoothed_real_labels)
            real_score = outputs

            noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
            noise = noise.view(noise.size(0), 100, 1, 1)
            fake_images = self.generator(noise)
            outputs, _ = self.discriminator(fake_images)
            fake_loss = criterion(outputs, fake_labels)
            fake_score = outputs

            d_loss = real_loss + fake_loss

            d_loss.backward()
            self.optimD.step()

            # Train the generator
            self.generator.zero_grad()
            noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
            noise = noise.view(noise.size(0), 100, 1, 1)
            fake_images = self.generator(noise)
            outputs, activation_fake = self.discriminator(fake_images)
            _, activation_real = self.discriminator(right_images)

            activation_fake = torch.mean(activation_fake, 0)
            activation_real = torch.mean(activation_real, 0)

            # ======= Generator Loss function============
            # This is a customized loss function, the first term is the regular cross entropy loss
            # The second term is feature matching loss, this measure the distance between the real and generated
            # images statistics by comparing intermediate layers activations
            # The third term is L1 distance between the generated and real images, this is helpful for the conditional case
            # because it links the embedding feature vector directly to certain pixel values.
            g_loss = criterion(outputs, real_labels) \
                     + self.l2_coef * l2_loss(activation_fake, activation_real.detach()) \
                     + self.l1_coef * l1_loss(fake_images, right_images)

            g_loss.backward()
            self.optimG.step()

            if iteration % 5 == 0:
                self.logger.log_iteration_gan(epoch, d_loss, g_loss, real_score, fake_score)
                self.logger.draw(right_images, fake_images)

        self.logger.plot_epoch_w_scores(iteration)

        if (epoch) % 50 == 0:
            Utils.save_checkpoint(self.discriminator, self.generator, self.checkpoints_path, epoch)

def predict(self):
    for sample in self.data_loader:
        right_images = sample['right_images']
        right_embed = sample['right_embed']
        txt = sample['txt']

        if not os.path.exists('results/{0}'.format(self.save_path)):
            os.makedirs('results/{0}'.format(self.save_path))

        right_images = Variable(right_images.float()).cuda()
        right_embed = Variable(right_embed.float()).cuda()

        # Train the generator
        noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
        noise = noise.view(noise.size(0), 100, 1, 1)
        fake_images = self.generator(right_embed, noise)

        self.logger.draw(right_images, fake_images)

        for image, t in zip(fake_images, txt):
            im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
            im.save('results/{0}/{1}.jpg'.format(self.save_path, t.replace("/", "")[:100]))
            print(t)

`

Manifold interpolation

In the original paper, the authors show that their manifold interpolation gets a big improvement. Did you implement this part in your code?

batch image generation

Thank you for sharing the code. I discovered that the 'generator' generate 64 images each time by 64 different embeddings. However, when I tried to copy embeddings :
for i in range(10):
right_embed[i] = right_embed[0]
it still work and generate normal images with the first 10 images look the same, but when most of embeddings are the same ( range 50 ), the 'generator' cant work and will generate wierd results.
SO, what I should do to generate small number of images by embeddings from 2~3 descriptions, because I cant achieve this by simply copy these embeddings 64 times .

flowers.hd5

I downloaded flowers.hdf5.
Confused on how to proceed further.
How do I get the following?

flowers_images_path: '/export/mlrg/aelnouby/projects/GANs/flowers_dataset/'
flowers_embedding_path: '/export/mlrg/aelnouby/projects/GANs/flowers_dataset/flowers_icml/'
flowers_text_path: '/export/mlrg/aelnouby/projects/GANs/flowers_dataset/cvpr2016_flowers/text_c10/'
flowers_dataset_path: '/scratch/aelnouby/text2image/flowers.hdf5'

flowers_val_split_path: '/export/mlrg/aelnouby/projects/GANs/flowers_dataset/flowers_icml/valclasses.txt'
flowers_train_split_path: '/export/mlrg/aelnouby/projects/GANs/flowers_dataset/flowers_icml/trainclasses.txt'
flowers_test_split_path: '/export/mlrg/aelnouby/projects/GANs/flowers_dataset/flowers_icml/testclasses.txt'

Instance as first argument

File "trainer.py", line 32, in init
self.discriminator.apply(Utils.weights_init)

File "/home/aara/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 136, in apply
module.apply(fn)

File "/home/aara/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 136, in apply
module.apply(fn)

File "/home/aara/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 136, in apply
module.apply(fn)

File "/home/aara/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 137, in apply
fn(self)

TypeError: unbound method weights_init() must be called with Utils instance as first argument (got Conv2d instance instead)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.