Giter Site home page Giter Site logo

ict's Introduction

Image Completion Transformer (ICT)

This repository is the official pytorch implementation of our ICCV 2021 paper, High-Fidelity Pluralistic Image Completion with Transformers.

Ziyu Wan1, Jingbo Zhang1, Dongdong Chen2, Jing Liao1
1City University of Hong Kong, 2Microsoft Cloud AI

🎈 Prerequisites

  • Python >=3.6
  • PyTorch >=1.6
  • NVIDIA GPU + CUDA cuDNN
pip install -r requirements.txt

To directly inference, first download the pretrained models from Dropbox, then

cd ICT
wget -O ckpts_ICT.zip https://www.dropbox.com/s/we886b1fqf2qyrs/ckpts_ICT.zip?dl=1
unzip ckpts_ICT.zip

If Dropbox doesn't work for you, please try the Baidu Driver. Verification Code: 6g4f

Another option to download the checkpoints is using OneDrive.

Some tips:

  • Masks should be binarized.
  • The extensions of images and masks should be .png.
  • The model is trained for 256x256 input resolution only.
  • Make sure that the downsampled (32x32 or 48x48) mask could cover all the regions you want to fill. If not, dilate the mask.

🌟 Pipeline

Why transformer?

Compared with traditional CNN-based methods, transformers have better capability in understanding shape and geometry.

🚀 Training

1) Transformer

cd Transformer
python main.py --name [exp_name] --ckpt_path [save_path] \
               --data_path [training_image_path] \
               --validation_path [validation_image_path] \
               --mask_path [mask_path] \
               --BERT --batch_size 64 --train_epoch 100 \
               --nodes 1 --gpus 8 --node_rank 0 \
               --n_layer [transformer_layer #] --n_embd [embedding_dimension] \
               --n_head [head #] --ImageNet --GELU_2 \
               --image_size [input_resolution]

Notes of transformer:

  • --AMP: Reduce the memory cost while training, but sometimes will lead to NAN.
  • --use_ImageFolder: Enable this option while training on ImageNet
  • --random_stroke: Generate the mask on-the-fly.
  • Our code is also ready for training on multiple machines.

2) Guided Upsampling

cd Guided_Upsample
python train.py --model 2 --checkpoints [save_path] \
                --config_file ./config_list/config_template.yml \
                --Generator 4 --use_degradation_2

Notes of guided upsampling:

  • --use_degradation_2: Bilinear downsampling. Try to match the transformer training.
  • --prior_random_degree: Stochastically deviate the sequence elements by K nearest neighbour.
  • Modify the provided config template according to your own training environments.
  • Training the upsample part won't cost many GPUs.

⚡ Inference

We provide very covenient and neat script for inference.

python run.py --input_image [test_image_folder] \
              --input_mask [test_mask_folder] \
              --sample_num 1  --save_place [save_path] \
              --ImageNet --visualize_all

Notes of inference:

  • --sample_num: How many completion results do you want?
  • --visualize_all: You could save each output result via disabling this option.
  • --ImageNet --FFHQ --Places2_Nature: You must enable one option to select corresponding ckpts.
  • Please use absolute path.

More results

FFHQ

Places2

ImageNet

⏳ To Do

  • Release training code
  • Release testing code
  • Release pre-trained models
  • Add Google Colab

📔 Citation

If you find our work useful for your research, please consider citing the following papers :)

@article{wan2021high,
  title={High-Fidelity Pluralistic Image Completion with Transformers},
  author={Wan, Ziyu and Zhang, Jingbo and Chen, Dongdong and Liao, Jing},
  journal={arXiv preprint arXiv:2103.14031},
  year={2021}
}

The real-world application of image inpainting is also ready! Try and cite our old photo restoration algorithm here.

@inproceedings{wan2020bringing,
title={Bringing Old Photos Back to Life},
author={Wan, Ziyu and Zhang, Bo and Chen, Dongdong and Zhang, Pan and Chen, Dong and Liao, Jing and Wen, Fang},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={2747--2757},
year={2020}
}

💡 Acknowledgments

This repo is built upon minGPT and Edge-Connect. We also thank the provided cluster centers from OpenAI.

📨 Contact

This repo is currently maintained by Ziyu Wan (@Raywzy) and is for academic research use only. Discussions and questions are welcome via [email protected].

