Giter Site home page Giter Site logo

decent's Introduction

Decent: Unpaired Image-to-Image Translation with Density Changing Regularization (Neurips2022)

Basic Usage

  • Training:
python train.py --dataroot=datasets/selfie2anime  
  • Test:
python test.py --dataroot=datasets/selfie2anime
  • Multi-GPU training:
python train.py --dataroot=datasets/selfie2anime --gpu=0,1,2,3 --batch_size=4  
  • The Weight --lambda_var=0.01
  • Compute density changing loss across images --var_all I have tested var_all=False.
  • Number of Flow Blocks --flow_blocks=1
  • Learning Rate of Flow --flow_lr=0.001
  • Different flows --flow_type=bnaf BNAF works best for me. Feel free to experiment other flows.

Pretrained Models

Evaluation Script of label2city

Different Pretrained-DRN and evaluation protocols can cause big performance gaps. So, I created a repository to upload the evaluation script of label2city. Hope the script could make the future evaluation easier.

Citation

If you use this code for your research, please cite our paper:

@inproceedings{xieunsupervised,
  title={Unsupervised Image-to-Image Translation with Density Changing Regularization},
  author={Xie, Shaoan and Ho, Qirong and Zhang, Kun},
  booktitle={Advances in Neural Information Processing Systems},
year=2022,
}

decent's People

Contributors

mid-push avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

decent's Issues

Good defaults and experiment configs

Very cool work! Wanted to see how much it has improved from previous methods but wondering what settings you are recommending. Did you use the default options or something else? I see you mention a few different flags in the readme which made me wonder.

The Weight --lambda_var=0.01
Compute density changing loss across images --var_all I have tested var_all=False.
Number of Flow Blocks --flow_blocks=1
Learning Rate of Flow --flow_lr=0.001
Different flows --flow_type=bnaf BNAF works best for me. Feel free to experiment other flows.

Multi GPU Training Issue

Hello, authors! Thanks for your excellent work.

I have trouble with multi-GPU training. My command line looks like this:

python train.py --dataroot $dataset_path--name $model_name--gpu 0,1,2,3 --batch_size 1

And the error is below:

Traceback (most recent call last):
  File "/home/shen/Rain/Methods/Decent/train.py", line 49, in <module>
    model.data_dependent_initialize(data)
  File "/home/shen/Rain/Methods/Decent/models/decent_gan_model.py", line 99, in data_dependent_initialize
    self.compute_F_loss().backward()                   # calculate graidents for F
  File "/home/shen/Rain/Methods/Decent/models/decent_gan_model.py", line 189, in compute_F_loss
    assert len(log_prob_a) == self.opt.batch_size * self.opt.num_patches
AssertionError

I print the values below for debugging.

print(f"{len(log_prob_a)} != {self.opt.batch_size} * {self.opt.num_patches}")

which gives me

0 ! = 1 * 256

Since len(log_prob_a) is 0, we get an empty list for log_prob_a in multi-GPU training.

Do you encounter this issue when training your models? How to solve this issue?

Error while training the network

Minimal steps for error replication:

Training on Nvida A6000 GPU
OS: Ubuntu 20.04
Cuda version: 12.0

Pip list:
absl-py 1.4.0
cachetools 4.2.4
certifi 2021.5.30
cffi 1.14.6
charset-normalizer 2.0.12
cycler 0.11.0
dataclasses 0.8
decorator 4.4.2
dominate 2.4.0
google-auth 2.18.1
google-auth-oauthlib 0.4.6
GPUtil 1.4.0
grpcio 1.48.2
idna 3.4
importlib-metadata 4.8.3
importlib-resources 5.4.0
jsonpatch 1.32
jsonpointer 2.3
kiwisolver 1.3.1
Markdown 3.3.7
matplotlib 3.3.4
mkl-fft 1.3.0
mkl-random 1.1.1
mkl-service 2.3.0
networkx 2.5.1
nflows 0.14
numpy 1.16.4
oauthlib 3.2.2
opencv-python 4.7.0.72
packaging 21.3
Pillow 8.4.0
pip 21.2.2
protobuf 3.19.6
pyasn1 0.5.0
pyasn1-modules 0.3.0
pycparser 2.21
pyparsing 3.0.9
python-dateutil 2.8.2
pyzmq 25.0.2
requests 2.27.1
requests-oauthlib 1.3.1
rsa 4.9
scipy 1.5.2
setuptools 58.0.4
six 1.16.0
tensorboard 2.10.1
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
torch 1.4.0
torch-fidelity 0.3.0
torchfile 0.1.0
torchvision 0.5.0
tornado 6.1
tqdm 4.64.1
typing_extensions 4.1.1
urllib3 1.26.16
visdom 0.2.4
websocket-client 1.3.1
Werkzeug 2.0.3
wheel 0.37.1
zipp 3.6.0

Dataset used: horse2zebra
Batch size: 24, single GPU only

Command for training:
python train.py --dataroot=datasets/horse2zebra/ --gpu=1 --batch_size=24

Error obtained:
learning rate = 0.0002000
(epoch: 10, iters: 120, time: 0.141, data: 0.008) G: nan G_GAN: nan D_real: nan D_fake: nan idt: nan var: nan nll_A: nan nll_B: nan exp_A: nan exp_B: nan
(epoch: 10, iters: 720, time: 0.139, data: 0.010) G: nan G_GAN: nan D_real: nan D_fake: nan idt: nan var: nan nll_A: nan nll_B: nan exp_A: nan exp_B: nan
(epoch: 10, iters: 1320, time: 0.138, data: 0.008) G: nan G_GAN: nan D_real: nan D_fake: nan idt: nan var: nan nll_A: nan nll_B: nan exp_A: nan exp_B: nan
[*] start evaluation!
datasets/horse2zebra/testB
./checkpoints/debug/horse2zebra_AtoB/var0.01_np256_nb1_nl0_nd10_lr0.001_ema0.998_var_single/fake
Traceback (most recent call last):
File "train.py", line 82, in
eval_dict = eval_loader(model, test_loader_A, test_loader_B, opt.run_dir, opt)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad
return func(*args, **kwargs)
File "/home/user/manjunath/GAN/Decent/models/utils.py", line 76, in eval_loader
return eval_loader_one(model, test_loader_a, test_loader_b, output_directory, opt)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad
return func(*args, **kwargs)
File "/home/user/manjunath/GAN/Decent/models/utils.py", line 97, in eval_loader_one
eval_dict = eval_method_one(real_dir, fake_dir, opt)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad
return func(*args, **kwargs)
File "/home/user/manjunath/GAN/Decent/models/utils.py", line 115, in eval_method_one
metric_dict_AB = torch_fidelity.calculate_metrics(input1=realB_path, input2=fakeB_path, **eval_args)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/torch_fidelity/metrics.py", line 258, in calculate_metrics
metric_fid = fid_statistics_to_metric(fid_stats_1, fid_stats_2, get_kwarg('verbose', kwargs))
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/torch_fidelity/metric_fid.py", line 47, in fid_statistics_to_metric
covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/scipy/linalg/_matfuncs_sqrtm.py", line 161, in sqrtm
A = _asarray_validated(A, check_finite=True, as_inexact=True)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/scipy/_lib/_util.py", line 263, in _asarray_validated
a = toarray(a)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/numpy/lib/function_base.py", line 498, in asarray_chkfinite
"array must not contain infs or NaNs")
ValueError: array must not contain infs or NaNs

Noticed: learning rate is not changing after each epoch

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.