Comments (2)
Thanks for the clarifications,
I understand it would require additional logic to infer dimensions, or cause more work for a user to specify them. For me, this feature is interesting because I'm working with operations that can be expressed neatly in this syntax and will allow me to replace rearrange -> einsum -> rearrange
with a single call to einsum
.
Feel free to close the issue.
Best,
Felix
I'm leaving behind my own cooked up version that supports the syntax. Maybe it's useful for someone else:
"""Einsum with support for ``einops``' index un-grouping syntax."""
from typing import Union
from einops import einsum as einops_einsum
from einops import rearrange
from einops.einops import Tensor
from torch import allclose, manual_seed, rand
def einsum(*tensors_and_pattern: Union[Tensor, str], **axes_lengths: int) -> Tensor:
"""Same as ``einops.einsum`` but supports index un-grouping notation.
For example, the following operation does not work (yet) in ``einops.einsum``
(https://github.com/arogozhnikov/einops/blob/fcd36c9b017d52e452a301e75c1373b75ec23ee0/einops/einops.py#L833-L834),
but works with this version: ``einsum(A, B, '(a b) c, a b c -> (a b) c', a=a, b=b)``.
Args:
tensors_and_pattern:
tensors: tensors of any supported library (numpy, tensorflow, pytorch, jax).
pattern: string, einsum pattern, with commas separating specifications for
each tensor. Pattern should be provided after all tensors.
axes_lengths: Length of axes that cannot be inferred.
Returns:
Tensor of the same type as input, after processing with einsum.
Raises:
NotImplementedError: If the pattern contains unsupported features.
"""
try:
return einops_einsum(*tensors_and_pattern)
except NotImplementedError as e:
tensors, pattern = tensors_and_pattern[:-1], tensors_and_pattern[-1]
if "(" not in pattern or ")" not in pattern:
raise NotImplementedError from e
# un-group the operands
lefts, right = pattern.split("->")
lefts = lefts.split(",")
lefts_ungrouped = [l.replace("(", "").replace(")", "") for l in lefts]
tensors_ungrouped = [
rearrange(t, " -> ".join([l, l_u]), **axes_lengths) if l != l_u else t
for t, l, l_u in zip(tensors, lefts, lefts_ungrouped)
]
# compute the result with un-grouped indices
right_ungrouped = right.replace("(", "").replace(")", "")
pattern_ungrouped = " -> ".join([",".join(lefts_ungrouped), right_ungrouped])
result_ungrouped = einops_einsum(*tensors_ungrouped, pattern_ungrouped)
# group the indices in the result tensor
return (
rearrange(
result_ungrouped, " -> ".join([right_ungrouped, right]), **axes_lengths
)
if right_ungrouped != right
else result_ungrouped
)
def test_einsum():
"""Test einsum with support for index un-grouping syntax."""
manual_seed(0)
a, b, c = (3, 4, 5)
A = rand(a, b, c)
B = rand(a * b, c)
# NOTE Need to specify dims ``a, b`` although they could be inferred
axes_lengths = dict(a=a, b=b)
# no rearrange of result tensor
C = einsum(A, B, "a b c, (a b) c -> a b c", **axes_lengths)
C_truth = einsum(A, B.reshape(a, b, c), "a b c, a b c -> a b c")
assert allclose(C, C_truth)
# rearrange required before returning the result
C = einsum(A, B, "a b c, (a b) c -> (a b) c", **axes_lengths)
C_truth = einsum(A, B.reshape(a, b, c), "a b c, a b c -> a b c").reshape(a * b, c)
assert allclose(C, C_truth)
if __name__ == "__main__":
test_einsum()
from einops.
Hi Felix,
I wanted to know if there is a design problem that prevents supporting this, or whether it is simply not implemented.
somewhat in the middle.
There are many potential wishes to einops.einsum, which require some investment (e.g. collapsed ellipsis, 1-axes, composition/decomposition of axes).
E.g. if you think of pattern
(a b c), (b c ...), ..., b ..., (a d) -> a b c d
You'll see this requires quite a number of steps to infer all shapes.
Alternative is to ask user to provide sufficient number of dimensions, e.g. in this case 'a', 'b' and 'c'
So I'll say there should be a strong reason to justify maintenance of this logic.
but there are potential optimizations, e.g. merging parallel legs
I expect backend framework to take care of that. Doing this on einops side may induce additional copies (frameworks can make these allocations transient/partial or even avoid them at all).
from einops.
Related Issues (20)
- *** AttributeError: 'Rearrange' object has no attribute 'recipe'[BUG] HOT 1
- [BUG] batchsize of dataloading
- [BUG] error when import einops HOT 1
- [BUG] Einops repeat throws device error during torchscripting HOT 1
- [Feature suggestion] apple mlx support
- [Feature suggestion] Allow performing a view instead of a reshape HOT 3
- [BUG] einops.repeat returns value with type Never HOT 3
- Add support for keras3 HOT 2
- [Feature suggestion] fixup/support anonymous axes in `parse_shape` HOT 2
- [BUG] `einsum` with `ii->i` raises an unknow axis error. HOT 1
- [Feature suggestion] package downloaded from conda-forge seems missing some functions HOT 1
- Passing a float in `repeat` as a dimension size prevents correct usage afterwards HOT 2
- What am I doing wrong with repeat command? HOT 1
- Tests failing on FreeBSD HOT 3
- Alias ellipsis to a star for rearrange HOT 3
- Circular imports when importing einops and torch._dynamo HOT 6
- [BUG] SyntaxError in Python 3.11 HOT 2
- [BUG] Basic code from documentation does not work HOT 1
- [Feature suggestion] Optionally return the inferred dimensions
- module 'einops' has no attribute 'layers' HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from einops.