ict's People

Contributors

josephrocca avatar raywzy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ict's Issues

Pretrained model

Could you release the pretrained models? The given link is empty.

New easy to use inpanting method with transformers

Dear reasercher, please also consider checking our newly introduced face inpainting method to address the symmetry problems of general inpainting mehthods by using swin transformer and semantic aware discriminators.
Our proposed method showed better results in terms of fid score and newly proposed metric which focus on the face symmetry compared to some of the state-of-the-art methods including lama.
Our paper is availabe at:
https://www.researchgate.net/publication/366984165_SFI-Swin_Symmetric_Face_Inpainting_with_Swin_Transformer_by_Distinctly_Learning_Face_Components_Distributions

The code also will be published in:
https://github.com/mohammadrezanaderi4/SFI-Swin

TypeError: 'NoneType' object does not support item assignment

Hello, when I am running“Guided Upsampling
cd Guided_Upsample
python train.py --model 2 --checkpoints [save_path]
--config_file ./config_list/config_template.yml
--Generator 4 --use_degradation_2“ '',
the error message shows this. How can I modify the error
image

Pretrained weight of upsampler of Places does not work well.

Hi, @raywzy

I tried to use both FFHQ and Places2 pretrained weights of upsampler.
However, the Upsampler's pretrained weight of Places does not have enough quality.
We know that the weight of the generator is trained during 322000 iterations.
Do you think that the attached results are correct?

From left to right,
1st stage output | blended result with masked input image | raw output of upsampler | GT image | masked input image | raw output of upsampler within given mask

debug_0

Training Data Format

can you tell me the training data format?
as i understand

  1. training data contain images for the pretrained models 256 x 256 in directory for example X
  2. training data contain the masks of the images same size as the original train image for my example 256x256 in directory Y
  3. optional, if i used flag --random_stroke, i don't need the training data contain the masks
    if i am right, please confirm and if i am wrong please tell me the right

thanks in advance

weights

Why can't I download the weights?
image

About the validation_path of Transformer training.

There is a parameter--"validation_path" described as "validation_image_path" in training Transformers,but I couldn't find a parameter like "validation_mask_path" in "./Transformer/main.py".So does it mean that the validation set doesn't need its own mask set or it use the same mask set of training set or something else?Sorry for my poor English.

About the test effect is not good

It's a great job. But when i test, I have a problem.Why is this problem, the figure1 isgreat, and the figure is bad. Is the downsampled mask not cover all the regions? sorry my English is not vert good, I look forward to your reply.

image
image

num_sample in sample_mask function

Hi,

I am confused by the sample_mask function in transformer/utils/util.py, it seems that it does not use the argument num_sample but keeps num_sample=1, is it normal ?

Moreover, you use top_k=40 but the paper uses top_k=50. What is the best choice ?

Thanks,

inference time

How much time does inference need for one image ?
Thanks

Downloading pretrained models from a non-Baidu source

Hi, thanks for your work on this.
I'm trying to download pretrained models, but it looks like I can't download from Baidu Driver unless I have a Baidu account (which requires a Chinese phone number.)

Would it be possible to upload the model to non-Baidu source as well?

### Something Wrong ###

Trying to evaluate your codes:
WSL2 under Windows 10
Nvidia RTX3090

Upon start provided script for inference (regardless of using ckpts) , from the very beginning got a message:

### Something Wrong ###
  0%|  

After that calculation continue.
If "--FFHQ" or "--Places2_Nature" specified - inference finished with no error.

However if " ImageNet" specified - inference finished with an error:

raise AssertionError("Invalid device id")
AssertionError: Invalid device id

NVCC report:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_May__3_19:00:59_Pacific_Daylight_Time_2022
Cuda compilation tools, release 11.7, V11.7.64
Build cuda_11.7.r11.7/compiler.31294372_0

Tried to change line 44 in "run.py" from
CUDA_VISIBLE_DEVICES=0,1
to
CUDA_VISIBLE_DEVICES=0

-- no luck.

Found the reason for message:

### Something Wrong ###
  0%|  

That happened due to the different file image/mask names (they should be identical).
However the rest of the issue still exists.
Beside of that after finishing inference (regardless specified ckpts) "output" sub folder in the folder "Guided_Upsample" got created but it is empty. "output" sub folder in the folder "Transformer" got created and consists generated image with 32 or 48 pixel.

RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:75] data. DefaultCPUAllocator: not enough memory: you tried to allocate 17179869184 bytes. Buy new RAM!

(ICT) E:\ZSL\ICT-main\Transformer>python main.py --name ICT --ckpt_path ./ckpt --data_path E:\ZSL\ICT-main\data\train_256_png --validation_path E:\ZSL\ICT-main\data\test_256
_png --mask_path E:\ZSL\ICT-main\data\mask --BERT --batch_size 1 --train_epoch 100 --nodes 1 --gpus 1 --node_rank 0 --n_layer 12 --n_embd 512 --n_head 8 --GELU_2 --image_s
ize 256 --AMP

Mask is 2000, # Image is 1688

Mask is 2000, # Image is 303

Traceback (most recent call last):
File "main.py", line 139, in
mp.spawn(main_worker, nprocs=opts.gpus, args=(opts,))
File "D:\anaconda\envs\ICT\lib\site-packages\torch\multiprocessing\spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "D:\anaconda\envs\ICT\lib\site-packages\torch\multiprocessing\spawn.py", line 188, in start_processes
while not context.join():
File "D:\anaconda\envs\ICT\lib\site-packages\torch\multiprocessing\spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "D:\anaconda\envs\ICT\lib\site-packages\torch\multiprocessing\spawn.py", line 59, in _wrap
fn(i, args)
File "E:\ZSL\ICT-main\Transformer\main.py", line 57, in main_worker
IGPT_model=GPT(model_config)
File "E:\ZSL\ICT-main\Transformer\models\model.py", line 142, in init
self.blocks = nn.Sequential(
[Block_2(config) for _ in range(config.n_layer)])
File "E:\ZSL\ICT-main\Transformer\models\model.py", line 142, in
self.blocks = nn.Sequential(*[Block_2(config) for _ in range(config.n_layer)])
File "E:\ZSL\ICT-main\Transformer\models\model.py", line 88, in init
self.attn = CausalSelfAttention(config)
File "E:\ZSL\ICT-main\Transformer\models\model.py", line 44, in init
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:75] data. DefaultCPUAllocator: not enough memory: you tried to allocate 17179869184 bytes. Buy new RAM!
这个问题是什么导致的

Google Colab demo

I'm been playing around trying to create a minimal Google Colab demo for this repository, but am running into some error messages that you can see here:

As you can see, the ImageNet inference command (same as in README) throws an error like this:

RuntimeError: Error(s) in loading state_dict for InpaintGenerator_5:
	Missing key(s) in state_dict: "encoder.1.weight", "encoder.1.bias", "encoder.3.weight", ...............
	Unexpected key(s) in state_dict: "module.encoder.1.weight", "module.encoder.1.bias", ...............

I haven't dug into the other (FFHQ) error yet - it says something about transparency, which may be a mistake on my part with preparing the input image and mask (image, mask).

Thanks for your work on this repo and publicly releasing the code and pre-trained models! Can't wait to try it out.

Obtaining completion probability maps (from paper)

Hello,

