Skip to content

Commit

Permalink
Add a non-one shape constraint to TensorType
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 31, 2022
1 parent 6a003e5 commit ac64604
Show file tree
Hide file tree
Showing 17 changed files with 305 additions and 181 deletions.
12 changes: 5 additions & 7 deletions aesara/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,13 +1802,11 @@ def verify_grad(
mode=mode,
)

tensor_pt = [
aesara.tensor.type.TensorType(
aesara.tensor.as_tensor_variable(p).dtype,
aesara.tensor.as_tensor_variable(p).broadcastable,
)(name=f"input {i}")
for i, p in enumerate(pt)
]
tensor_pt = []
for i, p in enumerate(pt):
p_t = aesara.tensor.as_tensor_variable(p).type()
p_t.name = f"input {i}"
tensor_pt.append(p_t)

# fun can be either a function or an actual Op instance
o_output = fun(*tensor_pt)
Expand Down
2 changes: 1 addition & 1 deletion aesara/link/c/params_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ class ParamsType(CType):
`ParamsType` constructor takes key-value args. Key will be the name of the
attribute in the struct. Value is the Aesara type of this attribute,
ie. an instance of (a subclass of) :class:`CType`
(eg. ``TensorType('int64', (False,))``).
(eg. ``TensorType('int64', (None,))``).
In a Python code any attribute named ``key`` will be available via::
Expand Down
4 changes: 3 additions & 1 deletion aesara/sandbox/multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def make_node(self, pvals, unis, n=1):
odtype = pvals.dtype
else:
odtype = self.odtype
out = at.tensor(dtype=odtype, shape=pvals.type.broadcastable)
out = at.tensor(
dtype=odtype, shape=tuple(1 if s == 1 else None for s in pvals.type.shape)
)
return Apply(self, [pvals, unis, as_scalar(n)], [out])

def grad(self, ins, outgrads):
Expand Down
11 changes: 7 additions & 4 deletions aesara/sandbox/rng_mrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,13 @@ def make_node(self, rstate, size):
# this op should not be called directly.
#
# call through MRG_RandomStream instead.
broad = []
out_shape = []
for i in range(self.output_type.ndim):
broad.append(at.extract_constant(size[i]) == 1)
output_type = self.output_type.clone(shape=broad)()
if at.extract_constant(size[i]) == 1:
out_shape.append(1)
else:
out_shape.append(None)
output_type = self.output_type.clone(shape=out_shape)()
rstate = as_tensor_variable(rstate)
size = as_tensor_variable(size)
return Apply(self, [rstate, size], [rstate.type(), output_type])
Expand All @@ -392,7 +395,7 @@ def new(cls, rstate, ndim, dtype, size):
v_size = as_tensor_variable(size)
if ndim is None:
ndim = get_vector_length(v_size)
op = cls(TensorType(dtype, (False,) * ndim))
op = cls(TensorType(dtype, shape=(None,) * ndim))
return op(rstate, v_size)

def perform(self, node, inp, out, params):
Expand Down
6 changes: 2 additions & 4 deletions aesara/sparse/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ def __init__(
dtype: Union[str, np.dtype],
shape: Optional[Iterable[Optional[Union[bool, int]]]] = None,
name: Optional[str] = None,
broadcastable: Optional[Iterable[bool]] = None,
):
if shape is None and broadcastable is None:
if shape is None:
shape = (None, None)

if format not in self.format_cls:
Expand All @@ -82,13 +81,12 @@ def __init__(

self.format = format

super().__init__(dtype, shape=shape, name=name, broadcastable=broadcastable)
super().__init__(dtype, shape=shape, name=name)

def clone(
self,
dtype=None,
shape=None,
broadcastable=None,
**kwargs,
):
format: Optional[SparsityTypes] = kwargs.pop("format", self.format)
Expand Down
41 changes: 23 additions & 18 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ def _as_tensor_Variable(x, name, ndim, **kwargs):

if x.type.ndim > ndim:
# Strip off leading broadcastable dimensions
non_broadcastables = [idx for idx in range(x.ndim) if not x.broadcastable[idx]]
non_broadcastables = [
idx for idx in range(x.type.ndim) if x.type.shape_encoded[idx] != 1
]

if non_broadcastables:
x = x.dimshuffle(list(range(x.ndim))[non_broadcastables[0] :])
x = x.dimshuffle(list(range(x.type.ndim))[non_broadcastables[0] :])
else:
x = x.dimshuffle()

Expand Down Expand Up @@ -2210,18 +2212,18 @@ def make_node(self, axis, *tensors):
"Join cannot handle arguments of dimension 0."
" Use `stack` to join scalar values."
)
# Handle single-tensor joins immediately.

