diff --git a/aesara/link/jax/dispatch/shape.py b/aesara/link/jax/dispatch/shape.py index 77da552713..800d0a50ec 100644 --- a/aesara/link/jax/dispatch/shape.py +++ b/aesara/link/jax/dispatch/shape.py @@ -46,12 +46,21 @@ def shape_i(x): def jax_funcify_SpecifyShape(op, **kwargs): def specifyshape(x, *shape): assert x.ndim == len(shape) - assert jnp.all(x.shape == tuple(shape)), ( - "got shape", - x.shape, - "expected", - shape, - ) + for s_x, s in zip(x.shape, shape): + if s == -1: + assert s_x != 1, ( + "got shape", + s_x, + "expected", + s, + ) + elif s > -1: + assert s_x == s, ( + "got shape", + s_x, + "expected", + s, + ) return x return specifyshape diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index 79960e52fa..12d1baba0f 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -2,7 +2,7 @@ import warnings from contextlib import contextmanager from functools import singledispatch -from textwrap import dedent +from textwrap import dedent, indent from typing import Union import numba @@ -650,17 +650,26 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] func_conditions = [ - f"assert x.shape[{i}] == {shape_input_names}" - for i, (shape_input, shape_input_names) in enumerate( + dedent( + f""" + if {shape_name} == -1: + assert x.shape[{i}] != 1 + if {shape_name} > -1: + assert x.shape[{i}] == {shape_name} + """ + ) + for i, (shape_input, shape_name) in enumerate( zip(shape_inputs, shape_input_names) ) if shape_input is not NoneConst ] + conditions_block = "\n".join(func_conditions) + func = dedent( f""" def specify_shape(x, {create_arg_string(shape_input_names)}): - {"; ".join(func_conditions)} + {indent(conditions_block, ' ' * 12)} return x """ ) diff --git a/aesara/tensor/shape.py b/aesara/tensor/shape.py index 065a9d7b35..c3bc5c53c4 100644 --- a/aesara/tensor/shape.py +++ b/aesara/tensor/shape.py @@ -1,7 +1,7 @@ import warnings from numbers import Number from textwrap import dedent -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Sequence, Tuple, Union import numpy as np @@ -17,11 +17,65 @@ from aesara.tensor import basic as at from aesara.tensor import get_vector_length from aesara.tensor.exceptions import NotScalarConstantError -from aesara.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor +from aesara.tensor.type import ( + DenseTensorType, + TensorType, + int_dtypes, + shape_encode, + tensor, +) from aesara.tensor.type_other import NoneConst from aesara.tensor.var import TensorConstant, TensorVariable +def filter_shape_vars( + ref_shape: Tuple[int, ...], shape: Sequence[Variable], shape_is_encoded: bool = True +) -> Tuple[int, ...]: + r"""Compute the most \"informative\" shape based on a static reference. + + Parameters + ---------- + ref_shape + A static shape reference using static shape constraint encoding. + shape + A symbolic shape. + shape_is_encoded + If ``True``, `shape` is assumed to be static shape constraint encoded. + + Returns + ------- + The most specific, and compatible (with `ref_shape`), static shape + constraint encoded values. + """ + shape_bottom = shape_encode(None) + type_shape = () + for i, (xts, s) in enumerate(zip(ref_shape, shape)): + + try: + # TODO FIXME: We shouldn't need to do this; let a rewrite + # do constant folding and update the `TensorType`s. + s_val = at.get_scalar_constant_value(s) + + if isinstance(s_val, np.ndarray): + s_val = s_val.item() + + if shape_is_encoded or s_val is not None and s_val > 0: + type_s = shape_encode(s_val) + else: + type_s = shape_bottom + except NotScalarConstantError: + type_s = shape_bottom + + if not (xts <= -1 or type_s <= -1 or type_s == xts): + raise AssertionError( + f"SpecifyShape: Got shape {xts} at index {i}, expected {type_s}." + ) + + type_shape += (max(type_s, xts),) + + return type_shape + + def register_shape_c_code(type, code, version=()): """ Tell Shape Op how to generate C code for an Aesara Type. @@ -394,7 +448,6 @@ class SpecifyShape(COp): _f16_ok = True def make_node(self, x, *shape): - from aesara.tensor.basic import get_scalar_constant_value x = at.as_tensor_variable(x) @@ -417,18 +470,7 @@ def make_node(self, x, *shape): f"Input `x` is {x.type.ndim}-dimensional and will never match a shape of length {len(shape)}." ) - type_shape = [None] * x.ndim - for i, (xts, s) in enumerate(zip(x.type.shape, shape)): - if xts is not None: - type_shape[i] = xts - else: - try: - type_s = get_scalar_constant_value(s) - if type_s is not None: - type_shape[i] = int(type_s) - except NotScalarConstantError: - pass - + type_shape = filter_shape_vars(x.type.shape_encoded, shape) out_var = x.type.clone(shape=type_shape)() return Apply(self, [x, *shape], [out_var]) @@ -441,10 +483,10 @@ def perform(self, node, inp, out_): raise AssertionError( f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}." ) - if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None): - raise AssertionError( - f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}." - ) + for xs, s in zip(x.shape, shape): + if (s == -1 and xs == 1) or (s is not None and s > -1 and not xs == s): + raise AssertionError(f"SpecifyShape: Got shape {xs}, expected {s}.") + out[0] = x def infer_shape(self, fgraph, node, shapes): @@ -454,11 +496,11 @@ def infer_shape(self, fgraph, node, shapes): for dim in range(node.inputs[0].type.ndim): s = shape[dim] try: - s = at.get_scalar_constant_value(s) - # We assume that `None` shapes are always retrieved by + s = shape_encode(at.get_scalar_constant_value(s)) + # We assume that negative shapes are always retrieved by # `get_scalar_constant_value`, and only in that case do we default to # the shape of the input variable - if s is None: + if s < 0: s = xshape[dim] except NotScalarConstantError: pass @@ -502,6 +544,9 @@ def c_code(self, node, name, i_names, o_names, sub): ); {fail}; }} + + npy_intp shp; + npy_intp actual_shp; """ ) @@ -510,9 +555,11 @@ def c_code(self, node, name, i_names, o_names, sub): continue code += dedent( f""" - if (py_{shp_name} != Py_None){{ - dtype_{shp_name} shp = ((dtype_{shp_name}*)PyArray_GETPTR1({shp_name}, 0))[0]; - if (PyArray_DIMS({x_name})[{i}] != shp) {{ + shp = ((dtype_{shp_name}*)PyArray_GETPTR1({shp_name}, 0))[0]; + + if (shp > -2) {{ + actual_shp = PyArray_DIMS({x_name})[{i}]; + if (actual_shp == -1 && shp == 1 || actual_shp != shp) {{ PyErr_Format(PyExc_AssertionError, "SpecifyShape: dim %d of input has shape %d, expected %d.", {i}, PyArray_DIMS({x_name})[{i}], shp @@ -533,7 +580,7 @@ def c_code(self, node, name, i_names, o_names, sub): return code def c_code_cache_version(self): - return (2,) + return (3,) _specify_shape = SpecifyShape() diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 5890b6e22e..8d6f1781df 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -54,6 +54,20 @@ } +def parse_bcast_and_shape(s): + if isinstance(s, (bool, np.bool_)): + return 1 if s else None + else: + return s + + +def shape_encode(s): + if s is None: + return -2 + + return int(s) + + class TensorType(CType[np.ndarray], HasDataType, HasShape): r"""Symbolic `Type` representing `numpy.ndarray`\s.""" @@ -82,7 +96,8 @@ def __init__( A NumPy dtype (e.g. ``"int64"``). shape The static shape information. ``None``\s are used to indicate - unknown shape values for their respective dimensions. + unknown shape values for their respective dimensions and ``-1`` to + indicate the constraint ``shape != 1``. If `shape` is a list of ``bool``\s, the ``True`` elements of are converted to ``1``\s and the ``False`` values are converted to ``None``\s. @@ -96,7 +111,10 @@ def __init__( "The `broadcastable` keyword is deprecated; use `shape`.", DeprecationWarning, ) - shape = broadcastable + + shape = tuple(parse_bcast_and_shape(s) for s in shape) + + self.name = name if str(dtype) == "floatX": self.dtype = config.floatX @@ -106,15 +124,15 @@ def __init__( self.dtype = np.dtype(dtype).name - def parse_bcast_and_shape(s): - if isinstance(s, (bool, np.bool_)): - return 1 if s else None - else: - return s + self.shape_encoded = tuple(shape_encode(s) for s in shape) + + # TODO: Raise more informative exceptions + assert isinstance(self.shape_encoded, tuple) + assert all( + isinstance(s, int) and not isinstance(s, bool) for s in self.shape_encoded + ) - self.shape = tuple(parse_bcast_and_shape(s) for s in shape) self.dtype_specs() # error checking is done there - self.name = name self.numpy_dtype = np.dtype(self.dtype) def clone( @@ -125,11 +143,13 @@ def clone( "The `broadcastable` keyword is deprecated; use `shape`.", DeprecationWarning, ) - shape = broadcastable + shape = tuple(parse_bcast_and_shape(s) for s in shape) + if dtype is None: dtype = self.dtype if shape is None: - shape = self.shape + shape = self.shape_encoded + return type(self)(dtype, shape, name=self.name) def filter(self, data, strict=False, allow_downcast=None): @@ -243,16 +263,24 @@ def filter(self, data, strict=False, allow_downcast=None): " Aesara C code does not support that.", ) - if not all( - ds == ts if ts is not None else True - for ds, ts in zip(data.shape, self.shape) - ): - raise TypeError( - f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})" - ) + def check_shape_info(i, s_val, s_info): + if s_info == -1 and s_val == 1: + raise TypeError( + f"Value's shape in dimension {i} is not compatible " + f"with the constraint: {s_val} != 1" + ) + if s_info > -1 and s_val != s_info: + raise TypeError( + f"Value's shape in dimension {i} is not compatible " + f"with the constraint: {s_val} == {s_info}" + ) + + for i, (s_val, s_info) in enumerate(zip(np.shape(data), self.shape_encoded)): + check_shape_info(i, s_val, s_info) if self.filter_checks_isfinite and not np.all(np.isfinite(data)): raise ValueError("Non-finite elements not allowed") + return data def filter_variable(self, other, allow_convert=True): @@ -308,7 +336,11 @@ def in_same_class(self, otype): if ( isinstance(otype, TensorType) and otype.dtype == self.dtype - and otype.broadcastable == self.broadcastable + and self.ndim == otype.ndim + and all( + s == o_s if s == 1 or o_s == 1 else True + for s, o_s in zip(self.shape, otype.shape) + ) ): return True return False @@ -320,7 +352,10 @@ def is_super(self, otype): and otype.ndim == self.ndim # `otype` is allowed to be as or more shape-specific than `self`, # but not less - and all(sb == ob or sb is None for sb, ob in zip(self.shape, otype.shape)) + and all( + s == o_s if s > -1 and o_s > -1 else s <= o_s + for s, o_s in zip(self.shape_encoded, otype.shape_encoded) + ) ): return True @@ -334,8 +369,19 @@ def convert_variable(self, var): if (self.ndim == var.type.ndim) and (self.dtype == var.type.dtype): # `var.type` only differs from `self` in that its shape is (at least partially) # less specific than `self`, so we convert `var` to `self`'s `Type`. - # `specify_shape` will combine the more precise shapes of the two types - return aesara.tensor.specify_shape(var, self.shape) + # `specify_shape` will combine the more precise shapes of the two types. + + new_shape_encoded = () + for s, o_s in zip(self.shape_encoded, var.type.shape_encoded): + + if s > -1 and o_s > -1 and s != o_s: + raise TypeError( + f"Incompatible shapes: {self.shape_encoded}, {var.type.shape_encoded}" + ) + + new_shape_encoded += (max(s, o_s),) + + return aesara.tensor.specify_shape(var, new_shape_encoded) @staticmethod def values_eq(a, b, force_same_dtype=True): @@ -367,10 +413,15 @@ def __eq__(self, other): if type(self) != type(other): return NotImplemented - return other.dtype == self.dtype and other.shape == self.shape + return other.dtype == self.dtype and other.shape_encoded == self.shape_encoded def __hash__(self): - return hash((type(self), self.dtype, self.shape)) + return hash((type(self), self.dtype, self.shape_encoded)) + + @property + def shape(self) -> Tuple[Optional[Union[int]]]: + """Return a static shape tuple with unknown values equal to ``None``.""" + return tuple(s if s > -1 else None for s in self.shape_encoded) @property def broadcastable(self): @@ -388,12 +439,14 @@ def __str__(self): else: def shape_str(s): - if s is None: + if s == -1: + return ">1" + elif s < -1: return "?" else: return str(s) - formatted_shape = ", ".join([shape_str(s) for s in self.shape]) + formatted_shape = ", ".join([shape_str(s) for s in self.shape_encoded]) if len(self.shape) == 1: formatted_shape += "," diff --git a/tests/link/jax/test_shape.py b/tests/link/jax/test_shape.py index 6b1bd442fa..3abe06fb3e 100644 --- a/tests/link/jax/test_shape.py +++ b/tests/link/jax/test_shape.py @@ -36,6 +36,12 @@ def test_jax_specify_shape(): compare_jax_and_py(x_fg, []) + x_np = np.zeros((20, 3)) + x = SpecifyShape()(at.as_tensor_variable(x_np), (20, -2)) + x_fg = FunctionGraph([], [x]) + + compare_jax_and_py(x_fg, []) + with config.change_flags(compute_test_value="off"): x = SpecifyShape()(at.as_tensor_variable(x_np), *(2, 3)) @@ -44,6 +50,13 @@ def test_jax_specify_shape(): with pytest.raises(AssertionError): compare_jax_and_py(x_fg, []) + x_np = np.zeros((20, 1)) + x = SpecifyShape()(at.as_tensor_variable(x_np), *(2, -1)) + x_fg = FunctionGraph([], [x]) + + with pytest.raises(AssertionError): + compare_jax_and_py(x_fg, []) + def test_jax_Reshape(): a = vector("a") diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 96ef53203d..bcd6eadd59 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -580,6 +580,16 @@ def test_Reshape_scalar(): (1, None), False, ), + ( + set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), + (1, -2), + False, + ), + ( + set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), + (-1, None), + True, + ), ], ) def test_SpecifyShape(v, shape, fails): diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index e93829d6d2..7e969d260c 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -478,6 +478,10 @@ def test_bad_shape(self): with pytest.raises(AssertionError, match="SpecifyShape:.*"): f(xval, 3) + f = aesara.function([x], specify_shape(x, None, -2), mode=self.mode) + x_val = np.zeros((3, 1), dtype=config.floatX) + assert np.array_equal(f(xval), xval) + def test_infer_shape(self): rng = np.random.default_rng(3453) adtens4 = dtensor4() @@ -514,7 +518,9 @@ def test_direct_return(self): assert specify_shape(x, (1, 2, 3)) is not x assert specify_shape(x, (None, None, 3)) is not x - assert specify_shape(x, (1, 3, None)) is not x + + with pytest.raises(AssertionError, match="SpecifyShape:.*"): + specify_shape(x, (1, 3, None)) def test_specify_shape_in_grad(self): x = matrix() diff --git a/tests/tensor/test_type.py b/tests/tensor/test_type.py index 3df380dd03..f4123958d8 100644 --- a/tests/tensor/test_type.py +++ b/tests/tensor/test_type.py @@ -316,13 +316,29 @@ def test_fixed_shape_convert_variable(): assert res.type.shape == (3, 2) -def test_deprecated_kwargs(): - with pytest.warns(DeprecationWarning, match=".*broadcastable.*"): - res = TensorType("float64", broadcastable=(True, False)) +def test_non_ones_interface(): + """Make sure the ``-1`` shape encoding works as expected.""" - assert res.shape == (1, None) + x_type = TensorType("float64", shape=(-1, 0, 1, 5, None)) - with pytest.warns(DeprecationWarning, match=".*broadcastable.*"): - new_res = res.clone(broadcastable=(False, True)) + with pytest.raises(TypeError): + x_type.filter(np.ones((1, 0, 1, 5, 1))) + + x_val = np.ones((2, 0, 1, 5, 1), dtype=np.float64) + assert x_type.filter(x_val) is x_val + + assert str(x_type) == "TensorType(float64, (>1, 0, 1, 5, ?))" + + assert x_type.shape == (None, 0, 1, 5, None) + assert x_type.shape_encoded == (-1, 0, 1, 5, -2) + + y_type = TensorType("float64", shape=(None, -1, 1, 5, -1)) + + assert x_type != y_type + + # Make sure the filtering and ordering works + xy = x_type.convert_variable(y_type()) + assert xy.type.shape == (None, 0, 1, 5, None) + assert xy.type.shape_encoded == (-1, 0, 1, 5, -1) - assert new_res.shape == (None, 1) + assert not x_type.is_super(y_type)