First of all, super cool model, and thanks for being so helpful with past questions. I was just wondering specifically how you generated pixel-wise completion probability maps as in Fig. 9 of your paper (I get how it's done in theory, I just wanted to see code if possible).

Thanks!

Transformer training problem

Hello,
Congrats for your nice works.
I use a 16G GPU, but a single card can only run batch_size 3. I turned on mixed precision training, and the other settings are --n_layer 35 --n_embd 512 --n_head 8, which is the same as your model trained on Places2.
So I want to know how do you use 8 GPUs and set the batch_size to 64 to train the transformer model?

Train on own dataset

Hello,

Congrats for your amazing works. Could you explain the dataset specification for training on own dataset ?

  • Which resolution ?
  • How many images ?
  • Which specification to create masks ?
  • Which size for the masks ?

Could you provide a little example of the dataset with few examples ?

And finally, what is the structure folder for the dataset ?

Have a good day,

训练问题

您好,代码是否支持在Windows系统上调试?如果要使用自己的数据集训练该如何修改配置文件?

Will your kmeans data work well on other domains?

Hi,
Thanks for sharing great work! I have a question about kmeans_centers.npy.
According to your paper, you use clustering data generated from ImageNet to reduce the computational cost.

To further reduce the dimension and faithfully re-represent the low-resolution image, 
an extra visual vocabulary with spatial size 512 × 3 is generated using KMeans cluster centers of the whole ImageNet [8] RGB pixel spaces. 

Will your clustering data work well for other domains (like faces, paintings or maps)?

Pretrained Model link not working

Hello, thank you for the great work. I tried to download the pretrained model but I get this error:

--2021-11-23 18:50:20-- https://www.dropbox.com/s/cqjgcj0serkbdxd/ckpts_ICT.zip?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:6018:18::a27d:312
Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/dl/cqjgcj0serkbdxd/ckpts_ICT.zip [following]
--2021-11-23 18:50:20-- https://www.dropbox.com/s/dl/cqjgcj0serkbdxd/ckpts_ICT.zip
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 404 Not Found
2021-11-23 18:50:20 ERROR 404: Not Found.

The link doesn't seem to work.

How much RAM does inference need?

I tried to get it working in Google Colab, but it does seem that 25GB RAM is not enough and it seems to crash during the first step. How much RAM is expected?
unknown

Attention Mask

Hi, thanks for your nice works.
There are some details that bothered me. I would appreciate it if you could give me some advice.

  1. BERT choice
    I noticed --BERT option was be used in all Transformer training and inference processes. Which situation we do not need to select this option?
  2. attention mask
    As the paper described, transformer model capture the unmasked information to predict the probability distribution of missing regions. image. But in the code CausalSelfAttention , I found model will capture information at all position, and attention filling mask does not be used except input occlusion. How can it guarantee to just pay attention on unmasked information?
  3. auto-regressive
    As far as I understand, the model generate all masked pixels by a end to end mode rather than auto-regressive mode. During the inference, the model generate one pixel each iteration to improve sampling quality. If it works like I said, how do we guarantee the pixel quality of the first masked position in each iteration during the inference process?

Thanks again.

Why not end-to-end network?

Thank you for proposing this good idea of using Transformer as a priori information!
But why not use an end-to-end network for training, is it because the effect is not good?

About the test

Hello, I finished two stages of training, and ready to test, but the test path is correct, but can not load the test set inside the images.Can you help me? Thank you!

Minimum recommended GPUs

Hello, what do you think is the minimum recommended GPU specs (memory etc) for good performance, both for training on a new dataset and for testing the pluralistic completion? Thanks!

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument

Hi, thank you for sharing code.

I want to run training code on grayscale images, but i got following error.

# Mask is 12022, # Image is 12022
# Mask is 12022, # Image is 0
Warnning: There is no trained model found. An initialized model will be used.
Warnning: There is no previous optimizer found. An initialized optimizer will be used.
Resume from Epoch 0
Traceback (most recent call last):
  File "main.py", line 139, in <module>
    mp.spawn(main_worker, nprocs=opts.gpus, args=(opts,))
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/home/naoki/ICT/Transformer/main.py", line 73, in main_worker
    trainer.train(loaded_ckpt)
  File "/home/naoki/ICT/Transformer/DDP_trainer.py", line 203, in train
    run_epoch('train')
  File "/home/naoki/ICT/Transformer/DDP_trainer.py", line 139, in run_epoch
    logits, loss = model(x, y)
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 799, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/naoki/ICT/Transformer/models/model.py", line 254, in forward
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward

I checked the size and data type before cross entropy loss.
'''
logits: torch.Size([12, 1024, 512]), torch.float32
targets: torch.Size([12, 1024]), torch.float32
'''

Could you give me how to solve this problem.
Thank you in advance.

Transformer train loss leads to Nan

Hi, @raywzy

I am trying to train the model on ImageNet with the following setting:
--data_path /ILSVRC2012/train --validation_path /ILSVRC2012/val --mask_path /PConv-Keras/irregular_mask/train_mask --batch_size 32 --train_epoch 100 --nodes 1 --gpus 8 --node_rank 0 --n_layer 35 --n_embd 1024 --n_head 8 --GELU_2 --image_size 32 --use_ImageFolder

But I am getting Nan for train and test loss (screenshot attached).
ImageNet

This is happening for smaller datasets like Paris streetview (#train_images: 14900, #test_images 100) as well (screenshot attached).
Paris_streetview

Any suggestions on how to fix this issue?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.