Comments (11)
Thanks for the feedback @drewm1980, slicing is indeed something which we should implement but did not do yet due to time constraints.
I would not go for the solution of skipping the channel axis since the result might be unexpected for unexperienced users. It seems better to make the channel dimension explicit while guaranteeing that the user can't break the equivariance (i.e. split within fields).
I see three options:
- We enforce that no splits are allowed in the channel dimension, i.e. that the slices are necessarily of the form [N, :, X, Y].
- Slices are of the form [N, C, X, Y] where C counts channels. The method throws an exception if C splits within a field.
- Slices are of the form [N, F, X, Y] where F counts fields. This would be similar to the behavior of GeometricTensor.split(). This seems most logical from the viewpoint of steerable CNNs. The downside is that might intuitively expect the behavior of option 2) and get confused.
from e2cnn.
Thanks for the improvement! I'm trying it out. I needed this again today, and it was again involving cropping in the spatial dimensions to make tensors compatible for concatenation (for some skip connections in my network). Also, I agree with the decision to keep indexing numpy/pytorch compatible as much as possible.
from e2cnn.
I trained a new "champion" network that was using this internally today for skip connections. I haven't uncovered any issues related to this yet. Thanks!
from e2cnn.
This gets me past the error; will know later if it actually works.
def center_crop(self, geometric_tensor:nn.GeometricTensor, target_size:List[int])->nn.GeometricTensor:
# Unpack to a tensor so we can slice
tensor:torch.Tensor = geometric_tensor.tensor
field_type:nn.FieldType = geometric_tensor.type
_, _, tensor_height, tensor_width = tensor.size()
diff_y = (tensor_height - target_size[0]) // 2
diff_x = (tensor_width - target_size[1]) // 2
tensor_sliced = tensor[
:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
]
# Repack into a GeometricTensor
geometric_tensor_sliced = nn.GeometricTensor(tensor_sliced, field_type)
return geometric_tensor_sliced
I'm still new to torch's memory model regarding slicing. Will passing a sliced tensor back into GeometricTensor's instantiator cause problems?
from e2cnn.
Update: Casting back to a torch.Tensor as above is inelegant, but it seems to work. It's no longer urgent for me, but I'll leave the ticket open in case you have a more elegant solution for indexing. Was it a deliberate decision to disable indexing so that users don't mess up the channels dimension in a way that would break equivariance? If so, maybe there's a way to allow indexing as long as the caller only slices on the spatial dimensions...
from e2cnn.
Hi @drewm1980,
sorry for my late reply.
Yeah, GeometricTensor does not support indexing. A simple solution is unwrapping the underlying torch.Tensor, crop that and wrap it again in a GeometricTensor. I usually do the same.
When instantiated, a GeometricTensor doesn't do much more than storing a reference to the torch.Tensor and a reference to the FieldType. So, I don't think you need to worry about memory management.
Was it a deliberate decision to disable indexing so that users don't mess up the channels dimension in a way that would break equivariance? If so, maybe there's a way to allow indexing as long as the caller only slices on the spatial dimensions...
Yes, channels can not be split freely as it could break equivariance. If you need to only slice in the channel dimension you can use GeometricTensor.split().
For the spatial dimensions, it was not a common use case for me, so I did not implement any additional interface. I agree it deserves a better solution, though.
I think I can override the brackets operator (the one usually used for indexing) for GeometricTensor. I will try to implement it this way and let you know if it works.
Best,
Gabriele
from e2cnn.
Supporting full slicing of the underlying tensor is dangerous as it can split channels which belong to the same field. The method GeometricTensor.split() can be used for this purpose.
I can implement slicing along the batch and the spatial dimensions adding the following method in the GeometricTensor class:
def __getitem__(self, slices):
# Slicing is not supported on the channel dimension.
if isinstance(slices, tuple):
if len(slices) > len(self.tensor.shape) - 1:
raise TypeError()
else:
slices = (slices,)
# This is equivalent to use [:] on the channels dimensions
idxs = (slices[0], slice(None, None, None), *slices[1:])
sliced_tensor = self.tensor[idxs]
return GeometricTensor(sliced_tensor, self.type)
This would allow slicing as you would usually do it in PyTorch or Numpy, but it would skip the channel dimension when multiple indices are passed.
This is an example:
space = Rot2dOnR2(4)
type = FieldType(space, [space.regular_repr])
geom_tensor = GeometricTensor(torch.randn(10, type.size, 7, 7), type)
geom_tensor.shape
>> torch.Size([10, 4, 7, 7])
geom_tensor[1:3, 2:5, 2:5].shape
>> torch.Size([2, 4, 3, 3])
Here, I've passed 3 indices which are then used for the first (batch), third and fourth (spatial) dimensions, skipping the second (channels) one.
Do you think this is a valid solution or would this behaviour be confusing?
I appreciate any feedback and if anyone has a better suggestion, feel free to write it here!
Thanks!
Gabriele
from e2cnn.
I personally really like option 3 as it seems a very clean solution and could also replace GeometricTensor.split().
To be more precise, it would produce this behaviour:
space = Rot2dOnR2(4)
type = FieldType(space, [space.regular_repr]*2 + [space.irrep(1)]*3 )
geom_tensor = GeometricTensor(torch.randn(10, type.size, 7, 7), type)
geom_tensor.shape
>> torch.Size([10, 14, 7, 7])
geom_tensor[1:3, :, 2:5, 2:5].shape
>> torch.Size([2, 14, 3, 3])
# the first 2 fields are regular fields of size 4. In total, they contain 2*4 = 8 channels
geom_tensor[:, :2, :, :].shape
>> torch.Size([10, 8, 7, 7])
# the last 2 fields are vector fields of size 2. In total, they contain 2*2 =4 channels
geom_tensor[:, -2:, :, :].shape
>> torch.Size([10, 4, 7, 7])
# the first 3 fields are 2 regular and 1 vector field. In total, they contain 2*4 + 2 =10 channels
geom_tensor[:, :3, :, :].shape
>> torch.Size([10, 10, 7, 7])
This could also simplify the use of MultipleModule.
Option 1 would probably be the most user-friendly for new users.
from e2cnn.
Allright, lets go for option 3 then. It includes option 1 and if the user really wants to slice within fields, he can do it by the workaround proposed by @drewm1980.
from e2cnn.
We now support slicing on all axes and simple indexing (i.e. with a single index per dimension).
Slicing in the second dimension is done over fields instead of channels.
We do not support advanced indexing, though.
You can find some examples and additional details here.
I hope this can be helpful!
Best,
Gabriele
from e2cnn.
Hi @drewm1980
good to hear!
Thanks for the feedback! We really appreciate it!
Please, let us know if you encounter any issues!
from e2cnn.
Related Issues (20)
- wrapping pytorch operations - grid_sample HOT 4
- Import Error with Torch 1.9.0+cu111 HOT 2
- equivariant Transformer HOT 5
- ZeroPad2D on GeometricTensor
- Cannot pass weights of R2Conv as a positional argument HOT 2
- Counting FLOPs for e2cnn HOT 1
- equivariance in C8 space HOT 1
- Module export HOT 3
- About the equivalence of wide_resnet HOT 5
- Need a size parameter for e2cnn.R2Upsampling Class HOT 1
- about attribute R2conv.filter HOT 2
- Learning of kernels HOT 2
- O(2) group, irreps, and PyTorch DDP. HOT 2
- checking equivariance for the angles that are not 90n HOT 2
- about to set special rotation equivariant HOT 2
- Cannot import name container_abcs in python 3.6 version (e2cnn_py36)
- shriking size of output image
- Use of np.float and np.int etc
- Difference between trivial output type and regular output type with group pooling HOT 1
- Export Linear HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from e2cnn.