if len(tensors) == 1:
bcastable = list(tensors[0].type.broadcastable)
out_shape = tensors[0].type.shape_encoded
else:
# When the axis is fixed, a dimension should be
# broadcastable if at least one of the inputs is
# broadcastable on that dimension (see justification below),
# except for the axis dimension.
# Initialize bcastable all false, and then fill in some trues with
# the loops.
bcastable = [False] * len(tensors[0].type.broadcastable)
ndim = len(bcastable)
ndim = tensors[0].type.ndim
out_shape = [None] * ndim

if not isinstance(axis, int):
try:
Expand All @@ -2246,25 +2248,25 @@ def make_node(self, axis, *tensors):
axis += ndim

for x in tensors:
for current_axis, bflag in enumerate(x.type.broadcastable):
for current_axis, s in enumerate(x.type.shape_encoded):
# Constant negative axis can no longer be negative at
# this point. It safe to compare this way.
if current_axis == axis:
continue
if bflag:
bcastable[current_axis] = True
if s == 1:
out_shape[current_axis] = 1
try:
bcastable[axis] = False
out_shape[axis] = None
except IndexError:
raise ValueError(
f"Axis value {axis} is out of range for the given input dimensions"
)
else:
# When the axis may vary, no dimension can be guaranteed to be
# broadcastable.
bcastable = [False] * len(tensors[0].type.broadcastable)
out_shape = [None] * tensors[0].type.ndim

if not builtins.all(x.ndim == len(bcastable) for x in tensors):
if not builtins.all(x.ndim == len(out_shape) for x in tensors):
raise TypeError(
"Only tensors with the same number of dimensions can be joined"
)
Expand All @@ -2274,7 +2276,7 @@ def make_node(self, axis, *tensors):
if inputs[0].type.dtype not in int_dtypes:
raise TypeError(f"Axis value {inputs[0]} must be an integer type")

return Apply(self, inputs, [tensor(dtype=out_dtype, shape=bcastable)])
return Apply(self, inputs, [tensor(dtype=out_dtype, shape=out_shape)])

