Giter Site home page Giter Site logo

strotss's Introduction

Style Transfer by Relaxed Optimal Transport and Self-Similarity (STROTSS)

Code for the paper https://arxiv.org/abs/1904.12785 (CVPR 2019)

webdemo: http://style.ttic.edu/

UPDATE 5/8/2020: David Futschik (https://github.com/futscdav) very kindly pointed out a bug in the feature extraction pipeline where the images were not properly normalized with imagenet's mean and standard deviation for each color channel. Fixing this dramatically improves results in many cases. He also has implemented a much faster and more memory efficient version of strotts (https://github.com/futscdav/strotss), it doesn't allow for spatial guidance.

Dependencies:

  • python3 >= 3.5
  • pytorch >= 1.0
  • imageio >= 2.2
  • numpy >= 1.1

Usage:

Unconstrained Style Transfer:

python3 styleTransfer.py {PATH_TO_CONTENT} {PATH_TO_STYLE} {CONTENT_WEIGHT} {MAX_SCALES}

The default content weight is 1.0 (for the images provided my personal favorite is 0.5, but generally 1.0 works well for most inputs). The content weight is actually multiplied by 16, see section 2.5 of paper for explanation. I recommend running the algorithm with MAX_SCALES set to 5.

The resolution of the output can be set on line 80 of styleTransfer.py; the current scale is 5, and produces outputs that are 512 pixels on the long side, setting it to 4 or 6 will produce outputs that are 256 or 1024 pixels on the long side respectively, most GPUs will run out of memory for settings of this variable above 6.

The output will appear in the same folder as 'styleTransfer.py' and be named 'output.png'

Spatially Guided Style Transfer:

python3 styleTransfer.py {PATH_TO_CONTENT} {PATH_TO_STYLE} {CONTENT_WEIGHT} -gr {PATH_TO_CONTENT_GUIDANCE} {PATH_TO_STYLE_GUIDANCE}

guidance should take the form of two masks such as these:

Content Mask Style Mask

where regions that you wish to map onto each other have the same color.

strotss's People

Contributors

nkolkin13 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

strotss's Issues

dec_lap_pyr parameter ?

hello nick.
in pyr_lap.py the function dec_lap_pyr takes levs as a parameter.
in st_helper you call the function with levs=5
when using more levels than 5 (as in your defaults) should levs be set to max_scl ?
thanks in advance
luc

error when run python3 styleTransfer.py content.jpg style2.jpg 1

"See the documentation of nn.Upsample for details.".format(mode))
(1024, 848)
torch.Size([3, 512, 424])

(1024, 848)
torch.Size([3, 64, 53])

(2613, 1920, 3)
torch.Size([3, 64, 47])

/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py:2622: UserWarning: nn.functional.upsample_bilinear is deprecated. Use nn.functional.interpolate instead.
warnings.warn("nn.functional.upsample_bilinear is deprecated. Use nn.functional.interpolate instead.")
1
Traceback (most recent call last):
File "styleTransfer.py", line 104, in
loss,canvas = run_st(content_path,style_path,content_weight,max_scl,coords,use_guidance_points,regions)
File "styleTransfer.py", line 59, in run_st
stylized_im, final_loss = style_transfer(stylized_im, content_im, style_path, output_path, scl, long_side, 0., use_guidance=use_guidance, coords=coords, content_weight=content_weight, lr=lr, regions=regions)
File "/home/happyhugo/STROTSS/st_helper.py", line 68, in style_transfer
z_s, style_ims = load_style_folder(phi2, paths, regions,ri, n_samps=-1, subsamps=1000, scale=long_side, inner=5)
File "/home/happyhugo/STROTSS/utils.py", line 321, in load_style_folder
r = F.upsample(r_temp,(style_im.size(3),style_im.size(2)),mode='bilinear')[0,0,:,:].numpy()
File "/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py", line 2458, in upsample
return interpolate(input, size, scale_factor, mode, align_corners)
File "/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py", line 2569, in interpolate
raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")
NotImplementedError: Got 5D input, but bilinear mode needs 4D input

[Q] Speed up the code

Great project! I am interested in Your work and would like to implement it in C#. I would like to know how to speed up the code as much as possible? What parameters could I change?

question about processing time at each scale

hello nick
while testing your code I noticed the processing time at each scale.
on the example I'm working on :
scl1 (long side 64): 250 iters takes 87s
scl2 (long side 128): 250 iters takes 63s
scl3 (long side 256): 250 iters takes 54s
scl4 (long side 512): 250 iters takes 60s
scl5 (long side 1024): 250 iters takes 98s

intuitively I would have thought the lower scales should be much faster that the upper scales.
because of the size of the tensors.
would it make sense to use less iterations at lower scale for example (CNNMRF gives you the ability to choose the number of iterations at each scale)
just wondering ?
luc

about implementation details

Hi, I am not clear about some details during implementations: 1024 locations are sampled from style and content images respectively (so is the output image X_(t) sampled accordingly?), then how to make sure every location of the output image is optimized? because of the resampling after each step of RMSprop? (if enough resampling is adopted, then the output image could be optimized?) can you help me figure it out?

...loss_g?

what does loss definition with '_g' mean? such as:

