Giter Site home page Giter Site logo

Module export about e2cnn HOT 3 CLOSED

quva-lab avatar quva-lab commented on June 7, 2024
Module export

from e2cnn.

Comments (3)

Gabri95 avatar Gabri95 commented on June 7, 2024

Hi @danushv07

Unfortunately, SequentialModule can only contain EquivariantModules, you can not add a torch.nn.Linear module in it.

If I understand correctly, you want:

  • an equivariant architecture,
  • followed by a pooling layer
  • and a final linear layer (e.g. for classification)

You can achieve something similar this way:


net = SequentialModule(
    R2Conv(in_, out_, 3, bias=False),
    ReLU(out_, inplace=True),
    PointwiseMaxPool(out_, kernel_size=2, stride=2),
    GroupPooling(out_),
)

# this is the out_type of the last GroupPooling
final_feature_type = net.out_type    

# `out_channels` invariant outputs
output_type = e2cnn.nn.FieldType(s, [s.trivial_repr]*out_channels)

# add the final linear layer as a 1x1 convolution
net.add_module('classifier', R2Conv(final_feature_type, output_type, kernel_size=1)

The final R2Conv will be a 1x1 convolution which just behaves like your torch.nn.Linear, assuming the ouput of PointwiseMaxPool is a 1x1 feature map.
You should now be able to export() your model.
Note, however, that the output tensor will have shape B x out_channels x 1 x1 rather than B x out_channels, so you may need to do a manual reshaping.

Is my understanding correct? Does this help?

Best,
Gabriele

from e2cnn.

danushv07 avatar danushv07 commented on June 7, 2024

Thank you for the prompt reply @Gabri95 . The fore mentioned solution does work well. However, if torch.nn.Linear is required, the entire SequentialModule along with the linear layer can be wrapped in a torch.nn.Module and then used options such as .modules() or .children() can be used to export the required layers.

from e2cnn.

Gabri95 avatar Gabri95 commented on June 7, 2024

Hi @danushv07

I am not sure I understood what you mean exactly. Could you share a simple code snippet to illustrate your example?

Thanks,
Gabriele

from e2cnn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.