def perform(self, node, axis_and_tensors, out_):
(out,) = out_
Expand Down Expand Up @@ -2385,6 +2387,7 @@ def grad(self, axis_and_tensors, grads):
# Split.make_node isn't always able to infer the right
# broadcast. As the grad need to keep the information,
# read it if needed.
# TODO FIXME: Remove all this broadcastable stuff.
split_gz = [
g
if g.type.broadcastable == t.type.broadcastable
Expand Down Expand Up @@ -2771,6 +2774,8 @@ def flatten(x, ndim=1):
else:
dims = (-1,)
x_reshaped = _x.reshape(dims)

# TODO FIXME: Remove all this broadcastable stuff.
bcast_kept_dims = _x.broadcastable[: ndim - 1]
bcast_new_dim = builtins.all(_x.broadcastable[ndim - 1 :])
broadcastable = bcast_kept_dims + (bcast_new_dim,)
Expand Down Expand Up @@ -2882,7 +2887,7 @@ def make_node(self, start, stop, step):
assert step.ndim == 0

inputs = [start, stop, step]
outputs = [tensor(self.dtype, (False,))]
outputs = [tensor(self.dtype, shape=(None,))]

return Apply(self, inputs, outputs)

Expand Down Expand Up @@ -3158,11 +3163,11 @@ def make_node(self, x, y, inverse):
elif x_dim < y_dim:
x = shape_padleft(x, n_ones=(y_dim - x_dim))

# Compute the broadcastable pattern of the output
out_broadcastable = [
xb and yb for xb, yb in zip(x.type.broadcastable, y.type.broadcastable)
out_shape = [
1 if xb == 1 and yb == 1 else None
for xb, yb in zip(x.type.shape_encoded, y.type.shape_encoded)
]
out_type = tensor(dtype=x.type.dtype, shape=out_broadcastable)
out_type = tensor(dtype=x.type.dtype, shape=out_shape)

inputlist = [x, y, inverse]
outputlist = [out_type]
Expand Down
40 changes: 21 additions & 19 deletions aesara/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import (
DenseTensorType,
TensorType,
integer_dtypes,
tensor,
values_eq_approx_remove_inf_nan,
Expand Down Expand Up @@ -1204,11 +1205,11 @@ def _as_scalar(res, dtype=None):
"""Return ``None`` or a `TensorVariable` of float type"""
if dtype is None:
dtype = config.floatX
if all(res.type.broadcastable):
if all(s == 1 for s in res.type.shape):
while res.owner and isinstance(res.owner.op, DimShuffle):
res = res.owner.inputs[0]
# may still have some number of True's
if res.type.broadcastable:
if res.type.ndim > 0:
rval = res.dimshuffle()
else:
rval = res
Expand All @@ -1230,16 +1231,16 @@ def _is_real_matrix(res):
return (
res.type.dtype in ("float16", "float32", "float64")
and res.type.ndim == 2
and res.type.broadcastable[0] is False
and res.type.broadcastable[1] is False
and res.type.shape[0] != 1
and res.type.shape[1] != 1
) # cope with tuple vs. list


def _is_real_vector(res):
return (
res.type.dtype in ("float16", "float32", "float64")
and res.type.ndim == 1
and res.type.broadcastable[0] is False
and res.type.shape[0] != 1
)


Expand Down Expand Up @@ -1298,9 +1299,7 @@ def scaled(thing):
else:
return scale * thing

try:
r.type.broadcastable
except Exception:
if not isinstance(r.type, TensorType):
return None

if (r.type.ndim not in (1, 2)) or r.type.dtype not in (
Expand Down Expand Up @@ -1333,10 +1332,10 @@ def scaled(thing):
vectors = []
matrices = []
for i in r.owner.inputs:
if all(i.type.broadcastable):
if all(s == 1 for s in i.type.shape):
while i.owner and isinstance(i.owner.op, DimShuffle):
i = i.owner.inputs[0]
if i.type.broadcastable:
if i.type.ndim > 0:
scalars.append(i.dimshuffle())
else:
scalars.append(i)
Expand Down Expand Up @@ -1681,8 +1680,7 @@ def make_node(self, x, y):
raise TypeError(y)
if y.type.dtype != x.type.dtype:
raise TypeError("dtype mismatch to Dot22")
bz = (x.type.broadcastable[0], y.type.broadcastable[1])
outputs = [tensor(x.type.dtype, bz)]
outputs = [tensor(x.type.dtype, (x.type.shape[0], y.type.shape[1]))]
return Apply(self, [x, y], outputs)

def perform(self, node, inp, out):
Expand Down Expand Up @@ -1986,8 +1984,8 @@ def make_node(self, x, y, a):
if not a.dtype.startswith("float") and not a.dtype.startswith("complex"):
raise TypeError("Dot22Scalar requires float or complex args", a.dtype)

bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [tensor(x.type.dtype, bz)]
sz = [x.type.shape[0], y.type.shape[1]]
outputs = [tensor(x.type.dtype, shape=sz)]
return Apply(self, [x, y, a], outputs)

def perform(self, node, inp, out):
Expand Down Expand Up @@ -2213,12 +2211,16 @@ def make_node(self, *inputs):
dtype = aesara.scalar.upcast(*[input.type.dtype for input in inputs])
# upcast inputs to common dtype if needed
upcasted_inputs = [at.cast(input, dtype) for input in inputs]
broadcastable = (
(inputs[0].type.broadcastable[0] or inputs[1].type.broadcastable[0],)
+ inputs[0].type.broadcastable[1:-1]
+ inputs[1].type.broadcastable[2:]
out_shape = (
(
1
if inputs[0].type.shape[0] == 1 or inputs[1].type.shape[0] == 1
else None,
)
+ inputs[0].type.shape[1:-1]
+ inputs[1].type.shape[2:]
)
return Apply(self, upcasted_inputs, [tensor(dtype, broadcastable)])
return Apply(self, upcasted_inputs, [tensor(dtype, out_shape)])

def perform(self, node, inp, out):
x, y = inp
Expand Down
13 changes: 6 additions & 7 deletions aesara/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1911,15 +1911,14 @@ def make_node(self, *inputs):
"aesara.tensor.dot instead."
)

i_broadcastables = [input.type.broadcastable for input in inputs]
bx, by = i_broadcastables
if len(by) == 2: # y is a matrix
bz = bx[:-1] + by[-1:]
elif len(by) == 1: # y is vector
bz = bx[:-1]
sx, sy = [input.type.shape for input in inputs]
if len(sy) == 2: # y is a matrix
sz = sx[:-1] + sy[-1:]
elif len(sy) == 1: # y is vector
sz = sx[:-1]

i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(aes.upcast(*i_dtypes), bz)]
outputs = [tensor(aes.upcast(*i_dtypes), shape=sz)]
return Apply(self, inputs, outputs)

def perform(self, node, inp, out):
Expand Down
Loading

0 comments on commit ac64604

Please sign in to comment.