Giter Site home page Giter Site logo

contrastive-unpaired-translation's Introduction

Contrastive Unpaired Translation (CUT)





We provide our PyTorch implementation for unpaired image-to-image translation based on patchwise contrastive learning and adversarial learning. No hand-crafted loss and inverse network is used. Compared to CycleGAN, model training is faster and less memory-intensive. In addition, our method can be extended to single image training, where each “domain” is only a single image.

Contrastive Learning for Unpaired Image-to-Image Translation
Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
UC Berkeley and Adobe Research
In ECCV 2020




Pseudo code

import torch
cross_entropy_loss = torch.nn.CrossEntropyLoss()

# Input: f_q (BxCxS) and sampled features from H(G_enc(x))
# Input: f_k (BxCxS) are sampled features from H(G_enc(G(x))
# Input: tau is the temperature used in NCE loss.
# Output: PatchNCE loss
def PatchNCELoss(f_q, f_k, tau=0.07):
    # batch size, channel size, and number of sample locations
    B, C, S = f_q.shape

    # calculate v * v+: BxSx1
    l_pos = (f_k * f_q).sum(dim=1)[:, :, None]

    # calculate v * v-: BxSxS
    l_neg = torch.bmm(f_q.transpose(1, 2), f_k)

    # The diagonal entries are not negatives. Remove them.
    identity_matrix = torch.eye(S)[None, :, :]
    l_neg.masked_fill_(identity_matrix, -float('inf'))

    # calculate logits: (B)x(S)x(S+1)
    logits = torch.cat((l_pos, l_neg), dim=2) / tau

    # return NCE loss
    predictions = logits.flatten(0, 1)
    targets = torch.zeros(B * S, dtype=torch.long)
    return cross_entropy_loss(predictions, targets)

Example Results

Unpaired Image-to-Image Translation

Single Image Unpaired Translation

Russian Blue Cat to Grumpy Cat

Parisian Street to Burano's painted houses

Prerequisites

  • Linux or macOS
  • Python 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Getting started

  • Clone this repo:
git clone https://github.com/taesungp/contrastive-unpaired-translation CUT
cd CUT
  • Install PyTorch 1.4 and other dependencies (e.g., torchvision, func-timeout, gputil).

For pip users, please type the command pip install -r requirements.txt. For Conda users, we provide an installation script scripts/conda_deps.sh. Alternatively, you can create a new Conda environment using conda env create -f environment.yml.

CUT and FastCUT Training and Test

  • Download the grumpify dataset (Fig 8 of the paper. Russian Blue -> Grumpy Cats)
bash ./datasets/download_cut_dataset.sh grumpifycat

The dataset is downloaded and unzipped at ./datasets/grumpifycat/.

The other datasets can be downloaded using

bash ./datasets/download_cut_dataset.sh [dataset_name]

, a script provided by the CycleGAN repo.

  • Train the model:
# Trains the CUT model
python train.py --dataroot ./datasets/grumpifycat --name grumpycat_CUT --CUT_mode CUT

# Trains the FastCUT model
python train.py --dataroot ./datasets/grumpifycat --name grumpycat_FastCUT --CUT_mode FastCUT

The checkpoints are stored at ./checkpoints/grumpycat_*/web.

Training using our launcher scripts

Please see experiments/grumpifycat_launcher.py that generates the above command line arguments. The launcher scripts are useful for configuring rather complicated command-line arguments of training and testing.

Using the launcher, the command below generates the training command of CUT and FastCUT.

python -m experiments grumpifycat train 0
python -m experiments grumpifycat train 1

To test using the laucher,

python -m experiments grumpifycat test 0
python -m experiments grumpifycat test 1

Possible commands are run, run_test, launch, close, and so on. Please see experiments/main.py for all commands

Apply a pre-trained CUT model and evaluate

The tutorial for applying pretrained models will be released soon.

SinCUT Single Image Unpaired Training

The tutorial for the Single-Image Translation will be released soon.

Citation

If you use this code for your research, please cite our paper.

@inproceedings{park2020cut,
  title={Contrastive Learning for Unpaired Image-to-Image Translation},
  author={Taesung Park and Alexei A. Efros and Richard Zhang and Jun-Yan Zhu},
  booktitle={European Conference on Computer Vision},
  year={2020}
}

Acknowledgments

We thank Allan Jabri and Phillip Isola for helpful discussion and feedback. Our code is developed based on pytorch-CycleGAN-and-pix2pix. We also thank pytorch-fid for FID computation and drn for mIoU computation, and stylegan2-pytorch for the PyTorch implementation of StyleGAN2 used in single-image translation.

contrastive-unpaired-translation's People

Contributors

taesungp avatar

Stargazers

 avatar

Watchers

 avatar  avatar

Forkers

monup165

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.