Giter Site home page Giter Site logo

stytr-2's Issues

CUDA out of memory.

I try to train this code on 23090, but it still replys that CUDA out of memory. In your paper, you say that your model is trained on 2 Tesla V100 and 2 GeForce RTX 3090 GPUs spending about one day. I want to know whether this code can be trained on 23090. Than you for your contribution.

code

Is the source code available?

inference parameters

Would like to learn how the parameters in test.py impact the final prediction results. Specifically, style_interpolation_weights, and alpha, seems did not find any reference in the source code. Wondering is there any runtime parameters we can use to tweak the stylization results? Thanks

About the patch partition.

It seems that the PatchEmbed module embeds images and divide it into pathes, which is implemented with dilated convolution and large kernel size. But how it realize the locality demonstrated in the Fig2 of the paper? That is, the patches in the framework figure are divided on the pixel grid continuously but I don't see any pixel partition operation in the code. Or is it just my misunderstanding of the paper?

作者您好!对于位置编码有一些疑问

源代码用卷积的方法获得patch,每一个patch并不是切分出来的,每一个patch在图像上应该是没有相对位置的几何关系的,那么引入的位置编码仅作为内容信息的一个补充吗,好像patch和patch之间也没有位置关系,不是很理解,希望您的解答谢谢!

关于代码中构建的transformer的问题

代码中的pos_embed_c的shape好像在插值后跟content的shape不一样(除非content跟style的shape是一致的),在encode_c forward过程中两者会相加,是不是会报错呢?等待作者回答,谢谢!
1700881333932

Arbitrary output size instead of square

Hi, is there a way to generate output image while keeping the same ratio and resolution as per input content image? It seems currently it only support square image.

For example, I have a content image with 1920x1080, a style image with 1000x500 and I would like to have an output image with 1920x1080 as well?

About the scale of trainng dataset

Thanks for your awesome work!
Well, I wonder the scale of training dataset during your training stage. More specifically, how many images did you use in your content dataset and style dataset? I will appreciate it if you can tell more details.

About loss weight setting

Hello, thank you very much for sharing your code.
And I have a question about the weight setting of Loss. How do you determine the weight of Loss?
Just get it by experiments or throught some other ways?

Hi,please help me

Excuse me, could you please help me see why my prediction results look like this.

COCO_val2014_000000000192_stylized_anthony-van-dyck_portrait-of-prince-charles-louis-elector-palatine-1641

Error in running "train.py"

Hello. I have a problem when I'm trying to run train.py in google colab.

`/content/drive/MyDrive/dataset/metal
/content/drive/MyDrive/dataset/cardboard
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:566: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  cpuset_checked))
Traceback (most recent call last):
  File "/content/drive/MyDrive/StyTr/train.py", line 138, in <module>
    {'params': network.module.transformer.parameters()},
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1208, in __getattr__
    type(self).__name__, name))
AttributeError: 'StyTrans' object has no attribute 'module'`

How can I fix it?

Loss weight different values different from the paper

Hi,
In the paper you wrote that lamba_content is 10 and lamba_style is 7, but in the code the values are inverted.
Also, lamba_l1 is 50 but in the provided code is set to 70.
So what are the corrected values in order to obtain the results of the paper?

Request for elaboration regarding the Dataset

Hello, I am having an issue understanding the number of dataset samples used for training.

  1. The COCO 2014 Dataset contains 82k images. Have you trained on all these or a selected subset?
  2. The WIKIART dataset is huge and not all images in WIKIART are suitable for style transfer tasks, have you manually selected and prepared the Style Dataset and what is the size of the dataset?

content leak

您好,最近在阅读您的这篇文章,有以下疑惑:
In this paper, "typical CNN-based style transfer methods are biased toward content representation by visualizing the content leak of the stylization process",
CNN因为ArtFLow中的那3个原因,不是会破坏内容特征吗?为什么说偏向于内容表征呢?

1. Reconstruction error
Although an image reconstruction loss [32] or a content loss [20] is used to train the decoder, Li et al. [32] acknowledge that the decoder is far from perfect due to the loss of spatial information brought by the pooling operations in the encoder. Consequently, the accumulated image reconstruction error may gradually disturb the content details and lead to the Content Leak.
2. Biased decoder training
Due to Ls, the decoder is trained to trade off between Lc and Ls, rather than trying to reconstruct images perfectly. ...Consequently, the auto-encoder of AdaIN is biased towards rendering more artistic effects, which causes Content Leak. With the increase of the inference rounds, weird artistic patterns gradually appear in the produced results, which indicates that the auto-encoder of AdaIN may memorize image styles in training and bias towards the training styles in inference.
3.Biased style transfer module
Since such a patch replacement is irreversible, fc cannot be recovered from fcs, which makes fcs be biased towards style and consequently causes the Content Leak phenomenon.
谢谢。

running train.py get wrong

when I am running the train.py to train the dataset of myself, an error is reported.

image

I have no ideas to deal with this error, I don't know whether my pytorch environment is wrong.

image

Moreover,I used the RTX 4090 to run this train.py ,I thinked the memory was enough.

code

Hello, I would like to know when you will provide the code.

About the metric score of StyTr2:Image Style Transfer with Transformers

Thanks for your sharing your code. It's a wonderful job I think~~

I have one question about the content loss score. I have applied StyTr2 to a dataset of 800 images, using your pre-trained model. To ensure consistency with your test settings, I resized all the images to 256x256 before calculating the content loss. However, I have noticed significant differences in the content loss values compared to what is reported in your papers.

