Giter Site home page Giter Site logo

style-based-gan-pytorch's Introduction

Style-Based GAN in PyTorch

Update (2019/09/01)

I found bugs in the implementation thanks to @adambielski and @TropComplique! (#33, #34) I have fixed this and updated checkpoints

Update (2019/07/04)
  • Now trainer uses pre-resized lmdb dataset for more stable data loading and training.
  • Model architecture is now more closely matches with official implementation.

Implementation of A Style-Based Generator Architecture for Generative Adversarial Networks (https://arxiv.org/abs/1812.04948) in PyTorch

Usage:

You should prepare lmdb dataset

python prepare_data.py --out LMDB_PATH --n_worker N_WORKER DATASET_PATH

This will convert images to jpeg and pre-resizes it. (For example, 8/16/32/64/128/256/512/1024) Then you can train StyleGAN.

for celebA

python train.py --mixing LMDB_PATH

for FFHQ

python train.py --mixing --loss r1 --sched LMDB_PATH

Resolution Model & Optimizer
256px Link
512px Link
1024px Link

Model & Optimizer checkpoints saved at the end of phases of each resolution. (that is, 512px checkpoint saved at the end of 512px training.)

Sample

Sample of the model trained on FFHQ Style mixing sample of the model trained on FFHQ

512px sample from the generator trained on FFHQ.

Old Checkpoints

Resolution Model & Optimizer Running average of generator
128px Link 100k iter Link
256px Link 140k iter Link
512px Link 180k iter Link

Old version of checkpoints. As gradient penalty and discriminator activations are different, it is better to use new checkpoints to do some training. But you can use these checkpoints to make samples as generator architecture is not changed.

Running average of generator is saved at the specified iterations. So these two are saved at different iterations. (Yes, this is my mistake.)

style-based-gan-pytorch's People

Contributors

andreasjansson avatar michaelmonashev avatar penguinbing avatar rosinality avatar vinayakarannil avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

style-based-gan-pytorch's Issues

gpu requirements

What is your graphic card? I am reaching 67% train and then it falls, saying 'CUDA out of memory'

Question regarding the difference in fade in with the official implementation

As stated in the original ProGAN paper, it seems that they first stabilize the resolution at i.e. 4x4 for 800000 images, and then fade in the 8x8 for another 800000. This goes on and on. In the current implementation, it seems that fade in is happening even at the stabilizing stage. Is this correct?

PS. please refer to ProGAN paper Appendix A.1

Continue train

Is there a simple way to continue train? Is it possible without saving/loading discriminator state?

Default loss function

I noticed that the loss function is set to wp_gan by default. however, in the official implementation, the loss is set to r1 by default. Is that correct? If so why did you choose that?

discriminator discrepancy with proGAN paper

I'm trying to compare the Discriminator's downsampling procedure to Figure 2b of the proGAN paper (specifically the case where where 0 < alpha < 1). Here's the relevant code

    def forward(self, input, step=0, alpha=-1):
        for i in range(step, -1, -1):
            index = self.n_layer - i - 1

            if i == step:
                out = self.from_rgb[index](input)

            if i == 0:
                out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
                mean_std = out_std.mean()
                mean_std = mean_std.expand(out.size(0), 1, 4, 4)
                out = torch.cat([out, mean_std], 1)

            out = self.progression[index](out)

            if i > 0:
                # out = F.avg_pool2d(out, 2)
                out = F.interpolate(
                    out, scale_factor=0.5, mode='bilinear', align_corners=False
                )

                if i == step and 0 <= alpha < 1:
                    # skip_rgb = F.avg_pool2d(input, 2)
                    skip_rgb = self.from_rgb[index + 1](input)
                    skip_rgb = F.interpolate(
                        skip_rgb, scale_factor=0.5, mode='bilinear', align_corners=False
                    )

                    out = (1 - alpha) * skip_rgb + alpha * out

        out = out.squeeze(2).squeeze(2)
        # print(input.size(), out.size(), step)
        out = self.linear(out)

        return out

No matter how many times I look at this, I feel like lines

                    skip_rgb = self.from_rgb[index + 1](input)
                    skip_rgb = F.interpolate(
                        skip_rgb, scale_factor=0.5, mode='bilinear', align_corners=False
                    )

should be switched. In the paper, the pathway that gets multiplied by (1-alpha) takes in the input image, downsamples it, and then applies a from_rgb layer. However, it appears that in this code, first the input image gets passed through the from_rgb layer associated with the previous step and then the resulting image gets downsampled. Am I understanding this correctly? If so, was this intentional?

Pretrained model?

Can you / will you upload any pretrained model to compare results, please?

How to test the trained model

Thank you for your work, I am a newcomer who wants to know about the work. I noticed that you only provide the command to train the model in the document. How can I test the model I got after the training is completed?

In addition, can you provide your system environment, such as the version of pytorch, etc.,thank you very much

[-1, 1] normalization on fake_image?

Hi, thanks for this great repo and I have a question on a small detail.

This can be minor but I found that each pixel value of real_image is normalized to the dynamic range of [-1, 1] while fake_image is not. Also, I couldn't find any operations for handling this(e.g. F.tanh or nn.Tanh). And I wonder if this is what you intended.

how to train on own dataset

it's me again, if I want to train my own dataset , how to do this.
your code support CelebA dataset by datasets.ImageFolder of pytorch API.
If my dataset format is the same as datasets.ImageFolder api. like:

root/class1/xxx.jpg
root/class1/yy.jpg
root/class1/zz.jpg

root/class2/xxx.jpg
root/class2/yy.jpg
root/class2/zz.jpg

Can I train the code?? and you already support multi-GPU ,is it?

Grayscale input

Hi,

I try to use your environment to train the generator and discriminator with grayscale inputs, but I don't figure it which are the values that need to be change from 3 to 1, because everywhere I tried, I got a later error of mismatch in shape.

Checkpoint files for new versions

In the repo issue discussions, I found model checkpoint for the previous version of this repo, which seems noncompatible with the current version of gan. Is there any chance to get the newly trained models especially in the light of new high-quality results on FFHQ?

Or maybe I am just doing something wrong and the old checkpoint is still usable?

how to control the style what we want to generate?

Hi, I look the train.py code and I think the tensor of label was not used anywhere!

If I want to generate like this image , how to write this code ?

the generator input is only gen_in1 or gen_in2 So, how to control the style that I want to generator?

Slowly training on multyple GPUs

1 GPU gives 1.55it/s :

> CUDA_VISIBLE_DEVICES=0 python3 train.py --mixing --loss r1 --sched train
Size: 8; G: 7.498; D: 0.792; Grad: 1.636; Alpha: 1.00000:   0%|                                                          | 34/30000 [00:21<5:22:53,  1.55it/s]

8 GPUs 5 times slowly and give only 3.19s/it.

> CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 train.py --mixing --loss r1 --sched train
Size: 8; G: 3.594; D: 0.328; Grad: 0.440; Alpha: 1.00000:   0%|                                                         | 29/30000 [01:59<26:34:16,  3.19s/it]

Missing keys when generating new samples

I downloaded the generator parameters provided by you here: #3 (comment). But I keep getting missing key(s) error as following:

(base) Rahuls-MacBook-Pro:style-based-gan-pytorch rahulbhalley$ python generate.py 
Traceback (most recent call last):
  File "generate.py", line 9, in <module>
    generator.load_state_dict(torch.load('checkpoint/style-gan-600k.model', map_location=device))
  File "/Users/rahulbhalley/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 777, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for StyledGenerator:
	Missing key(s) in state_dict: "generator.progression.1.conv1.1.conv.bias", "generator.progression.1.conv1.1.conv.weight_orig", "generator.progression.1.conv1.2.weight", "generator.progression.1.conv1.2.weight_flip", "generator.progression.2.conv1.1.conv.bias", "generator.progression.2.conv1.1.conv.weight_orig", "generator.progression.2.conv1.2.weight", "generator.progression.2.conv1.2.weight_flip", "generator.progression.3.conv1.1.conv.bias", "generator.progression.3.conv1.1.conv.weight_orig", "generator.progression.3.conv1.2.weight", "generator.progression.3.conv1.2.weight_flip", "generator.progression.4.conv1.1.conv.bias", "generator.progression.4.conv1.1.conv.weight_orig", "generator.progression.4.conv1.2.weight", "generator.progression.4.conv1.2.weight_flip", "generator.progression.5.conv1.0.weight", "generator.progression.5.conv1.0.bias", "generator.progression.5.conv1.1.weight", "generator.progression.5.conv1.1.weight_flip", "generator.progression.6.conv1.0.weight", "generator.progression.6.conv1.0.bias", "generator.progression.6.conv1.1.weight", "generator.progression.6.conv1.1.weight_flip", "generator.progression.6.noise1.weight_orig", "generator.progression.6.adain1.style.linear.bias", "generator.progression.6.adain1.style.linear.weight_orig", "generator.progression.6.conv2.conv.bias", "generator.progression.6.conv2.conv.weight_orig", "generator.progression.6.noise2.weight_orig", "generator.progression.6.adain2.style.linear.bias", "generator.progression.6.adain2.style.linear.weight_orig", "generator.progression.7.conv1.0.weight", "generator.progression.7.conv1.0.bias", "generator.progression.7.conv1.1.weight", "generator.progression.7.conv1.1.weight_flip", "generator.progression.7.noise1.weight_orig", "generator.progression.7.adain1.style.linear.bias", "generator.progression.7.adain1.style.linear.weight_orig", "generator.progression.7.conv2.conv.bias", "generator.progression.7.conv2.conv.weight_orig", "generator.progression.7.noise2.weight_orig", "generator.progression.7.adain2.style.linear.bias", "generator.progression.7.adain2.style.linear.weight_orig", "generator.progression.8.conv1.0.weight", "generator.progression.8.conv1.0.bias", "generator.progression.8.conv1.1.weight", "generator.progression.8.conv1.1.weight_flip", "generator.progression.8.noise1.weight_orig", "generator.progression.8.adain1.style.linear.bias", "generator.progression.8.adain1.style.linear.weight_orig", "generator.progression.8.conv2.conv.bias", "generator.progression.8.conv2.conv.weight_orig", "generator.progression.8.noise2.weight_orig", "generator.progression.8.adain2.style.linear.bias", "generator.progression.8.adain2.style.linear.weight_orig", "generator.to_rgb.6.conv.bias", "generator.to_rgb.6.conv.weight_orig", "generator.to_rgb.7.conv.bias", "generator.to_rgb.7.conv.weight_orig", "generator.to_rgb.8.conv.bias", "generator.to_rgb.8.conv.weight_orig". 
	Unexpected key(s) in state_dict: "generator.progression.1.conv1.conv.bias", "generator.progression.1.conv1.conv.weight_orig", "generator.progression.2.conv1.conv.bias", "generator.progression.2.conv1.conv.weight_orig", "generator.progression.3.conv1.conv.bias", "generator.progression.3.conv1.conv.weight_orig", "generator.progression.4.conv1.conv.bias", "generator.progression.4.conv1.conv.weight_orig", "generator.progression.5.conv1.conv.bias", "generator.progression.5.conv1.conv.weight_orig". 

From your README it looks like you've recently modified the network architecture, therefore, this error is occurring. Could you please upload new parameters of both generator and discriminator networks for us to use for further training and high resolution image synthesis? It would be really helpful for me as I am basing my research work on this implementation and particularly your implementation looks pretty nice to me.

By the way thanks a lot for open-sourcing such a nice and simple implementation of StyleGAN in PyTorch. 🙂

Saturation artifacts with large batch sizes

Hey! Thanks a bunch for making this great implementation. I noticed that when I tune the batch sizes to be closer to Tero's numbers for 1 GPU, my training falls apart after transitioning resolutions. This doesn't seem to happen at a batch size of 16.

Looks like this (batch size is 64 for all of these):

(Right before transition from 8x8 to 16x16)
image

(Right after transition)
image

(A few hundred after transition)
image

(1500 iterations after transition)
image

(2000 after)
image

From here it just keeps collapsing. Did you notice this with your training at all? It doesn't seem to happen to me in Tero's tensorflow implementation with equivalent iterations

Did you update your pretrain model weight??

Hi, I look issue #3 and have found generator and discriminator weight. But maybe it is not a pair.

Can you give me the same pixel-resolution and same dataset the weight of generator and discriminator?

Uses only 1 GPU instead of four

I run training on four gpu gtx2080ti but training does not start and the error falls Cuda Out of memory GPU: 0, I looked and saw that the program uses only 1 GPU instead of 4, how to fix it?

KeyError: 'generator' when continuing from checkpoint

Hey there,
first, thank you for your amazing work with this pytorch stylegan, i got it to work quite flawlessly.
I trained on a quite small custom dataset on the free gpu on google colab for several hours. Now i got two .model checkpoints saved in the docked google drive. However when im trying to continue training from the checkpoint "020000.model" im getting a KeyError: 'generator'. Help would be really appreciated since im searching around for a fix for quite some time now.

See output:

!python ./train.py --ckpt ./checkpoint/020000.model ./datasets/custom

Traceback (most recent call last):
File "./train.py", line 316, in
generator.module.load_state_dict(ckpt['generator'])
KeyError: 'generator'

thank you in advance!

_pickle.UnpicklingError: invalid load key, '\xb8'.

Hi, I use the pre-trained model but get the following error, my pytorch version=1.0.0, can any one help me?
python generate.py
Traceback (most recent call last):
File "generate.py", line 7, in
generator.load_state_dict(torch.load('./checkpoint/style-gan-600k.model'))
File "/home/ubuntu/miniconda3/envs/dl/lib/python3.6/site-packages/torch/serialization.py", line 367, in load
return _load(f, map_location, pickle_module)
File "/home/ubuntu/miniconda3/envs/dl/lib/python3.6/site-packages/torch/serialization.py", line 528, in _load
magic_number = pickle_module.load(f)
_pickle.UnpicklingError: invalid load key, '\xb8'.

image size and the differences with official implementation

thanks for this work! But i still have two question:

  1. the number of styledconvblock is 6, so the output image size is no more than 128, am i right? can this extend to 512 image or even 1024 image?
  2. is there any differences between this implementation and the official code(or paper)?

Wrong file permissions

These files have execution bits, but should not be.

-rwxrwxr-x generate.py
-rwxrwxr-x model.py
-rwxrwxr-x README.md
-rwxrwxr-x train.py

how to fix it:

git update-index --chmod=-x generate.py model.py README.md train.py
git commit -m "Changing file permissions" generate.py model.py README.md train.py
git push

Sorry, Github does not have a web interface to make the pull request with file permissions. So, I write this issue.

TypeError: forward() missing 1 required positional argument: 'input'

> CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 train.py --mixing --loss r1 --sched --phase 16 train-full
Size: 256; G: 2.811; D: 3.419; Grad: 0.003; Alpha: 1.00000:   0%|                                                       | 13/3000000 [01:08<17:22:00,  2.08s/it]Traceback (most recent call last):
  File "train_v3.py", line 345, in <module>
    train(args, dataset, generator, discriminator)
  File "train_v3.py", line 133, in train
    real_scores = discriminator(real_image, step=step, alpha=alpha)
  File "/home/michael/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/michael/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/michael/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/michael/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
    output.reraise()
  File "/home/michael/.local/lib/python3.6/site-packages/torch/_utils.py", line 369, in reraise
    raise self.exc_type(msg)
TypeError: Caught TypeError in replica 4 on device 4.
Original Traceback (most recent call last):
  File "/home/michael/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/michael/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'input'

Size: 256; G: 2.811; D: 3.419; Grad: 0.003; Alpha: 1.00000:   0%|                                                     | 13/3000000 [01:14<48:01:02,  5.76s/it]

ConvBlock in Discriminator

Hi,

Is line 520 in model.py deliberate? Notice the input size is 513 (and not 512 ?)
ConvBlock(513, 512, 3, 1, 4, 0),

Thanks for the clarification

accumulate() function ??

Hi, Can I ask you a question about accumulate function in train.py?

What is this function doing? Which functionality is this in style based generator paper?

generate 256px img using pretrained model get wired picture

Hi Seonghyeon,
Thank your so much for your efforts.
I used your pretrained 256px model to generate images by simply running:

python generate.py 'checkpoint/pretrained/stylegan-256px-new.model'

here is what i got:

image

Could you plz point out what am I doing wrong here

Thank you

Reverse recovery of Latent vector

Hi,

Has anyone tried to do a reverse backprop using a target image, to find the best style vector and/or closest GAN image?
I can't seem to make it work.

Thanks!

Style generation...

Why do you get 1000 styles and then average them in get_mean_style() ? Why not use 1 style?

Package Versions

Could you post the versions of the various packages you're using? I'm particularly curious about what version of pytorch you use!

Why to use "g_runing" and "accumulate"?

Hi! I'm new in GAN's field. I have a question about why to use "g_runing" and "accumulate" to do a weighted-average to generator's parameter during training? And why to use "g_runing" instead of trained generator during testing?

Is this a trick during GAN's training? Or does this appear in any paper?

Thanks a lot.

Generated images color issue

The output generated images are inverted/rendering with the wrong palette of colors. Could you please check what is wrong here?

CUDA out of memory in generate, even with 1 image

Before I start, Huge thank you for the pre trained models;
but generating is problematic. Each time I try to run the code this happens:output
Surprisingly enough, training works perfectly fine.
Using PYCharm on Windows 10 fall creators update with python 3.7; Debugging in console.

how to launch test?

Hello!
I've managed to launch train, but whats next? How to test it after? I would like to reproduce experiments with style transfer, but I do not see how to provide 2 images to generator

Is it possible to use TF official weights?

Hi, thank you for your brilliant code. I'm wondering if the official repo's weight can be used directly in your model. In my previous project, I successfully initialized Pytorch InceptionV3's weight with TF's weight and the error is within 1e-5. Therefore, if the architecture is the same, using the same weight is also possible. Do you think this is possible with your implementation?

How to get the the latent code of a image

Nice job, I have a question about the latent code, the latent code in the generate.py is randomly generated. And for me i want to get the latent code of a image and apply the style of the image to another image , so how can i get the latent code of the image ?

64 image-size

Hello,

Where do I need to modify if the image size of my dataset is 64x64? I'm afraid I may miss some places. Thanks!

crash when running on 8 gpus

I've found that the training loop crashes after tqdm is successfully initiated. We are using the r1 flag. Basically it get stuck somewhere in the first training loop and never gets to the second one. We're unable to kill it normally, have to use kill -9

We were using 8 v100s for this run, and have successfully run this with the same dataset with 2 1080s.

what is the important of the below lines?

        if mixing and random.random() < 0.9:
            gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn(4, b_size, code_size, device='cuda').chunk(4, 0)
            gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)]
            gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)]

        else:
            gen_in1, gen_in2 = torch.randn(2, b_size, code_size, device='cuda').chunk(2, 0)
            gen_in1 = gen_in1.squeeze(0)
            gen_in2 = gen_in2.squeeze(0)

Bad results when continue training

First, thank you for this implementation!
I want to continue train a model that I trained before. I run the following command:
python train.py --mixing --init_size 64 ./images/ --ckpt ./checkpoint/140000.model
I don't get any errors when loading but the generated images don't look like the ones I got before, but instead I get really bad results.

Error on string operation on train.py

Hello,

I'm trying to run your code as per your instruction on README. But I got stuck on many string operation on train.py such as

torch.save(
                {
                    'generator': generator.module.state_dict(),
                    'discriminator': discriminator.module.state_dict(),
                    'g_optimizer': g_optimizer.state_dict(),
                    'd_optimizer': d_optimizer.state_dict(),
                },
                f'checkpoint/train_step-{step}.model', <-- here
            )

What type of string operation did you use? Or is it different Python version? I'm using Python 3.5

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.