Giter Site home page Giter Site logo

Comments (11)

mauriceweiler avatar mauriceweiler commented on May 27, 2024 1

Thanks for the feedback @drewm1980, slicing is indeed something which we should implement but did not do yet due to time constraints.

I would not go for the solution of skipping the channel axis since the result might be unexpected for unexperienced users. It seems better to make the channel dimension explicit while guaranteeing that the user can't break the equivariance (i.e. split within fields).
I see three options:

  1. We enforce that no splits are allowed in the channel dimension, i.e. that the slices are necessarily of the form [N, :, X, Y].
  2. Slices are of the form [N, C, X, Y] where C counts channels. The method throws an exception if C splits within a field.
  3. Slices are of the form [N, F, X, Y] where F counts fields. This would be similar to the behavior of GeometricTensor.split(). This seems most logical from the viewpoint of steerable CNNs. The downside is that might intuitively expect the behavior of option 2) and get confused.

from e2cnn.

drewm1980 avatar drewm1980 commented on May 27, 2024 1

Thanks for the improvement! I'm trying it out. I needed this again today, and it was again involving cropping in the spatial dimensions to make tensors compatible for concatenation (for some skip connections in my network). Also, I agree with the decision to keep indexing numpy/pytorch compatible as much as possible.

from e2cnn.

drewm1980 avatar drewm1980 commented on May 27, 2024 1

I trained a new "champion" network that was using this internally today for skip connections. I haven't uncovered any issues related to this yet. Thanks!

from e2cnn.

drewm1980 avatar drewm1980 commented on May 27, 2024

This gets me past the error; will know later if it actually works.

    def center_crop(self, geometric_tensor:nn.GeometricTensor, target_size:List[int])->nn.GeometricTensor:

        # Unpack to a tensor so we can slice
        tensor:torch.Tensor = geometric_tensor.tensor
        field_type:nn.FieldType = geometric_tensor.type

        _, _, tensor_height, tensor_width = tensor.size()
        diff_y = (tensor_height - target_size[0]) // 2
        diff_x = (tensor_width - target_size[1]) // 2
        tensor_sliced = tensor[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]
        
        # Repack into a GeometricTensor
        geometric_tensor_sliced = nn.GeometricTensor(tensor_sliced, field_type)

        return geometric_tensor_sliced

I'm still new to torch's memory model regarding slicing. Will passing a sliced tensor back into GeometricTensor's instantiator cause problems?

from e2cnn.

drewm1980 avatar drewm1980 commented on May 27, 2024

Update: Casting back to a torch.Tensor as above is inelegant, but it seems to work. It's no longer urgent for me, but I'll leave the ticket open in case you have a more elegant solution for indexing. Was it a deliberate decision to disable indexing so that users don't mess up the channels dimension in a way that would break equivariance? If so, maybe there's a way to allow indexing as long as the caller only slices on the spatial dimensions...

from e2cnn.

Gabri95 avatar Gabri95 commented on May 27, 2024

Hi @drewm1980,

sorry for my late reply.

Yeah, GeometricTensor does not support indexing. A simple solution is unwrapping the underlying torch.Tensor, crop that and wrap it again in a GeometricTensor. I usually do the same.
When instantiated, a GeometricTensor doesn't do much more than storing a reference to the torch.Tensor and a reference to the FieldType. So, I don't think you need to worry about memory management.

Was it a deliberate decision to disable indexing so that users don't mess up the channels dimension in a way that would break equivariance? If so, maybe there's a way to allow indexing as long as the caller only slices on the spatial dimensions...

Yes, channels can not be split freely as it could break equivariance. If you need to only slice in the channel dimension you can use GeometricTensor.split().

For the spatial dimensions, it was not a common use case for me, so I did not implement any additional interface. I agree it deserves a better solution, though.
I think I can override the brackets operator (the one usually used for indexing) for GeometricTensor. I will try to implement it this way and let you know if it works.

Best,
Gabriele

from e2cnn.

Gabri95 avatar Gabri95 commented on May 27, 2024

Supporting full slicing of the underlying tensor is dangerous as it can split channels which belong to the same field. The method GeometricTensor.split() can be used for this purpose.

I can implement slicing along the batch and the spatial dimensions adding the following method in the GeometricTensor class:

def __getitem__(self, slices):        
    # Slicing is not supported on the channel dimension.
    if isinstance(slices, tuple):
        if len(slices) > len(self.tensor.shape) - 1:
            raise TypeError()
    else:
        slices = (slices,)

    # This is equivalent to use [:] on the channels dimensions
    idxs = (slices[0], slice(None, None, None), *slices[1:])
    sliced_tensor = self.tensor[idxs]
    return GeometricTensor(sliced_tensor, self.type)

This would allow slicing as you would usually do it in PyTorch or Numpy, but it would skip the channel dimension when multiple indices are passed.
This is an example:

space = Rot2dOnR2(4)
type = FieldType(space, [space.regular_repr])
geom_tensor = GeometricTensor(torch.randn(10, type.size, 7, 7), type)

geom_tensor.shape
>> torch.Size([10, 4, 7, 7])

geom_tensor[1:3, 2:5, 2:5].shape
>> torch.Size([2, 4, 3, 3])

Here, I've passed 3 indices which are then used for the first (batch), third and fourth (spatial) dimensions, skipping the second (channels) one.
Do you think this is a valid solution or would this behaviour be confusing?

I appreciate any feedback and if anyone has a better suggestion, feel free to write it here!

Thanks!
Gabriele

from e2cnn.

Gabri95 avatar Gabri95 commented on May 27, 2024

I personally really like option 3 as it seems a very clean solution and could also replace GeometricTensor.split().
To be more precise, it would produce this behaviour:

space = Rot2dOnR2(4)
type = FieldType(space, [space.regular_repr]*2 + [space.irrep(1)]*3 )
geom_tensor = GeometricTensor(torch.randn(10, type.size, 7, 7), type)

geom_tensor.shape
>> torch.Size([10, 14, 7, 7])

geom_tensor[1:3, :, 2:5, 2:5].shape
>> torch.Size([2, 14, 3, 3])

# the first 2 fields are regular fields of size 4. In total, they contain 2*4 = 8 channels
geom_tensor[:, :2, :, :].shape
>> torch.Size([10, 8, 7, 7])

# the last 2 fields are vector fields of size 2. In total, they contain 2*2 =4 channels
geom_tensor[:, -2:, :, :].shape
>> torch.Size([10, 4, 7, 7])

# the first 3 fields are 2 regular and 1 vector field. In total, they contain 2*4 + 2 =10 channels
geom_tensor[:, :3, :, :].shape
>> torch.Size([10, 10, 7, 7])

This could also simplify the use of MultipleModule.

Option 1 would probably be the most user-friendly for new users.

from e2cnn.

mauriceweiler avatar mauriceweiler commented on May 27, 2024

Allright, lets go for option 3 then. It includes option 1 and if the user really wants to slice within fields, he can do it by the workaround proposed by @drewm1980.

from e2cnn.

Gabri95 avatar Gabri95 commented on May 27, 2024

We now support slicing on all axes and simple indexing (i.e. with a single index per dimension).
Slicing in the second dimension is done over fields instead of channels.
We do not support advanced indexing, though.

You can find some examples and additional details here.

I hope this can be helpful!

Best,
Gabriele

from e2cnn.

Gabri95 avatar Gabri95 commented on May 27, 2024

Hi @drewm1980

good to hear!

Thanks for the feedback! We really appreciate it!

Please, let us know if you encounter any issues!

from e2cnn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.