Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Tensor Puzzlers from torchtyping to jaxtyping. #25

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 45 additions & 42 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ If you are interested, there is also a youtube walkthrough of the puzzles
[![Watch the video](https://img.youtube.com/vi/SiwTAyyvt5s/default.jpg)](https://youtu.be/Hafo7hIl8MU)

```python
!pip install -qqq torchtyping hypothesis pytest git+https://github.com/danoneata/chalk@srush-patch-1
!pip install -qqq jaxtyping beartype hypothesis pytest git+https://github.com/chalk-diagrams/chalk
!wget -q https://github.com/srush/Tensor-Puzzles/raw/main/lib.py
```

Expand All @@ -35,8 +35,19 @@ If you are interested, there is also a youtube walkthrough of the puzzles
from lib import draw_examples, make_test, run_test
import torch
import numpy as np
from torchtyping import TensorType as TT
import jaxtyping
from lib import Ints, Reals, Bools # jaxtyping shorthand

import beartype
Tensor = torch.Tensor
tensor = torch.tensor

# Uncommenting the following will turn on type checking of your output
# sizes, but will interfere with the line counting at the end of this
# exercise. https://github.com/google/jaxtyping/issues/160
#
# %load_ext jaxtyping
# %jaxtyping.typechecker beartype.beartype
```

## Rules
Expand Down Expand Up @@ -121,8 +132,8 @@ Compute [ones](https://numpy.org/doc/stable/reference/generated/numpy.ones.html)
def ones_spec(out):
for i in range(len(out)):
out[i] = 1
def ones(i: int) -> TT["i"]:

def ones(i: int) -> Ints["{i}"]:
raise NotImplementedError

test_ones = make_test("one", ones, ones_spec, add_sizes=["i"])
Expand Down Expand Up @@ -150,7 +161,7 @@ def sum_spec(a, out):
for i in range(len(a)):
out[0] += a[i]

def sum(a: TT["i"]) -> TT[1]:
def sum(a: Reals["i"]) -> Reals["1"]:
raise NotImplementedError


Expand Down Expand Up @@ -179,7 +190,7 @@ def outer_spec(a, b, out):
for j in range(len(out[0])):
out[i][j] = a[i] * b[j]

def outer(a: TT["i"], b: TT["j"]) -> TT["i", "j"]:
def outer(a: Reals["i"], b: Reals["j"]) -> Reals["i j"]:
raise NotImplementedError

test_outer = make_test("outer", outer, outer_spec)
Expand All @@ -206,7 +217,7 @@ def diag_spec(a, out):
for i in range(len(a)):
out[i] = a[i][i]

def diag(a: TT["i", "i"]) -> TT["i"]:
def diag(a: Reals["i i"]) -> Reals["i"]:
raise NotImplementedError


Expand Down Expand Up @@ -234,7 +245,7 @@ def eye_spec(out):
for i in range(len(out)):
out[i][i] = 1

def eye(j: int) -> TT["j", "j"]:
def eye(j: int) -> Ints["{j} {j}"]:
raise NotImplementedError

test_eye = make_test("eye", eye, eye_spec, add_sizes=["j"])
Expand Down Expand Up @@ -265,10 +276,9 @@ def triu_spec(out):
else:
out[i][j] = 0

def triu(j: int) -> TT["j", "j"]:
def triu(j: int) -> Ints["{j} {j}"]:
raise NotImplementedError


test_triu = make_test("triu", triu, triu_spec, add_sizes=["j"])
```

Expand All @@ -295,7 +305,7 @@ def cumsum_spec(a, out):
out[i] = total + a[i]
total += a[i]

def cumsum(a: TT["i"]) -> TT["i"]:
def cumsum(a: Reals["i"]) -> Reals["i"]:
raise NotImplementedError

test_cumsum = make_test("cumsum", cumsum, cumsum_spec)
Expand Down Expand Up @@ -323,7 +333,8 @@ def diff_spec(a, out):
for i in range(1, len(out)):
out[i] = a[i] - a[i - 1]

def diff(a: TT["i"], i: int) -> TT["i"]:
def diff(a: Reals["i"], i: int) -> Reals["i"]:
assert(i == a.shape[0])
raise NotImplementedError

test_diff = make_test("diff", diff, diff_spec, add_sizes=["i"])
Expand Down Expand Up @@ -351,10 +362,9 @@ def vstack_spec(a, b, out):
out[0][i] = a[i]
out[1][i] = b[i]

def vstack(a: TT["i"], b: TT["i"]) -> TT[2, "i"]:
def vstack(a: Reals["i"], b: Reals["i"]) -> Reals["2 i"]:
raise NotImplementedError


test_vstack = make_test("vstack", vstack, vstack_spec)
```

Expand Down Expand Up @@ -382,7 +392,8 @@ def roll_spec(a, out):
else:
out[i] = a[i + 1 - len(out)]

def roll(a: TT["i"], i: int) -> TT["i"]:
def roll(a: Reals["i"], i: int) -> Reals["i"]:
assert(i == a.shape[0])
raise NotImplementedError


Expand Down Expand Up @@ -410,10 +421,10 @@ def flip_spec(a, out):
for i in range(len(out)):
out[i] = a[len(out) - i - 1]

def flip(a: TT["i"], i: int) -> TT["i"]:
def flip(a: Reals["i"], i: int) -> Reals["i"]:
assert(i == a.shape[0])
raise NotImplementedError


test_flip = make_test("flip", flip, flip_spec, add_sizes=["i"])
```

Expand Down Expand Up @@ -442,10 +453,10 @@ def compress_spec(g, v, out):
out[j] = v[i]
j += 1

def compress(g: TT["i", bool], v: TT["i"], i:int) -> TT["i"]:
def compress(g: Bools["i"], v: Reals["i"], i: int) -> Reals["i"]:
assert(i == v.shape[0])
raise NotImplementedError


test_compress = make_test("compress", compress, compress_spec, add_sizes=["i"])
```

Expand All @@ -471,11 +482,10 @@ def pad_to_spec(a, out):
for i in range(min(len(out), len(a))):
out[i] = a[i]


def pad_to(a: TT["i"], i: int, j: int) -> TT["j"]:
def pad_to(a: Reals["i"], i: int, j: int) -> Reals["{j}"]:
assert(i == a.shape[0])
raise NotImplementedError


test_pad_to = make_test("pad_to", pad_to, pad_to_spec, add_sizes=["i", "j"])
```

Expand Down Expand Up @@ -505,15 +515,13 @@ def sequence_mask_spec(values, length, out):
else:
out[i][j] = 0

def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
def sequence_mask(values: Reals["i j"], length: Ints["i"]) -> Reals["i j"]:
raise NotImplementedError


def constraint_set_length(d):
d["length"] = d["length"] % d["values"].shape[1]
return d


test_sequence = make_test("sequence_mask",
sequence_mask, sequence_mask_spec, constraint=constraint_set_length
)
Expand All @@ -540,10 +548,11 @@ def bincount_spec(a, out):
for i in range(len(a)):
out[a[i]] += 1

def bincount(a: TT["i"], j: int) -> TT["j"]:
def bincount(a: Ints["i"], j: int) -> Ints["{j}"]:
assert j >= max(a)
assert all(x >= 0 for x in a)
raise NotImplementedError


def constraint_set_max(d):
d["a"] = d["a"] % d["return"].shape[0]
return d
Expand Down Expand Up @@ -575,15 +584,13 @@ def scatter_add_spec(values, link, out):
for j in range(len(values)):
out[link[j]] += values[j]

def scatter_add(values: TT["i"], link: TT["i"], j: int) -> TT["j"]:
def scatter_add(values: Reals["i"], link: Ints["i"], j: int) -> Reals["{j}"]:
raise NotImplementedError


def constraint_set_max(d):
d["link"] = d["link"] % d["return"].shape[0]
return d


test_scatter_add = make_test("scatter_add",
scatter_add, scatter_add_spec, add_sizes=["j"], constraint=constraint_set_max
)
Expand Down Expand Up @@ -613,7 +620,9 @@ def flatten_spec(a, out):
out[k] = a[i][j]
k += 1

def flatten(a: TT["i", "j"], i:int, j:int) -> TT["i * j"]:
def flatten(a: Reals["i j"], i:int, j:int) -> Reals["i*j"]:
assert(i == a.shape[0])
assert(j == a.shape[1])
raise NotImplementedError

test_flatten = make_test("flatten", flatten, flatten_spec, add_sizes=["i", "j"])
Expand All @@ -640,7 +649,7 @@ def linspace_spec(i, j, out):
for k in range(len(out)):
out[k] = float(i + (j - i) * k / max(1, len(out) - 1))

def linspace(i: TT[1], j: TT[1], n: int) -> TT["n", float]:
def linspace(i: Reals["1"], j: Reals["1"], n: int) -> Reals["{n}"]:
raise NotImplementedError

test_linspace = make_test("linspace", linspace, linspace_spec, add_sizes=["n"])
Expand Down Expand Up @@ -670,7 +679,7 @@ def heaviside_spec(a, b, out):
else:
out[k] = int(a[k] > 0)

def heaviside(a: TT["i"], b: TT["i"]) -> TT["i"]:
def heaviside(a: Reals["i"], b: Reals["i"]) -> Reals["i"]:
raise NotImplementedError

test_heaviside = make_test("heaviside", heaviside, heaviside_spec)
Expand All @@ -697,16 +706,11 @@ def repeat_spec(a, d, out):
for i in range(d[0]):
for k in range(len(a)):
out[i][k] = a[k]

def constraint_set(d):
d["d"][0] = d["return"].shape[0]
return d


def repeat(a: TT["i"], d: TT[1]) -> TT["d", "i"]:
def repeat(a: Reals["i"], d: int) -> Reals["{d} i"]:
raise NotImplementedError

test_repeat = make_test("repeat", repeat, repeat_spec, constraint=constraint_set)
test_repeat = make_test("repeat", repeat, repeat_spec, add_sizes=['d'])
```

![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_68_0.svg)
Expand All @@ -730,8 +734,7 @@ def constraint_set(d):
d["boundaries"] = np.abs(d["boundaries"]).cumsum()
return d


def bucketize(v: TT["i"], boundaries: TT["j"]) -> TT["i"]:
def bucketize(v: Reals["i"], boundaries: Reals["j"]) -> Ints["i"]:
raise NotImplementedError

test_bucketize = make_test("bucketize", bucketize, bucketize_spec,
Expand Down
Loading