This repository contains code for our paper
Closed-Loop Training for Projected GAN
Jiangwei Zhao, Liang Zhang, Lili Pan, Hongliang Li
IEEE Signal Processing Letters (submitted)
Abstract:Projected GAN, a pre-trained GAN, has been found to perform well in generating images with only a few training samples. However, it struggles with extended training, which may lead to decreased performance over time. This is because the pre-trained discriminator consistently surpasses the generator, creating an unstable training environment. In this work, we propose a solution to this issue by introducing closed-loop control (CLC) into the dynamics of Projected GAN, stabilizing training and improving generation performance. Our proposed method consistently reduces the Fréchet Inception Distance (FID) of the previous methods; for example, it reduces the FID of Projected GAN by 4.31 on the Obama dataset. Our finding is fundamental and can be used in other pre-trained GANs. The code is available at https://github.com/learninginvision/ProjectedGAN-CLC.
Dependencies
- 64-bit Python 3.8
- PyTorch 1.9.0 (or later). See https://pytorch.org for PyTorch install instructions.
Installation
First, you can clone this repo using the command:
git clone https://github.com/learninginvision/ProjectedGAN-CLC
Then, you can create a virtual environment using conda, as follows:
conda env create -f environment.yaml
conda activate pg-clc
For a quick start, you can download the few-shot datasets provided by the authors of FastGAN. You can download them here. To prepare the dataset at the respective resolution, run for example
python dataset_tool.py --source=./data/pokemon --dest=./data/pokemon256.zip \
--resolution=256x256 --transform=center-crop
You can get the datasets we used in our paper at their respective websites: AFHQ, Landscape.
Training your own PG-CLC on Pokemon using 2 GPUs:
python train.py --outdir=./training-runs/ --cfg=fastgan --data=./data/pokemon256.zip \
--gpus=2 --batch=64 --mirror=1 --snap=50 --batch-gpu=16 --kimg=10000
--batch
specifies the overall batch size, --batch-gpu
specifies the batch size per GPU.
We use a lightweight version of FastGAN (--cfg=fastgan_lite
). This backbone trains fast regarding wallclock
time and yields better results on small datasets like Pokemon.
Samples and metrics are saved in outdir
. To monitor the training progress, you can inspect fid50k_full.json or run tensorboard in training-runs.
You can change the config of clc on train.py#L240-L243
We provide the following pretrained models (pass the url as PATH_TO_NETWORK_PKL
):
Dataset | Loss Weight | Queue Factor | FID | PATH |
---|---|---|---|---|
Pokemon | 0.1 | 100 | 25.04 | https://drive.google.com/file/d/18-678PSsr4sYX28qtIkdkOd3TtdpKCWf |
Art-Paint | 0.05 | 200 | 26.91 | https://drive.google.com/file/d/1if_qohz0PYtSzuSlL72nE71oATxuSmVT |
Flowers | 0.05 | 200 | 12.82 | https://drive.google.com/file/d/1B844ooziyOhk3dGbS389XWujIPjTpYbN |
landscapes | 0.05 | 100 | 6.55 | https://drive.google.com/file/d/1RpDg4vRPgD6UXajzmWDNSuyxkS2F_pwK |
Obama | 0.05 | 100 | 20.12 | https://drive.google.com/file/d/1A0SbqW3xvHMfWVs_Pp7nUs8Ih5Uj9aYL |
Per default, train.py
tracks FID50k during training. To calculate metrics for a specific network snapshot, run
python calc_metrics.py --metrics=fid50k_full --network=PATH_TO_NETWORK_PKL
To see the available metrics, run
python calc_metrics.py --help
Our codebase build and extends the awesome StyleGAN2-ADA repo, ProjectedGAN repo and StyleGAN3 repo