diff --git a/aesara/tensor/rewriting/math.py b/aesara/tensor/rewriting/math.py index 3c38a6086b..58c3744051 100644 --- a/aesara/tensor/rewriting/math.py +++ b/aesara/tensor/rewriting/math.py @@ -82,7 +82,6 @@ register_uncanonicalize, register_useless, ) -from aesara.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt from aesara.tensor.shape import Shape, Shape_i from aesara.tensor.subtensor import Subtensor from aesara.tensor.type import ( @@ -2843,66 +2842,6 @@ def check_input(inputs): return [ret] -def local_add_mul_fusion(fgraph, node): - """Fuse consecutive add or mul in one such node with more inputs. - - It is better to fuse add/mul that way then in a Composite node as - this make the inner graph of the Composite smaller. This allow to - put more computation in a Composite before hitting the max - recursion limit when pickling Composite. - - """ - if not isinstance(node.op, Elemwise) or not isinstance( - node.op.scalar_op, (aes.Add, aes.Mul) - ): - return False - - s_op = node.op.scalar_op.__class__ - new_inp = [] - fused = False - nb_inputs = len(node.inputs) - max_inputs = float("inf") - if hasattr(node.op, "max_inputs"): - max_inputs = node.op.max_inputs(node) - for inp in node.inputs: - if ( - inp.owner - and isinstance(inp.owner.op, Elemwise) - and isinstance(inp.owner.op.scalar_op, s_op) - and - # Do not duplicate the operation. - len(fgraph.clients[inp]) == 1 - and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs - ): - new_inp.extend(inp.owner.inputs) - fused = True - else: - new_inp.append(inp) - - # We can not compare the number of inputs as Mul and Add could have - # 0 or 1 inputs in some corner cases. - if fused: - output = node.op(*new_inp) - copy_stack_trace(node.outputs[0], output) - - # Do the recursion here to help lower the number of - # FusionOptimizer iteration. - if output.owner: - output2 = local_add_mul_fusion(fgraph, output.owner) - if output2: - return output2 - return [output] - - -fuse_seqopt.register( - "local_add_mul_fusion", - FusionOptimizer(local_add_mul_fusion), - "fast_run", - "fusion", - position=0, -) - - def _skip_mul_1(r): if r.owner and r.owner.op == mul: not_is_1 = [i for i in r.owner.inputs if not _is_1(i)] diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index cfb9b6a61d..2bf776d017 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -998,23 +998,35 @@ def test_big_fusion(self): for node in dlogp.maker.fgraph.toposort() ) - def test_add_mul_fusion_inplace(self): - - rewrites = RewriteDatabaseQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - "inplace", - ], - exclude=["cxx_only", "BlasOpt"], - ) - - mode = Mode(self.mode.linker, rewrites) + def test_add_mul_fusion_precedence(self): + """Test that additions and multiplications are "fused together" before + a `Composite` `Op` is introduced. This fusion is done by canonicalization + """ + x, y, z = vectors("x", "y", "z") + out = log((x + y + z) / (x * y * z)) + f = aesara.function([x, y, z], out, mode=self.mode) + # There should be a single Composite Op + nodes = f.maker.fgraph.apply_nodes + assert len(nodes) == 1 + (node,) = nodes + assert isinstance(node.op, Elemwise) + scalar_op = node.op.scalar_op + assert isinstance(scalar_op, Composite) + assert [node.op for node in scalar_op.fgraph.toposort()] == [ + # There should be a single mul + aes.mul, + # There should be a single add + aes.add, + aes.true_div, + aes.log, + ] + def test_add_mul_fusion_inplace(self): + # Note: This has nothing to do with the FusionOptimizer, as the "fusion" + # is done by canonicalize x, y, z = dmatrices("xyz") out = dot(x, y) + x + y + z - f = function([x, y, z], out, mode=mode) + f = function([x, y, z], out, mode=self.mode) topo = [n for n in f.maker.fgraph.toposort()] assert len(topo) == 2 assert topo[-1].op.inplace_pattern @@ -1026,7 +1038,9 @@ def test_add_mul_fusion_inplace(self): # TODO: Do we really need to do this? _ = f( - np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5)) + np.random.random((5, 5)), + np.random.random((5, 5)), + np.random.random((5, 5)), ) @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index a73632db9d..ec371ec92c 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -16,7 +16,7 @@ from aesara.compile.mode import Mode, get_default_mode, get_mode from aesara.compile.ops import DeepCopyOp, deep_copy_op from aesara.configdefaults import config -from aesara.graph.basic import Apply, Constant, equal_computations +from aesara.graph.basic import Apply, equal_computations from aesara.graph.fg import FunctionGraph from aesara.graph.rewriting.basic import ( SequentialNodeRewriter, @@ -46,7 +46,6 @@ bitwise_or, bitwise_xor, conj, - cos, cosh, deg2rad, dot, @@ -59,14 +58,10 @@ ge, gt, int_div, - invert, - iround, le, log, log1mexp, log1p, - log2, - log10, lt, ) from aesara.tensor.math import max as at_max @@ -74,11 +69,20 @@ from aesara.tensor.math import min as at_min from aesara.tensor.math import minimum, mul, neg, neq from aesara.tensor.math import pow as at_pow -from aesara.tensor.math import prod, rad2deg, reciprocal -from aesara.tensor.math import round as at_round -from aesara.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub +from aesara.tensor.math import ( + prod, + rad2deg, + reciprocal, + sgn, + sigmoid, + sinh, + softplus, + sqr, + sqrt, + sub, +) from aesara.tensor.math import sum as at_sum -from aesara.tensor.math import tan, tanh, true_div, xor +from aesara.tensor.math import tanh, true_div, xor from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift from aesara.tensor.rewriting.math import ( compute_mul, @@ -102,7 +106,6 @@ dvector, fmatrices, fmatrix, - fscalar, ftensor4, fvector, imatrices, @@ -1072,745 +1075,6 @@ def test_cast_in_mul_canonizer(): f([1], [1]) -class TestFusion: - rewrites = RewriteDatabaseQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - "inplace", - ], - exclude=["cxx_only", "BlasOpt"], - ) - mode = Mode(get_default_mode().linker, rewrites) - _shared = staticmethod(shared) - topo_exclude = () - - def do(self, mode, shared_fn, shp, nb_repeat=1, assert_len_topo=True, slice=None): - """ - param shared_fn: if None, will use function - verify that the elemwise fusion work - Test with and without DimShuffle - """ - # TODO: disable the canonizer? - def my_init(shp, dtype="float64", num=0): - ret = np.zeros(shp, dtype=dtype) + num - return ret - - fw, fx, fy, fz = [ - tensor(dtype="float32", shape=[False] * len(shp), name=n) for n in "wxyz" - ] - dw, dx, dy, dz = [ - tensor(dtype="float64", shape=[False] * len(shp), name=n) for n in "wxyz" - ] - ix, iy, iz = [ - tensor(dtype="int32", shape=[False] * len(shp), name=n) for n in "xyz" - ] - fv = fvector("v") - fs = fscalar("s") - - fwv = my_init(shp, "float32", 1) - fxv = my_init(shp, "float32", 2) - fyv = my_init(shp, "float32", 3) - fzv = my_init(shp, "float32", 4) - fvv = _asarray(np.random.random((shp[0])), dtype="float32") - fsv = np.asarray(np.random.random(), dtype="float32") - dwv = my_init(shp, "float64", 5) - ixv = _asarray(my_init(shp, num=60), dtype="int32") - iyv = _asarray(my_init(shp, num=70), dtype="int32") - izv = _asarray(my_init(shp, num=70), dtype="int32") - fwx = fw + fx - ftanx = tan(fx) - cases = [ - ( - fx + fy + fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + fzv, - "float32", - ), # 0 - ( - fx * fy * fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv * fyv * fzv, - "float32", - ), # 1 - ( - fx + fy * fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv * fzv, - "float32", - ), # 2 - ( - fx * fy + fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv * fyv + fzv, - "float32", - ), # 3 - ( - fw + fx + fy + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + fx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), # 5 - ( - ((fw + fx) + fy) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + (fx + fy)) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + (fx + fy) + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - fw + (fx + (fy + fz)), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + fx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), # 10 - ( - fw * fx * fy * fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv * fxv * fyv * fzv, - "float32", - ), - ( - fw + fx * fy * fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv * fyv * fzv, - "float32", - ), - ( - fx + fy * fz * fx, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv * fzv * fxv, - "float32", - ), - ( - fx * fy + fz + fy, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv * fyv + fzv + fyv, - "float32", - ), - ( - fx * fy * fz * fw + fx + fy + fz + fw, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fxv * fyv * fzv * fwv + fxv + fyv + fzv + fwv, - "float32", - ), # 15 - # test with constant - ( - (fw + fx) + (fy + fz) + 2.0, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - ((fw + fx) + 2.0 + fy) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - (fw + (fx + 2.0 + fy)) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - (fw + (fx + fy) + 2 + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - fw + (fx + (fy + fz) + 2.0), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), # 20 - ( - 2 + (fw + fx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - # mix float32 and float64 - ( - 2 + (dw + fx) + (fy + fz), - (dw, fx, fy, fz), - (dwv, fxv, fyv, fzv), - 1, - dwv + fxv + fyv + fzv + 2, - "float64", - ), - ( - 2 + (fw + dw) + (fy + fz), - (fw, dw, fy, fz), - (fwv, dwv, fyv, fzv), - 1, - fwv + dwv + fyv + fzv + 2, - "float64", - ), - ( - 2 + (fw + fx) + (dw + fz), - (fw, fx, dw, fz), - (fwv, fxv, dwv, fzv), - 1, - fwv + fxv + dwv + fzv + 2, - "float64", - ), - ( - 2 + (fw + fx) + (fy + dw), - (fw, fx, fy, dw), - (fwv, fxv, fyv, dwv), - 1, - fwv + fxv + fyv + dwv + 2, - "float64", - ), # 25 - # test when their is other op then elemwise. - ( - (fwx.sum()) + (fwx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 4, - (fwv + fxv).sum() + fwv + fxv + fyv + fzv, - "float32", - ), - # test other elemwise op - ( - fx + fy + cos(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.cos(fzv), - "float32", - ), - ( - fx + fy + cosh(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.cosh(fzv), - "float32", - ), - ( - fx + fy + abs(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.absolute(fzv), - "float32", - ), - ( - ix + iy + abs(iz), - (ix, iy, iz), - (ixv, iyv, izv), - 1, - ixv + iyv + np.absolute(izv), - "int32", - ), # 30 - ( - fx + fy + log(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.log(fzv), - "float32", - ), - ( - fx + fy + log2(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.log2(fzv), - "float32", - ), - ( - fx + fy + log10(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.log10(fzv), - "float32", - ), - ( - fx + fy**fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv**fzv, - "float32", - ), # pow - ( - fx + fy + exp(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.exp(fzv), - "float32", - ), # 35 - ( - fx - fy - fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv - fzv, - "float32", - ), - ( - fx - (fy / fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv / fzv), - "float32", - ), - ( - fx - true_div(fy, 2), - (fx, fy), - (fxv, fyv), - 1, - fxv - (fyv / 2), - "float32", - ), - ( - fx - true_div(fy, fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv / fzv), - "float32", - ), - ( - fx - int_div(ix * 100, iy * 1000), - (fx, ix, iy), - (fxv, ixv, iyv), - 1, - fxv - ((ixv * 100) // (iyv * 1000)), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), # 40 - (fx - (fy / 2), (fx, fy), (fxv, fyv), 1, fxv - (fyv / 2), "float32"), - ( - fx - (fy % fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv % fzv), - "float32", - ), - ( - fx - (fy > fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv > fzv), - "float32", - ), - ( - fx - (fy >= fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv >= fzv), - "float32", - ), - ( - fx - (fy < fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv < fzv), - "float32", - ), # 45 - ( - fx - (fy <= fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv <= fzv), - "float32", - ), - ( - fx - eq(fy, fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv == fzv), - "float32", - ), - ( - fx - neq(fy, fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv != fzv), - "float32", - ), - ( - fx - fy + tan(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.tan(fzv), - "float32", - ), - ( - fx - fy + tanh(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.tanh(fzv), - "float32", - ), # 50 - ( - fx - fy + sin(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.sin(fzv), - "float32", - ), - ( - fx - fy + sinh(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.sinh(fzv), - "float32", - ), - ( - fx - fy + sqr(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + (fzv * fzv), - "float32", - ), - ( - fx - fy + sqrt(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.sqrt(fzv), - "float32", - ), - ( - fx - fy + reciprocal(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + (1 / fzv), - "float32", - ), # 55 - ( - fx - fy + neg(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + (-fzv), - "float32", - ), - ( - fx - fy + at_round(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.round(fzv), - "float32", - ), - ( - ix - iy + iround(fz), - (ix, iy, fz), - (ixv, iyv, fzv), - 1, - ixv - iyv + np.round(fzv), - "int64", - ), - # Bit op - ( - fx - bitwise_or(iy, iz), - (fx, iy, iz), - (fxv, iyv, izv), - 1, - fxv - (iyv | izv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), - ( - fx - xor(iy, iz), - (fx, iy, iz), - (fxv, iyv, izv), - 1, - fxv - (iyv ^ izv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), # 60 - ( - fx - bitwise_and(iy, iz), - (fx, iy, iz), - (fxv, iyv, izv), - 1, - fxv - (iyv & izv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), - ( - fx - invert(iy), - (fx, iy), - (fxv, iyv), - 1, - fxv - (~iyv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), - ( - fx - at.cast(fy, dtype="float64"), - (fx, fy), - (fxv, fyv), - 1, - fxv - np.asarray(fyv, "float64"), - "float64", - ), - ( - at_pow(fx * fy + fz, fx * fy), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - np.power(fxv * fyv + fzv, fxv * fyv), - "float32", - ), - ( - fv + fy**fz, - (fv, fy, fz), - (fvv, fyv, fzv), - 2, - fvv + fyv**fzv, - "float32", - ), # fused with a dimshuffle #65 - ( - fv - fy + tanh(fz), - (fv, fy, fz), - (fvv, fyv, fzv), - 2, - fvv - fyv + np.tanh(fzv), - "float32", - ), # fused with a dimshuffle - # Cases where the same input is reused many times. - ( - mul(fx, fx, fx, fx), - (fx,), - (fxv,), - 1, - fxv * fxv * fxv * fxv, - "float32", - ), - ( - mul(fx, ftanx, ftanx), - (fx,), - (fxv,), - 1, - fxv * np.tan(fxv) * np.tan(fxv), - "float32", - ), - ( - mul(fx, ftanx, ftanx, fx), - (fx,), - (fxv,), - 1, - fxv * np.tan(fxv) * np.tan(fxv) * fxv, - "float32", - ), - ( - mul(ftanx, ftanx, fx + fy), - (fx, fy), - (fxv, fyv), - 1, - np.tan(fxv) * np.tan(fxv) * (fxv + fyv), - "float32", - ), # 70 - # Cases with different broadcast pattern. They should not - # be merged as this would duplicate computation - # The graph should have 2 elemwise and 1 dimshuffle - ( - fx * sin(fs), - (fx, fs), - (fxv, fsv), - 3, - fxv * np.sin(fsv), - "float32", - ), - ] - if slice: - cases = cases[slice] - times = np.zeros(len(cases)) - fail1 = [] - fail2 = [] - fail3 = [] - fail4 = [] - for ( - id, - [g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype], - ) in enumerate(cases): - if isinstance(out_dtype, dict): - out_dtype = out_dtype[config.cast_policy] - - if shared_fn is None: - f = function(list(sym_inputs), g, mode=mode) - for x in range(nb_repeat): - out = f(*val_inputs) - t1 = time.time() - else: - out = shared_fn(np.zeros(shp, dtype=out_dtype), "out") - assert out.dtype == g.dtype - f = function(sym_inputs, [], updates=[(out, g)], mode=mode) - t0 = time.time() - for x in range(nb_repeat): - f(*val_inputs) - t1 = time.time() - out = out.get_value() - - times[id] = t1 - t0 - atol = 1e-8 - if out_dtype == "float32": - atol = 1e-6 - if not np.allclose(out, answer * nb_repeat, atol=atol): - fail1.append(id) - topo = f.maker.fgraph.toposort() - topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)] - if assert_len_topo: - if len(topo_) != nb_elemwise: - fail3.append((id, topo_, nb_elemwise)) - if nb_elemwise == 1: - # if no variable appears multiple times in the - # input of g, - # check that the number of input to the Composite - # Elemwise is ok - if len(set(g.owner.inputs)) == len(g.owner.inputs): - expected_len_sym_inputs = sum( - not isinstance(x, Constant) for x in topo_[0].inputs - ) - assert expected_len_sym_inputs == len(sym_inputs) - - if out_dtype != out.dtype: - fail4.append((id, out_dtype, out.dtype)) - - assert len(fail1 + fail2 + fail3 + fail4) == 0 - - return times - - def test_add_mul_fusion_inplace(self): - - rewrites_query = RewriteDatabaseQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - "inplace", - ], - exclude=["cxx_only", "BlasOpt"], - ) - - mode = Mode(self.mode.linker, rewrites_query) - - x, y, z = dmatrices("xyz") - out = dot(x, y) + x + y + z - f = function([x, y, z], out, mode=mode) - topo = [n for n in f.maker.fgraph.toposort()] - assert len(topo) == 2 - assert topo[-1].op.inplace_pattern - - new_out = f.maker.fgraph.outputs[0] - assert isinstance(new_out.owner.op, Elemwise) - assert isinstance(new_out.owner.op.scalar_op, aes.basic.Add) - assert len(new_out.owner.inputs) == 4 - - # TODO: Do we really need to do this? - _ = f( - np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5)) - ) - - @utt.assertFailure_fast def test_log1p(): m = config.mode