I understand that variations in scores are expected due to the use of different images. Nonetheless, I found that the style loss scores exhibit a similar trend, while the content loss scores demonstrate noticeable discrepancies. So May I know how can you calculate the content loss? Is it possible to share your metric code or tell me where I am wrong?

    #!/usr/bin/env python3
    import argparse
    import os
    import torch
    import torch.nn as nn
    from tqdm import tqdm
    import cv2
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--resize", type=int, default=256, help="resize_image_size")
    parser.add_argument("--content_dir", default=r'\input\content', help="the directory of content images")
    parser.add_argument("--style_dir", default=r'\input\style', help="the directory of style images")
    parser.add_argument("--stylized_dir", default=r\StyTR-2-main\output', required=False, help="the directory of stylized images")
    parser.add_argument("--log_path", default=r't\metrics', required=False, help="the directory of stylized images")
    parser.add_argument('--mode', type=int, default=1, help="0 for style loss, 1 for content loss, 2 for both")
    args = parser.parse_args()
    
    device = torch.device("cuda")
    vgg = nn.Sequential(
        nn.Conv2d(3, 3, (1, 1)),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(3, 64, (3, 3)),
        nn.ReLU(),  # relu1-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(64, 64, (3, 3)),
        nn.ReLU(),  # relu1-2
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(64, 128, (3, 3)),
        nn.ReLU(),  # relu2-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(128, 128, (3, 3)),
        nn.ReLU(),  # relu2-2
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(128, 256, (3, 3)),
        nn.ReLU(),  # relu3-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 256, (3, 3)),
        nn.ReLU(),  # relu3-2
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 256, (3, 3)),
        nn.ReLU(),  # relu3-3
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 256, (3, 3)),
        nn.ReLU(),  # relu3-4
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 512, (3, 3)),
        nn.ReLU(),  # relu4-1, this is the last layer used
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu4-2
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu4-3
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu4-4
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu5-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu5-2
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu5-3
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU()  # relu5-4
    )
    
    vgg.eval()
    vgg.load_state_dict(torch.load("../models/vgg_normalised.pth"))
    
    enc_1 = nn.Sequential(*list(vgg.children())[:4])  # input -> relu1_1
    enc_2 = nn.Sequential(*list(vgg.children())[4:11])  # relu1_1 -> relu2_1
    enc_3 = nn.Sequential(*list(vgg.children())[11:18])  # relu2_1 -> relu3_1
    enc_4 = nn.Sequential(*list(vgg.children())[18:31])  # relu3_1 -> relu4_1
    enc_5 = nn.Sequential(*list(vgg.children())[31:44])  # relu4_1 -> relu5_1
    
    enc_1.to(device)
    enc_2.to(device)
    enc_3.to(device)
    enc_4.to(device)
    enc_5.to(device)
    
    def calc_content_loss(input, target):
        assert (input.size() == target.size())
        return torch.nn.MSELoss()(input, target)
    
    content_dir = args.content_dir
    style_dir = args.style_dir
    stylized_dir = args.stylized_dir
    log_dir = args.log_path
    
    stylized_files = os.listdir(stylized_dir)
    folder_components = stylized_dir.split(os.path.sep)
    name = folder_components[-2]
    sub_name = folder_components[-1]
    log_path = os.path.join(args.log_path, name + '_log.txt')
    
    with torch.no_grad():
        if args.mode == 1 or args.mode == 2:
            loss_c_sum = 0.
            count = 0
    
            for i, stylized in enumerate(tqdm(stylized_files)):
                stylized_img = cv2.imread(stylized_dir + os.sep + stylized)   # stylized image
                if stylized_img is None or stylized_img.size == 0:
                    print('Failed to load the image:', stylized_dir + os.sep + stylized)
                stylized_img = cv2.resize(stylized_img, (args.resize, args.resize))
                name = stylized.split("_stylized_")  # parse the content image's name
                content_img = cv2.imread(content_dir + os.sep + name[0] + '.jpg')   # content image
                if content_img is None or content_img.size == 0:
                    print('Failed to load the image:', content_dir + os.sep + name[0] + '.jpg')
    
                content_img = cv2.resize(content_img, (args.resize, args.resize))
    
                stylized_img = torch.tensor(stylized_img, dtype=torch.float)
                stylized_img = stylized_img/255
                stylized_img = torch.unsqueeze(stylized_img, dim=0)
                stylized_img = stylized_img.permute([0, 3, 1, 2])
                stylized_img = stylized_img.cuda().to(device)
    
                content_img = torch.tensor(content_img, dtype=torch.float)
                content_img = content_img/255
                content_img = torch.unsqueeze(content_img, dim=0)
                content_img = content_img.permute([0, 3, 1, 2])
                content_img = content_img.cuda().to(device)
    
                loss_c = 0.
    
                o1 = enc_4(enc_3(enc_2(enc_1(stylized_img))))
                c1 = enc_4(enc_3(enc_2(enc_1(content_img))))
    
                loss_c += calc_content_loss(o1, c1)
    
                o2 = enc_5(o1)
                c2 = enc_5(c1)
                loss_c += calc_content_loss(o2, c2)
    
                print("Content Loss: {}".format(loss_c / 2))
                loss_c_sum += float(loss_c / 2)
                count += 1
    
            print("Total num: {}".format(count))
            print("Average Content Loss: {}".format(loss_c_sum / count))

dataset

Hi, I'm very interested in your work and want to reproduce it, could you share the style dataset?

请问各个packages的版本

已知pillow<=7.2.0,不知道其他的怎么样
另外,AttributeError: 'tuple' object has no attribute 'cpu'怎么办呢

实验条件

请问多少G显存是合适的,或者你们用的是什么显卡

没有结果生成

您好~感谢开源。
按照test.py 执行,加入原图和风格图像以后,在output 里面没有结果生成,也没有报错,请问是什么原因呢

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.