Giter Site home page Giter Site logo

Comments (5)

tscohen avatar tscohen commented on August 17, 2024 1

Numpy has the ndarray class, which has some great features like reshaping, broadcasting, etc. The idea is to have an array that has the same features, but stores not numbers but group elements or functions (feature maps). Using these abstractions we can implement the GConv in a way that exactly mirrors the mathematics in the paper. Specifically, if you look at make_gconv_indices.py, you can see that for each group, we just have to create a GFuncArray for the filter (we think of the filter as a function on G) and then apply each element of a group to it to create the transformed filters (in one single expression). E.g. for C4 we do x = np.random.randn(4, ksize, ksize); f = P4FuncArray(v=x); i = f.left_translation_indices(C4[:, None, None, None]).

In hindsight however, I think it would have been more clear to just create the index arrays directly without using any abstractions. If you feel like it, you could look at the output of the functions in make_gconv_indices.py (e.g. make_c4_z2_indices()), and write a simple loop to output the same indexing arrays. All the indexing arrays do in the end, is rotate the kernels and permute the orientation channels (if any). The only thing that happens during inference is basically conv2d(feature_maps, weights[precomputed_indices]). A simple / minimal implementation would be a great learning resource for people new to group convolutions. If you keep the code super simple and test it well, you could even try to get it merged into pytorch or tensorflow.

from groupy.

shuida avatar shuida commented on August 17, 2024

Got it. Thank you so much!

from groupy.

mvoelk avatar mvoelk commented on August 17, 2024

I wrote a super simple Keras layer.
https://github.com/mvoelk/keras_layers

from groupy.

tscohen avatar tscohen commented on August 17, 2024

Hi Markus, this looks great, thanks. I've been thinking for a long time that it would be nice if someone did a less abstract implementation...

Just a minor comment about the statement "The computation performed in the layer is still slower compared to standard convolution." -- This is true if one uses more channels (they get multiplied by 4 or 8), but if one uses the same number of channels then the difference is negligible. The only overhead is filter expansion. At test time, one can convert the G-CNN to a CNN by doing the filter expansion once, and then there is no overhead. Would be nice if you could mention this in the docs.

Another thing I've found to be useful is to add unit tests that check for equivariance.

from groupy.

mvoelk avatar mvoelk commented on August 17, 2024

I did experiments on CIFAR-10 with Conv2D 96 features and GroupConv2D 12 * 8 features. Conv2D takes 22ms/step, GroupConv2D 34ms/step. I assume that it is a matter of implementation. I do more reshapes on the kernel and also reshapes on the features, but that should not be so expensive... In this case I do not think that we get faster due to the fact that we have fewer parameters. What are your experiences?

channels (they get multiplied by 4 or 8)

I should mention that somewhere.

At test time, one can convert the G-CNN to a CNN by doing the filter expansion once, and then there is no overhead.

Good idea, I haven't thought of it yet.

The example in examples.ipynb checks for equivariance.

from groupy.

Related Issues (19)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.