Giter Site home page Giter Site logo

Comments (7)

Gabri95 avatar Gabri95 commented on May 28, 2024 2

Hi @kristian-georgiev ,

I had a quick look at the new code in the gist.
I don't see any relevant mistake to be honest.

I don't know if you are still using the same equivariance check you had in the first gist.
Back then you were using 191x191 inputs.
I guess there has been some change in the architecture, so now you need to use 193x193 inputs to get 0 equivariance error (maybe you changed the first conv layer?).

Do you still have the same issue with the NaN loss?
I did not try to train this model to be honest, but I can have a better look if you hae some specific problem.

One note: I see you tried to make you code very general, supporting both SO(2) (with different frequencies) and C_N (for different values of N).
I'd recommend implementing two different models, since the kind of operations you use for SO(2) and C_N are generally different. Implementing the two models separately allows for a much simpler code, which is also more readable.

Let me know if you have more questions

Best,
Gabriele

from e2cnn.

Gabri95 avatar Gabri95 commented on May 28, 2024 1

Hi @kristian-georgiev

However, the current architecture seems to be unstable (loss consistently becomes NaN around the middle of the first epoch on ImageNet, for both max_frequency=-3 and -5); any pointers on what may be causing this are greatly appreciated.

To answer this, I will need to look a bit more in details in your architecture and probably try to run it myself.
Unfortunately, I don't have for it in this moment so, if you don't mind, I'll come back to you about this later.

I will answer the other (very good, btw) questions in the meantime:

I don't understand what is happening in these lines in the gated_normpool layers.

There, we use gated non-linearities only on the non-trivial channels (see line 1247).
For trivial channels, we still use ELU.
Because a gate for a gated non-linearity is a trivial field, we need to also add an additional trivial field for each non-trivial irrep in input.

what are S and M? It seems that they go unused(?)

You're right, they are there just for documentation. In the comment at line 1290, I use them to describe the total channels.
I is the number of fields which require a gate and, therefore, is also the number of gates (which are trivial fields) to be added.
S is the size of the features, which is equal to the size of all irreps which require a gate plus 1, the size of the trivial field using ELU.
M is the total size: features + #gates, i.e. S + I

Why do we not have the same setup in the hnet_normpool layers (here)?

That is a slightly different architecture.
We still apply ELU on each trivial irrep and an independent non-linearity to each non-trivial irrep.
Here, however, the independent non-linearity is a norm-relu, not a gated one.
While norm-relu is computed directly on the input irrep, the gated non-linearity requires an additional gate.
The code for the network using gate-nonlinearity is a bit more complex since I need to account for the additional number of parameters introduced in the model by adding this additional outputs (the gates) in each convolution block.

Where does the expression t /= 16 * s ** 2 * 3 / 4 (from here) come from ... ?

This is just a simple heuristic I created manually to ensure the model has roughly the same number of parameters of the C_16 regular GCNN. I should work for different frequencies of the HNET but only if you compare to the specific C_16 architecture I used in those experiments. I would not trust that formula in another setting.

Is there any benefit/harm in using max_frequency higher than the highest irrep used? E.g. have max_frequency=10 but use irreps of frequency at most 3?

The answer is "it depends" 😅

  • if you use SO(2) (or O(2)) equivariance, the filters mapping an irrep of freq m to an irrep of freq n can have both frequency n+m and n-m. That means that, even if your irreps have max frequency 3, you could still use some filter of frequency 6 (e.g. mapping from 3 to 3).
  • if you use a discrete group, e.g. C_N, and map from irrep of freq n to irrep of freq m, you have filters of freq m+n and m-n as before, but also any filter with frequency n+m +tN or n-m+tN for any integer t.
    In the first case, having max_frequency = 10 but only irreps <= 3 is the same as max_frequency=6.
    In the second case, higher values of max_frequency almost always imply more high-frequency filters in the basis.
    Anyways, in both cases, it is hard to tell a priori whether this is fine or not. I'd recommend to use the frequencies-cutoff parameter to control the frequencies of the filter. Indeed, even in the first case, you may not want to use the frequency 6 filter if you only have 3x3 filters.

