diyiiyiii / stytr-2 Goto Github PK
View Code? Open in Web Editor NEWStyTr2 : Image Style Transfer with Transformers
StyTr2 : Image Style Transfer with Transformers
您好~感谢开源。
按照test.py 执行,加入原图和风格图像以后,在output 里面没有结果生成,也没有报错,请问是什么原因呢
Hello, thank you very much for sharing your code.
And I have a question about the weight setting of Loss. How do you determine the weight of Loss?
Just get it by experiments or throught some other ways?
Hello, I would like to know when you will provide the code.
Thanks for your sharing your code. It's a wonderful job I think~~
I have one question about the content loss score. I have applied StyTr2 to a dataset of 800 images, using your pre-trained model. To ensure consistency with your test settings, I resized all the images to 256x256 before calculating the content loss. However, I have noticed significant differences in the content loss values compared to what is reported in your papers.
I understand that variations in scores are expected due to the use of different images. Nonetheless, I found that the style loss scores exhibit a similar trend, while the content loss scores demonstrate noticeable discrepancies. So May I know how can you calculate the content loss? Is it possible to share your metric code or tell me where I am wrong?
#!/usr/bin/env python3
import argparse
import os
import torch
import torch.nn as nn
from tqdm import tqdm
import cv2
parser = argparse.ArgumentParser()
parser.add_argument("--resize", type=int, default=256, help="resize_image_size")
parser.add_argument("--content_dir", default=r'\input\content', help="the directory of content images")
parser.add_argument("--style_dir", default=r'\input\style', help="the directory of style images")
parser.add_argument("--stylized_dir", default=r\StyTR-2-main\output', required=False, help="the directory of stylized images")
parser.add_argument("--log_path", default=r't\metrics', required=False, help="the directory of stylized images")
parser.add_argument('--mode', type=int, default=1, help="0 for style loss, 1 for content loss, 2 for both")
args = parser.parse_args()
device = torch.device("cuda")
vgg = nn.Sequential(
nn.Conv2d(3, 3, (1, 1)),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(3, 64, (3, 3)),
nn.ReLU(), # relu1-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, (3, 3)),
nn.ReLU(), # relu1-2
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 128, (3, 3)),
nn.ReLU(), # relu2-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, (3, 3)),
nn.ReLU(), # relu2-2
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 256, (3, 3)),
nn.ReLU(), # relu3-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-4
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 512, (3, 3)),
nn.ReLU(), # relu4-1, this is the last layer used
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-4
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU() # relu5-4
)
vgg.eval()
vgg.load_state_dict(torch.load("../models/vgg_normalised.pth"))
enc_1 = nn.Sequential(*list(vgg.children())[:4]) # input -> relu1_1
enc_2 = nn.Sequential(*list(vgg.children())[4:11]) # relu1_1 -> relu2_1
enc_3 = nn.Sequential(*list(vgg.children())[11:18]) # relu2_1 -> relu3_1
enc_4 = nn.Sequential(*list(vgg.children())[18:31]) # relu3_1 -> relu4_1
enc_5 = nn.Sequential(*list(vgg.children())[31:44]) # relu4_1 -> relu5_1
enc_1.to(device)
enc_2.to(device)
enc_3.to(device)
enc_4.to(device)
enc_5.to(device)
def calc_content_loss(input, target):
assert (input.size() == target.size())
return torch.nn.MSELoss()(input, target)
content_dir = args.content_dir
style_dir = args.style_dir
stylized_dir = args.stylized_dir
log_dir = args.log_path
stylized_files = os.listdir(stylized_dir)
folder_components = stylized_dir.split(os.path.sep)
name = folder_components[-2]
sub_name = folder_components[-1]
log_path = os.path.join(args.log_path, name + '_log.txt')
with torch.no_grad():
if args.mode == 1 or args.mode == 2:
loss_c_sum = 0.
count = 0
for i, stylized in enumerate(tqdm(stylized_files)):
stylized_img = cv2.imread(stylized_dir + os.sep + stylized) # stylized image
if stylized_img is None or stylized_img.size == 0:
print('Failed to load the image:', stylized_dir + os.sep + stylized)
stylized_img = cv2.resize(stylized_img, (args.resize, args.resize))
name = stylized.split("_stylized_") # parse the content image's name
content_img = cv2.imread(content_dir + os.sep + name[0] + '.jpg') # content image
if content_img is None or content_img.size == 0:
print('Failed to load the image:', content_dir + os.sep + name[0] + '.jpg')
content_img = cv2.resize(content_img, (args.resize, args.resize))
stylized_img = torch.tensor(stylized_img, dtype=torch.float)
stylized_img = stylized_img/255
stylized_img = torch.unsqueeze(stylized_img, dim=0)
stylized_img = stylized_img.permute([0, 3, 1, 2])
stylized_img = stylized_img.cuda().to(device)
content_img = torch.tensor(content_img, dtype=torch.float)
content_img = content_img/255
content_img = torch.unsqueeze(content_img, dim=0)
content_img = content_img.permute([0, 3, 1, 2])
content_img = content_img.cuda().to(device)
loss_c = 0.
o1 = enc_4(enc_3(enc_2(enc_1(stylized_img))))
c1 = enc_4(enc_3(enc_2(enc_1(content_img))))
loss_c += calc_content_loss(o1, c1)
o2 = enc_5(o1)
c2 = enc_5(c1)
loss_c += calc_content_loss(o2, c2)
print("Content Loss: {}".format(loss_c / 2))
loss_c_sum += float(loss_c / 2)
count += 1
print("Total num: {}".format(count))
print("Average Content Loss: {}".format(loss_c_sum / count))
Hi, is there a way to generate output image while keeping the same ratio and resolution as per input content image? It seems currently it only support square
image.
For example, I have a content image with 1920x1080, a style image with 1000x500 and I would like to have an output image with 1920x1080 as well?
已知pillow<=7.2.0,不知道其他的怎么样
另外,AttributeError: 'tuple' object has no attribute 'cpu'怎么办呢
源代码用卷积的方法获得patch,每一个patch并不是切分出来的,每一个patch在图像上应该是没有相对位置的几何关系的,那么引入的位置编码仅作为内容信息的一个补充吗,好像patch和patch之间也没有位置关系,不是很理解,希望您的解答谢谢!
When is the code available?
It seems that the PatchEmbed module embeds images and divide it into pathes, which is implemented with dilated convolution and large kernel size. But how it realize the locality demonstrated in the Fig2 of the paper? That is, the patches in the framework figure are divided on the pixel grid continuously but I don't see any pixel partition operation in the code. Or is it just my misunderstanding of the paper?
Hi, I'm very interested in your work and want to reproduce it, could you share the style dataset?
非常感谢你做的工作,不知道作者能否将环境的更加详细的配置发送一下,非常感谢
I replace the code with yours, but I got the output as this![result](https://user-images.githubusercontent.com/53261018/186422373-0678c37a-8d2d-45ac-97cd-b772c99f0368.jpg)
Originally posted by @phhandong in #10 (comment)
Would like to learn how the parameters in test.py impact the final prediction results. Specifically, style_interpolation_weights, and alpha, seems did not find any reference in the source code. Wondering is there any runtime parameters we can use to tweak the stylization results? Thanks
I try to train this code on 23090, but it still replys that CUDA out of memory. In your paper, you say that your model is trained on 2 Tesla V100 and 2 GeForce RTX 3090 GPUs spending about one day. I want to know whether this code can be trained on 23090. Than you for your contribution.
How can I get the code?
您好,最近在阅读您的这篇文章,有以下疑惑:
In this paper, "typical CNN-based style transfer methods are biased toward content representation by visualizing the content leak of the stylization process",
CNN因为ArtFLow中的那3个原因,不是会破坏内容特征吗?为什么说偏向于内容表征呢?
1. Reconstruction error
Although an image reconstruction loss [32] or a content loss [20] is used to train the decoder, Li et al. [32] acknowledge that the decoder is far from perfect due to the loss of spatial information brought by the pooling operations in the encoder. Consequently, the accumulated image reconstruction error may gradually disturb the content details and lead to the Content Leak.
2. Biased decoder training
Due to Ls, the decoder is trained to trade off between Lc and Ls, rather than trying to reconstruct images perfectly. ...Consequently, the auto-encoder of AdaIN is biased towards rendering more artistic effects, which causes Content Leak. With the increase of the inference rounds, weird artistic patterns gradually appear in the produced results, which indicates that the auto-encoder of AdaIN may memorize image styles in training and bias towards the training styles in inference.
3.Biased style transfer module
Since such a patch replacement is irreversible, fc cannot be recovered from fcs, which makes fcs be biased towards style and consequently causes the Content Leak phenomenon.
谢谢。
Testing
Pretrained models: vgg-model, decoder[Coming SOON], Transformer_module [Coming SOON]
Please download them and put them into the floder ./experiments/
Traceback (most recent call last):
File "test.py", line 174, in
output = output.cpu()
Hi,
In the paper you wrote that lamba_content is 10 and lamba_style is 7, but in the code the values are inverted.
Also, lamba_l1 is 50 but in the provided code is set to 70.
So what are the corrected values in order to obtain the results of the paper?
Thanks for your awesome work!
Well, I wonder the scale of training dataset during your training stage. More specifically, how many images did you use in your content dataset and style dataset? I will appreciate it if you can tell more details.
请问多少G显存是合适的,或者你们用的是什么显卡
Hello. I have a problem when I'm trying to run train.py in google colab.
`/content/drive/MyDrive/dataset/metal
/content/drive/MyDrive/dataset/cardboard
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:566: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
Traceback (most recent call last):
File "/content/drive/MyDrive/StyTr/train.py", line 138, in <module>
{'params': network.module.transformer.parameters()},
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1208, in __getattr__
type(self).__name__, name))
AttributeError: 'StyTrans' object has no attribute 'module'`
How can I fix it?
Is the source code available?
I want to know more about how do you implement the calculation of style loss. Could you please share this part of code?
Hello, I am having an issue understanding the number of dataset samples used for training.
作者您好!我有一个问题,你们源码中每一个patch就是将图像输入到卷积之后一个通道维度的张量,但是原论文图中好像是对一张图像的不同区域分块之后得到patch,请问这是为啥
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.