def dp_loss_g(X,Y,GX):

    d = X.size(1)

    X = X.transpose(0,1).contiguous().view(d,-1).transpose(0,1)
    Y = Y.transpose(0,1).contiguous().view(d,-1).transpose(0,1)
    GX = GX.transpose(0,1).contiguous().view(d,-1).transpose(0,1)

    betas,_ = torch.max(torch.pow(get_DMat(X, GX),1),1)
    betas = betas.unsqueeze(1).detach()
    betas = torch.matmul(betas,betas.transpose(0,1))

    Mx = get_DMat(X,X,1.,splits=[X.size(1)])
    Mx = Mx/Mx.sum(0,keepdim=True)

    My = get_DMat(Y,Y,1.,splits=[X.size(1)])
    My = My/My.sum(0,keepdim=True)

    d = torch.abs(betas*(Mx-My)).sum(0).mean()


    return d

and what's the meaning of 'max_scl' in styleTransfer.py?

Hi Nick! Great Paper, I had a small doubt!

In section 2 of the paper, where you preface about the objective function you say,

We describe the content term of our loss αlC in Section 2.2, and the style term lm + lr + 1/a lp in Section 2.3.

But section 2.2 talks about the style loss and section 2.3 talks about the content loss, is this right? If I have understood it right, αlC is the content loss and lm + lr + 1/a lp is the style loss, is this right? Pardon me if I have misunderstood it. Thanks in advance.

content weight question

hello and thanks a lot for the code.
I have a little confusion about content weight.
you multiply the user input by 16 and divide by 2 at the end of each iteration
so for max_scl = 5 and a user input of content weight = 1
scl = 1 : content weight = 16
scl = 2 : content weight = 8
scl = 3 : content weight = 4
scl = 4 : content weight = 2

shouldn't it be multiplied by 32 (or 2^max_scale)
not a big deal but just wanted to be sure.
thanks
luc

get black result pciture when using 3-channel gray picture as style image

When using 3-channel gray picture as style image , the result is alway black
It seems ok when modifing the following code.

  def pairwise_distances_cos(x, y):
    x_norm = torch.sqrt((x**2).sum(1).view(-1, 1))
    y_t = torch.transpose(y, 0, 1)
    y_norm = torch.sqrt((y**2).sum(1).view(1, -1))
    #dist = 1.-torch.mm(x, y_t)/(x_norm+1e-10)/(y_norm+1e-10)
    dist = 1.-torch.mm(x, y_t)/x_norm/y_norm
    return dist

Could not find a format to read the specified file in mode 'i'

Hi Nick!

When l try to run your code I getting this error and l can't solve,so l'm sorry to disturb you.

Traceback (most recent call last):
File "styleTransfer.py", line 99, in
regions = [[imread(content_path)[:,:,0]*0.+1.], [imread(style_path)[:,:,0]*0.+1.]]
File "/home/lyx/anaconda3/lib/python3.6/site-packages/imageio/core/functions.py", line 221, in imread
reader = read(uri, format, "i", **kwargs)
File "/home/lyx/anaconda3/lib/python3.6/site-packages/imageio/core/functions.py", line 139, in get_reader
"Could not find a format to read the specified file " "in mode %r" % mode
ValueError: Could not find a format to read the specified file in mode 'i'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "styleTransfer.py", line 101, in
regions = [[imread(content_path)[:,:]*0.+1.], [imread(style_path)[:,:]*0.+1.]]
File "/home/lyx/anaconda3/lib/python3.6/site-packages/imageio/core/functions.py", line 221, in imread
reader = read(uri, format, "i", **kwargs)
File "/home/lyx/anaconda3/lib/python3.6/site-packages/imageio/core/functions.py", line 139, in get_reader
"Could not find a format to read the specified file " "in mode %r" % mode
ValueError: Could not find a format to read the specified file in mode 'i'

My content data set from Coco2014,style data set come from WikiArt.
l try Unconstrained Style Transfer,input PATH_TO_CONTENT ,PATH_TO_STYLE and CONTENT_WEIGHT.
Looking forward to your reply.
Thanks in advance!

Should create a pull request?

Hi Nick!

When you try to run the code on a non-cuda device like for example - MacBook Pro, I keep getting this error

File "/Users/adarsh/anaconda/envs/py36/lib/python3.6/site-packages/torch/cuda/init.py", line 75, in _check_driver
raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled

To fix this you can edit line 88 in STROTSS/utils.py from
return tensor.cuda() to return tensor
And also change stylize_objectives.py line 15 from self.z_dist = torch.zeros(1).cuda() to self.z_dist = torch.zeros(1)

And starts working fine again. However GPU can definitely speed up the process.
Thought should include this for anyone who might be facing this problem. Thanks in advance!

What Python is this?

Thanks for the code but I think its broken now.

from .st_helper import *
gives error
ModuleNotFoundError: No module named 'main.st_helper'; 'main' is not a package

Never seen imports like this. Is this some new Python feature?

Once I fixed all the periods I still ended up with an error on max scale of List Index out of range on the inference script.

output path and file name

My apologies, but I seem to have gone illiterate.
How does a person take an image from one folder, put it through STROTSS, and have it reach the other side in a new folder, with the original file name?
I've tried adjusting output_path through command line and in the code itself, but I'm just not getting my head around it.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.