Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for JAX Arrays #1

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Fancy Einsum

This is a simple wrapper around `np.einsum` and `torch.einsum` that allows the use of self-documenting variable names instead of just single letters in the equations. Inspired by the syntax in [einops](https://github.com/arogozhnikov/einops).
This is a simple wrapper around `np.einsum`, `jnp.einsum`, and `torch.einsum` that allows the use of self-documenting variable names instead of just single letters in the equations. Inspired by the syntax in [einops](https://github.com/arogozhnikov/einops).

For example, instead of writing:

Expand All @@ -9,16 +9,23 @@ import torch
torch.einsum('bct,bcs->bcts', a, b)
```

or
or

```python
import numpy as np
np.einsum('bct,bcs->bcts', a, b)
```

or

```python
import jax.numpy as jnp
jnp.einsum('bct,bcs->bcts', a, b)
```

With this library you can write:

```python
from fancy_einsum import einsum
einsum('batch channels time1, batch channels time2 -> batch channels time1 time2', a, b)
```
```
30 changes: 23 additions & 7 deletions fancy_einsum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,24 @@ def is_appropriate_type(self, tensor):
def einsum(self, equation, *operands):
return self.np.einsum(equation, *operands)

# end part following einops

class JaxBackend(AbstractBackend):
framework_name = 'jax'

def __init__(self):
import jax
import jax.numpy as jnp
self.jax = jax
self.jnp = jnp

def is_appropriate_type(self, tensor):
return isinstance(tensor, self.jax.Array)

def einsum(self, equation, *operands):
return self.jnp.einsum(equation, *operands)


# end part following einops


_part_re = re.compile(r'\.{3}|\w+|,|->')
Expand All @@ -86,7 +102,7 @@ def convert_equation(equation: str) -> str:
rhs.extend(sorted(term for term in terms if
term not in SPECIAL and terms.count(term) == 1))
terms.extend(rhs)

# First pass: prefer to map long names to first letter, uppercase if needed
# so "time" becomes t if possible, then T.
short_to_long = {}
Expand All @@ -107,8 +123,8 @@ def try_make_abbr(s):
# Handle multiple long with same first letter. Second one gets first available letter
conflicts = []
for term in terms:
if (term not in SPECIAL and
term not in long_to_short and
if (term not in SPECIAL and
term not in long_to_short and
term not in conflicts and
not try_make_abbr(term)):
conflicts.append(term)
Expand All @@ -126,11 +142,11 @@ def try_make_abbr(s):

def einsum(equation: str, *operands):
"""Evaluates the Einstein summation convention on the operands.
See:

See:
https://pytorch.org/docs/stable/generated/torch.einsum.html
https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
"""
backend = get_backend(operands[0])
new_equation = convert_equation(equation)
return backend.einsum(new_equation, *operands)
return backend.einsum(new_equation, *operands)
55 changes: 55 additions & 0 deletions fancy_einsum/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import jax.numpy as jnp
from jax.numpy import allclose
from hypothesis import given
from hypothesis.strategies import integers, composite, lists
from hypothesis.extra.numpy import arrays

from fancy_einsum import einsum


def tensor(draw, shape):
return draw(arrays(dtype=int, shape=shape))


@composite
def square_matrix(draw):
n = draw(integers(2, 10))
return tensor(draw, (n, n))


@given(square_matrix())
def test_simple_matmul(mat):
mat = jnp.array(mat)
actual = einsum("length length ->", mat)
assert allclose(actual, jnp.einsum("aa->", mat))


@composite
def matmul_compatible(draw):
b = draw(integers(1, 10))
r = draw(integers(1, 10))
t = draw(integers(1, 10))
c = draw(integers(1, 10))
return tensor(draw, (b, r, t)), tensor(draw, (b, t, c))


@given(matmul_compatible())
def test_ellipse_matmul(args):
a, b = args
a, b = jnp.array(a), jnp.array(b)
actual = einsum("...rows temp, ...temp cols -> ...rows cols", a, b)
assert allclose(actual, jnp.einsum("...rt,...tc->...rc", a, b))


@composite
def chain_matmul(draw):
sizes = [draw(integers(1, 4)) for _ in range(5)]
shapes = [(sizes[i - 1], sizes[i]) for i in range(1, len(sizes))]
return [tensor(draw, shape) for shape in shapes]


@given(chain_matmul())
def test_chain_matmul(args):
args = [jnp.array(arg) for arg in args]
actual = einsum("rows t1, t1 t2, t2 t3, t3 cols -> rows cols", *args)
assert allclose(actual, jnp.einsum("ab,bc,cd,de->ae", *args))
3 changes: 2 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch
numpy
jax
pytest
hypothesis
hypothesis