Giter Site home page Giter Site logo

Support for ConvTransposeNd about einconv HOT 8 OPEN

f-dangel avatar f-dangel commented on June 12, 2024
Support for ConvTransposeNd

from einconv.

Comments (8)

f-dangel avatar f-dangel commented on June 12, 2024 2

Quick update: I pulled together the einconv.expressions.conv_transposeNd_forward.einsum_expression on the conv-transpose-expression branch. For now it is tested for N=1 and I will add cases for N=2,3 next, then merge into development.

With that, it should be relatively straightforward to make the following contributions (1. and 2. can be separate, but for each the testing part of 3. would be mandatory 😉)

  1. Add einconv.functionals.conv_transpose.py that contains an equivalent of PyTorch's nn.functional.conv_transpose{1,2,3}d. The code should be very similar to the convolution case in einconv.functionals.conv.py.
  2. Add einconv.modules.conv_transpose.py that contains an equivalent of PyTorch's nn.ConvTranspose{1,2,3}d. The code should be very similar to the convolution case in einconv.modules.conv.py. Also, one needs to use torch.nn.modules.conv._ConvTransposeNd._output_padding to compute the output_padding parameter if the user specifies output_size in forward. See here for an example.
  3. Add to API, add to documentation, add tests, then merge 🚀

from einconv.

f-dangel avatar f-dangel commented on June 12, 2024

Hi,
yes that indeed should be possible. Can you clarify what functionality of transpose convolution you need? Would be happy to offer advice how to implement it; maybe you can even submit a PR.

(I haven't had time to ship many additional operations, but they are all pretty similar. Feel free to check out Table B.2 in the paper for an extended overview)

Best,
Felix

from einconv.

JCBrouwer avatar JCBrouwer commented on June 12, 2024

My network has an initial step which extracts patches from the input (like a ViT) which I then need to reverse at the end of the network to get back the same original shape:

import torch

x = torch.randn(16, 3, 32, 32)

conv_in = torch.nn.Conv2d(3, 64, kernel_size=8, stride=8)
conv_out = torch.nn.ConvTranspose2d(64, 3, kernel_size=8, stride=8)

print(x.size())  # torch.Size([16, 3, 32, 32])

h = conv_in(x)

print(h.size())  # torch.Size([16, 64, 4, 4])
# generally there's a bunch more layers in the middle here e.g. a Transformer

y = conv_out(h)

print(y.size())  # torch.Size([16, 3, 32, 32])

The way I understand it, a ConvTranspose op is essentially just the transposed version of the lowered matrix form of the convolution. This makes me think that in an einsum the ConvTranspose would just amount to reversing the order of the input labels, although I'm not sure if my intuition maps that nicely to reality.

Is it clear enough what I'm looking to do?

from einconv.

f-dangel avatar f-dangel commented on June 12, 2024

Hi,

I think you're intuition should be correct in that you can simply exchange the role of spatial input and output dimensions in the corresponding einsum of the convolution.

May I ask why you cannot rely on the built-in transpose convolution?

from einconv.

JCBrouwer avatar JCBrouwer commented on June 12, 2024

I'd like my model to be Nd!

Right now I have to maintain 3 different copies of it for different data types, but by leveraging an Nd convolution I'd only have one.

from einconv.

f-dangel avatar f-dangel commented on June 12, 2024

Gotcha,

would you be interested in adding the equivalents of torch.nn.functional.conv_transpose{1,2,3}d and torch.nn.ConvTransposeNd to the library? I can provide the einsum expression.

I should also mention that the einsum variants of ConvNd and ConvTransposeNd are (unsurprisingly) usually slower than the highly optimized cudnn implementations. For your specific case where stride == kernel_size the performance gap should be relatively small, but I just want to be up-front transparent about this downside.

from einconv.

JCBrouwer avatar JCBrouwer commented on June 12, 2024

Yes, I think if you could help with the einsum expression then I can probably figure out how to add in the necessary functions to the library.

Reduced performance is definitely a downside. I'll have to benchmark it and see how big the gap is in practice. I think a minor performance degradation would still be worth it for a fully dimensionality agnostic architecture.

from einconv.

f-dangel avatar f-dangel commented on June 12, 2024

If performance is crucial, I suggest using PyTorch's built in convolution and something like

N = conv_dim
conv_cls = {1: Conv1d, 2: Conv2d, 3: Conv3d}[N]
conv_transpose_cls = {1: ConvTranspose1d, 2: ConvTranspose2d, 3: ConvTranspose3d}[N]

in your code.

I'll keep you updated on the einsum expression for transpose convolution 👍

from einconv.

Related Issues (2)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.