Giter Site home page Giter Site logo

pytorch-progressive_growing_of_gans's People

Contributors

github-pengge avatar yuanzhaoyz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-progressive_growing_of_gans's Issues

error: salt must be bytes

hi ,thanks for your work.
I came into the error: salt must be bytes.
what is the reason, have you counter it?

No longer other-dataset friendly

The new code only works with the hdf5 dataset. Is it possible to have a more abstract dataloader, and have the option of h5 or raw image files? I was using the old version for a different dataset.

Bug related to some objects using/not using GPU

When trying to train from scratch, I encounter the following error:
x = self.scale * x
RuntimeError: Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor for argument #2 'other'

I believe it has something to do about changing either one of the variable to run on GPU or the other one to run on CPU. yet such change might cause a chain of changes.
I would like to know if you had this problem before, and how should I treat it.
Thanks a lot.

torch0.3 py=3.6 RuntimeError

Traceback (most recent call last):
  File "train.py", line 365, in <module>
    pggan.train()
  File "train.py", line 286, in train
    self.train_phase(R, phase, batch_size, _range[0]*batch_size, _range[0], _range[1])
  File "train.py", line 240, in train_phase
    self.forward_D(cur_level, detach=True)
  File "train.py", line 196, in forward_D
    self.d_real = self.D(self.real, cur_level=cur_level, gdrop_strength=strength)
  File "/home/jurh/anaconda2/envs/th03/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "models/model.py", line 218, in forward
    return self.output_layer(x, y, cur_level, insert_y_at)
  File "/home/jurh/anaconda2/envs/th03/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "models/base_model.py", line 280, in forward
    x = self.chain[max_level](x)
  File "/home/jurh/anaconda2/envs/th03/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/jurh/anaconda2/envs/th03/lib/python3.6/site-packages/torch/nn/modules/container.py", line 67, in forward
    input = module(input)
  File "/home/jurh/anaconda2/envs/th03/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "models/base_model.py", line 74, in forward
    vals = torch.mean(vals, keepdim=True)
RuntimeError: mean() missing 1 required positional arguments: "dim"

Possible bug in `he_init` function

It seems that in the current version of master in the he_init function, gain is being passed for argument a.
But as per Pytorch code in both version 1.0 and 0.4.1, kaiming_normal_() expects a to be negative slope and has nonlinearity argument separately. So just following function call should do:

kaiming_normal(layer.weight, nonlinearity=nonlinearity, a=param) 

Pytorch source code link: https://github.com/pytorch/pytorch/blob/v0.4.1/torch/nn/init.py#L296

I have tested it on some layers, just above call gives the expected value to std() of layer.weight.

Your code might have been based on some different version of pytorch that did expect to pass gain separately but thought should you a heads up just in case.

Might be a bug during fade in

This line

        for it in range(from_it, total_it):
            if phase == 'stabilize':
                cur_level = R
            else:
                cur_level = R + total_it/float(from_it)
            cur_resol = 2 ** int(np.ceil(cur_level+1))

is problematic.

That's because, when fade in, the from_it = (train_kimg)//batch_size, total_it = (train_kimg+total_kimg)//batch_size so the according to your code, cur_level will always be R+1.99. But it should be something progressively increasing from [R, R+1].

It should be

        for it in range(from_it, total_it):
            if phase == 'stabilize':
                cur_level = R
            else:
                cur_level = R + (it - from_it) / float(total_it - from_it)
            cur_resol = 2 ** int(np.ceil(cur_level+1))

Please correct me if my understanding is wrong.

resize_activation function

hi,as you say you replace repeat by torch.nn.functional.upsample, but i read the source codes from author, it seems that author uses repeat, do you notice it ?

About GDropLayer?

in tkarras's(original author) code , I found the default option False for GDropLayer .
I just wonder if you find it helpful?

incompatibilities with python2

In addition to print statement, there are following inconsistencies with python 2

  1. copy in base_model.py will not work. shape.copy() is not available in python2. Use
    target_shape=shape[:]
    or something similar.
  2. imports need __init__.py in python earlier than 3.2. so import models.<> or import utils.<> will not work Alternative is to add sys.path.append('utils') and sys.path.append('models') as already done and then
from data import CelebA, RandomNoiseGenerator
from model import Generator, Discriminator

Did u finish to use pg-gan in cycle-gan?

I have read your article in paperweekly, I have the same idea with you. I wanna use this gan into cyclegan to improve the image quality. But if you have done this, I may find a new way to improve. So can you send email to me and tell me your progress about this work? My email is [email protected]. Of course, you could add my qq. Wating for your answer and thank u so much!

Why update_lr() ?

I was so confused about using self.update_lr(cur_nimg) in train.py train.py#L281 which may produce negative learning rate coefficient during the training process.

I have no idea about why we should use this func --- a func just make learning rate fluctuating.

Green Image

I trained the network with following setting:

python train.py --gpu 0 --train_kimg 600 --transition_kimg 600 --beta1 0 --beta2 0.99 --gan lsgan --first_resol 4 --target_resol 256 --no_tanh --exp_dir /exp

I am using the new CelebaHQ dataset from nvidia.

and I got following stabilization image after 75000 steps:
128x128-stabilize-074999

The results on all samples (at various sizes) are similar. I am using python2.7

Anyone experiencing similar samples?

Edit: Addition
I tried python3, using train_no_tanh. Still, the image is same. There is no change.

写完了么?