Rather open-ended: Do you expect the trends from rows 29-44 of Table 5.1 to hold true for higher-resolution harder datasets than MNIST (e.g. ImageNet)?

I think this result is indeed strongly related to the low resolution of the data used.
On very high resolution data, this might change.
Note, however, that this also refers to the resolution of the field of view of a neuron.
So, even if you have very high resolution inputs, the neurons in the first layers will only process small patches.
Probably, neurons in the deepest layers of the network can benefit from higher frequencies, but I doubt the first layers can (unless you use very wide filters).

A silly question, but just to double-check: The order of representation in the specification of the field type doesn't matter, correct?

It depends. A permutation of the representations inside a FieldType will result in a fully equivalent architecture.
However, if you keep all representations of the same type close to each other in the FieldType, you will obtain much better inference time.
This is because the code can access the parts of the input tensor associated with the same representation by using slicing rather than indexing. While slicing only requires a view internally, advanced indexing requires a full copy of the part indexed (see this and this).
I just realised this was not really explicit in the documentation, that's my fault. I will update the documentation with some more comments on this.

Let me know if you have any other doubt

Best,
Gabriele

from e2cnn.

kristian-georgiev avatar kristian-georgiev commented on May 28, 2024 1

Hi @JoaoGuibs and @ahyunSeo. I have not updated my code snippet since my last comment in this thread.

from e2cnn.

Gabri95 avatar Gabri95 commented on May 28, 2024

Hi @kristian-georgiev #36

I tried running you network and this problem seems related to your other issue: #36

If you try to feed inputs of shape 191x191, the equivariance error of your model becomes 0
If you want to work on higher resolution images, I would recommend adapting the strides and/or the number of convolution layers in the model. Otherwise, you could also try to downsample your images to shape 191x191 before feeding them in the model.

Hope this helps!

Best,
Gabriele

from e2cnn.

kristian-georgiev avatar kristian-georgiev commented on May 28, 2024

Thank you for the quick response and apologies for the delay. I agree with what you said, the lack of equivariance indeed seems to stem from pooling and dilation. In addition, it seems like I've made other poor design choices since I was not able to train the networks to a reasonable accuracy on ImageNet. I've since taken a closer look at your experiments (Table 5.1 and e2cnn_experiments/experiments/models/exp_e2sfcnn.py in particular) and have made some corrections to my architecture (gist is updated). However, the current architecture seems to be unstable (loss consistently becomes NaN around the middle of the first epoch on ImageNet, for both max_frequency=-3 and -5); any pointers on what may be causing this are greatly appreciated. And potentially related, I have a couple of high-level questions:

  • I don't understand what is happening in these lines in the gated_normpool layers. Why do we have two sets of trivial reps? Why do we not have the same setup in the hnet_normpool layers (here)? In the same function, what are S and M? It seems that they go unused(?)
  • Where does the expression t /= 16 * s ** 2 * 3 / 4 (from here) come from when we try to keep number of total params constant? Are any of the numbers hardcoded for a particular frequency?
  • Is there any benefit/harm in using max_frequency higher than the highest irrep used? E.g. have max_frequency=10 but use irreps of frequency at most 3?
  • Rather open-ended: Do you expect the trends from rows 29-44 of Table 5.1 to hold true for higher-resolution harder datasets than MNIST (e.g. ImageNet)? In particular, do you expect a model that uses irreps <= 3 to still outperform models that introduce higher frequencies?
  • A silly question, but just to double-check: The order of representation in the specification of the field type doesn't matter, correct?

Thanks again for your time and apologies for the long issue!

from e2cnn.

ahyunSeo avatar ahyunSeo commented on May 28, 2024

Hello, @kristian-georgiev

I'm sorry to leave a comment on a year-old issue.
Did you fix the issue with the NaN loss? (I also got a NaN loss using your gist code)

Best,
Ahyun

from e2cnn.

JoaoGuibs avatar JoaoGuibs commented on May 28, 2024

Hi @kristian-georgiev , I fell upon this discussion and was wondering whether you had managed to train the ResNet equivariant model up to a reasonable accuracy on ImageNet? Thank you in advance.

from e2cnn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.