diff --git a/aesara/sparse/rewriting.py b/aesara/sparse/rewriting.py index 0945c8afda..69865c7fe8 100644 --- a/aesara/sparse/rewriting.py +++ b/aesara/sparse/rewriting.py @@ -677,7 +677,7 @@ def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z): assert x_ind.dtype == "int32" assert x_ptr.dtype == "int32" assert x_nrows.dtype == "int32" - assert alpha.ndim == 2 and alpha.type.broadcastable == (True, True) + assert alpha.ndim == 2 and alpha.type.shape == (1, 1) assert x_val.ndim == 1 assert y.ndim == 2 assert z.ndim == 2 @@ -905,7 +905,7 @@ def c_code_cache_version(self): { "pattern": "alpha", "constraint": lambda expr: ( - all(expr.type.broadcastable) and config.blas__ldflags + all(s == 1 for s in expr.type.shape) and config.blas__ldflags ), }, (sparse._dot, "x", "y"),