你发的博客里边说你还在实现pytorch,想问下现在这个版本是最终版本么?和论文的实现有不同么?

how to generate png images?

Hi, thanks for your work. I used your code, but it only created .h5 files.
i want to get images dataset. so how to generate images?

Do you have pre-trained models?

There are always bugs during my training. Do you have a pre-trained models which can yield the same result as the samples?

TypeError: float() argument must be a string or a number when I tried to create the Celeba-HQ

I tried to create the CelebA-HQ dataset with the original data set and dat file.

but the error occurs like below

(202599L, 5L, 2L)
Loading CelebA-HQ deltas from ./celeba-hq/Delta
Traceback (most recent call last):
File "h5tool.py", line 708, in
execute_cmdline(sys.argv)
File "h5tool.py", line 703, in execute_cmdline
func(**vars(args))
File "h5tool.py", line 609, in create_celeba_hq
aidx, aimg64, aimg128, aimg256, aimg512, aimg1024 = process_func(x)
File "h5tool.py", line 556, in process_func
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
TypeError: float() argument must be a string or a number

anyone has any idea how to solve this ?

Some questions about MinibatchStatConcatLayer

I have some questions about MinibatchStatConcatLayer. They are mostly about the shape of the vals in the code.
Let's say the shape of the input x is [b, c, h, w]

After this line:

vals = self.adjusted_std(x, dim=0, keepdim=True)# per activation, over minibatch dim

the shape of vals should be [1, c, h, w].

  1. For the case of "all ", we should get vals of shape [1, 1, 1, 1], since it is to "average everything --> 1 value per minibatch". However, this line

    vals = torch.mean(vals, dim=1, keepdim=True)#vals = torch.mean(vals, keepdim=True)

    outputs vals of shape [1, 1, h, w].
    I think we should use vals = torch.mean(vals, keepdim=True) instead, which you have commented for an unknown reason.

  2. What is the purpose of this line?

    target_shape = [target_shape[0]] + [s for s in target_shape[1:]]

    It seems like 'target_shape = target_shape' and we still get [b, c, h, w].

output sample images are incorrectly normalized

Some of the sample images looked washed out. I suspect that the min/max pixel values of the real samples versus the generated ones are different. It fluctuates wildly on every output. This makes it hard to verify quality during training most of the time.

download CelebA dataset and deltas files

I can't download the original CelebA dataset, and the additional deltas files,because of the access permission,please tell me how to solve this problem,thank you!

Where to change the weight of the layers?

I'm sorry to bother you that, here is the code:
min_level_weight, max_level_weight = int(cur_level+1)-cur_level, cur_level-int(cur_level)

I wonder if the 'cur_level' need to be changed during the iterations ? And when the iteration grows bigger the max_level_weight grows bigger too
But in the code It has be set as follows and never change

phases = {'stabilize':[0, train_kimg//batch_size], 'fade_in':[train_kimg//batch_size+1, (transition_kimg+train_kimg)//batch_size]}

About the loss function of LSGAN

You code about the D_loss

d_adv_loss = self.compute_adv_loss(self.d_real, True, 0.5) + self.compute_adv_loss(self.d_fake, False, 0.5)

The official code of D_loss

 if type == 'lsgan': 
        G_loss = L2(fake_scores_out, 0)
        D_loss = L2(real_scores_out, 0) + L2(fake_scores_out, 1) * L2_fake_weight

The value of L2_fake__weight is 0.1.

So, Do you abserve it?

WGAN-GP Loss

Have you tried WGAN-GP loss? I am having trouble while fading in a new layer using WGAN-GP loss, it causes the loss to go grazy and results in mode collapse.

About LayerNormLayer

in forward function, x = torch.inverse(...)

tensor = torch.FloatTensor([[1,2],[3,4]])
t_out = torch.inverse(tensor)
t_o = 1/tensor
print(tensor)
print(t_out)
print(t_o)
==========>
1 2
3 4
[torch.FloatTensor of size 2x2]

-2.0000 1.0000
1.5000 -0.5000
[torch.FloatTensor of size 2x2]

1.0000 0.5000
0.3333 0.2500
============>
or what we need is t_o
in theano T.inv===1/T
as has been mentioned in the paper

batch sizes don't match paper

Why are the default batch sizes used? The original paper uses 16 for sizes 4x4 to 128x128, which should be faster (overall) than what is currently used.

Error when generating the dataset

Hello,
Thanks for the great work! However I'm facing some errors when generating the images dataset. One of them is when I executed the command python2 h5tool.py create_celeba_hq file_name_to_save /path/to/celeba_dataset/ /path/to/celeba_hq_deltas, it shows the error message below. How should I solve this? Thank you so much.


Traceback (most recent call last):
  File "h5tool.py", line 697, in <module>
    execute_cmdline(sys.argv)
  File "h5tool.py", line 692, in execute_cmdline
    func(**vars(args))
  File "h5tool.py", line 596, in create_celeba_hq
    for idx, img in pool.process_items_concurrently(fields['idx'], process_func=process_func, max_items_in_flight=num_tasks):
  File "h5tool.py", line 161, in process_items_concurrently
    for res in retire_result(): yield res
  File "h5tool.py", line 149, in retire_result
    processed, (prepared, idx) = self.get_result(task_func)
  File "h5tool.py", line 126, in get_result
    raise Exception('%s, %s' % (result.type, result.value))
Exception: <type 'exceptions.ValueError'>, Unable to create correctly shaped tuple from ((420, 1730), (420, 1906), (0, 0))

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.