Skip to content

Commit

Permalink
Add a non-one shape constraint to TensorType
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 4, 2022
1 parent fa01971 commit 1b025d6
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 70 deletions.
21 changes: 15 additions & 6 deletions aesara/link/jax/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,21 @@ def shape_i(x):
def jax_funcify_SpecifyShape(op, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), (
"got shape",
x.shape,
"expected",
shape,
)
for s_x, s in zip(x.shape, shape):
if s == -1:
assert s_x != 1, (
"got shape",
s_x,
"expected",
s,
)
elif s > -1:
assert s_x == s, (
"got shape",
s_x,
"expected",
s,
)
return x

return specifyshape
Expand Down
17 changes: 13 additions & 4 deletions aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from contextlib import contextmanager
from functools import singledispatch
from textwrap import dedent
from textwrap import dedent, indent
from typing import Union

import numba
Expand Down Expand Up @@ -650,17 +650,26 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]

func_conditions = [
f"assert x.shape[{i}] == {shape_input_names}"
for i, (shape_input, shape_input_names) in enumerate(
dedent(
f"""
if {shape_name} == -1:
assert x.shape[{i}] != 1
if {shape_name} > -1:
assert x.shape[{i}] == {shape_name}
"""
)
for i, (shape_input, shape_name) in enumerate(
zip(shape_inputs, shape_input_names)
)
if shape_input is not NoneConst
]

conditions_block = "\n".join(func_conditions)

func = dedent(
f"""
def specify_shape(x, {create_arg_string(shape_input_names)}):
{"; ".join(func_conditions)}
{indent(conditions_block, ' ' * 12)}
return x
"""
)
Expand Down
99 changes: 73 additions & 26 deletions aesara/tensor/shape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from numbers import Number
from textwrap import dedent
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Sequence, Tuple, Union

import numpy as np

Expand All @@ -17,11 +17,65 @@
from aesara.tensor import basic as at
from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
from aesara.tensor.type import (
DenseTensorType,
TensorType,
int_dtypes,
shape_encode,
tensor,
)
from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorConstant, TensorVariable


def filter_shape_vars(
ref_shape: Tuple[int, ...], shape: Sequence[Variable], shape_is_encoded: bool = True
) -> Tuple[int, ...]:
r"""Compute the most \"informative\" shape based on a static reference.
Parameters
----------
ref_shape
A static shape reference using static shape constraint encoding.
shape
A symbolic shape.
shape_is_encoded
If ``True``, `shape` is assumed to be static shape constraint encoded.
Returns
-------
The most specific, and compatible (with `ref_shape`), static shape
constraint encoded values.
"""
shape_bottom = shape_encode(None)
type_shape = ()
for i, (xts, s) in enumerate(zip(ref_shape, shape)):

try:
# TODO FIXME: We shouldn't need to do this; let a rewrite
# do constant folding and update the `TensorType`s.
s_val = at.get_scalar_constant_value(s)

if isinstance(s_val, np.ndarray):
s_val = s_val.item()

if shape_is_encoded or s_val is not None and s_val > 0:
type_s = shape_encode(s_val)
else:
type_s = shape_bottom
except NotScalarConstantError:
type_s = shape_bottom

if not (xts <= -1 or type_s <= -1 or type_s == xts):
raise AssertionError(
f"SpecifyShape: Got shape {xts} at index {i}, expected {type_s}."
)

type_shape += (max(type_s, xts),)

return type_shape


def register_shape_c_code(type, code, version=()):
"""
Tell Shape Op how to generate C code for an Aesara Type.
Expand Down Expand Up @@ -394,7 +448,6 @@ class SpecifyShape(COp):
_f16_ok = True

def make_node(self, x, *shape):
from aesara.tensor.basic import get_scalar_constant_value

x = at.as_tensor_variable(x)

Expand All @@ -417,18 +470,7 @@ def make_node(self, x, *shape):
f"Input `x` is {x.type.ndim}-dimensional and will never match a shape of length {len(shape)}."
)

type_shape = [None] * x.ndim
for i, (xts, s) in enumerate(zip(x.type.shape, shape)):
if xts is not None:
type_shape[i] = xts
else:
try:
type_s = get_scalar_constant_value(s)
if type_s is not None:
type_shape[i] = int(type_s)
except NotScalarConstantError:
pass

type_shape = filter_shape_vars(x.type.shape_encoded, shape)
out_var = x.type.clone(shape=type_shape)()

return Apply(self, [x, *shape], [out_var])
Expand All @@ -441,10 +483,10 @@ def perform(self, node, inp, out_):
raise AssertionError(
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
)
if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None):
raise AssertionError(
f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
)
for xs, s in zip(x.shape, shape):
if (s == -1 and xs == 1) or (s is not None and s > -1 and not xs == s):
raise AssertionError(f"SpecifyShape: Got shape {xs}, expected {s}.")

out[0] = x

def infer_shape(self, fgraph, node, shapes):
Expand All @@ -454,11 +496,11 @@ def infer_shape(self, fgraph, node, shapes):
for dim in range(node.inputs[0].type.ndim):
s = shape[dim]
try:
s = at.get_scalar_constant_value(s)
# We assume that `None` shapes are always retrieved by
s = shape_encode(at.get_scalar_constant_value(s))
# We assume that negative shapes are always retrieved by
# `get_scalar_constant_value`, and only in that case do we default to
# the shape of the input variable
if s is None:
if s < 0:
s = xshape[dim]
except NotScalarConstantError:
pass
Expand Down Expand Up @@ -502,6 +544,9 @@ def c_code(self, node, name, i_names, o_names, sub):
);
{fail};
}}
npy_intp shp;
npy_intp actual_shp;
"""
)

Expand All @@ -510,9 +555,11 @@ def c_code(self, node, name, i_names, o_names, sub):
continue
code += dedent(
f"""
if (py_{shp_name} != Py_None){{
dtype_{shp_name} shp = ((dtype_{shp_name}*)PyArray_GETPTR1({shp_name}, 0))[0];
if (PyArray_DIMS({x_name})[{i}] != shp) {{
shp = ((dtype_{shp_name}*)PyArray_GETPTR1({shp_name}, 0))[0];
if (shp > -2) {{
actual_shp = PyArray_DIMS({x_name})[{i}];
if (actual_shp == -1 && shp == 1 || actual_shp != shp) {{
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: dim %d of input has shape %d, expected %d.",
{i}, PyArray_DIMS({x_name})[{i}], shp
Expand All @@ -533,7 +580,7 @@ def c_code(self, node, name, i_names, o_names, sub):
return code

def c_code_cache_version(self):
return (2,)
return (3,)


_specify_shape = SpecifyShape()
Expand Down
Loading

0 comments on commit 1b025d6

Please sign in to comment.