From 6916e84ef70f6f8bb49423d6274edaf954d3c890 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 30 Aug 2022 14:29:03 -0500 Subject: [PATCH 01/11] Use helper function instead of a mixin in tests.graph.test_basic --- tests/graph/test_basic.py | 51 +++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 91085168d1..cdd362b00b 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -87,56 +87,59 @@ def perform(self, *args, **kwargs): MyOp = MyOp() -class X: - def leaf_formatter(self, leaf): - return str(leaf.type) - - def node_formatter(self, node, argstrings): - return f"{node.op}({', '.join(argstrings)})" - - def str(self, inputs, outputs): - return as_string( - inputs, - outputs, - leaf_formatter=self.leaf_formatter, - node_formatter=self.node_formatter, - ) +def leaf_formatter(leaf): + return str(leaf.type) + + +def node_formatter(node, argstrings): + return f"{node.op}({', '.join(argstrings)})" -class TestStr(X): +def format_graph(inputs, outputs): + return as_string( + inputs, + outputs, + leaf_formatter=leaf_formatter, + node_formatter=node_formatter, + ) + + +class TestStr: def test_as_string(self): r1, r2 = MyVariable(1), MyVariable(2) node = MyOp.make_node(r1, r2) - s = self.str([r1, r2], node.outputs) + s = format_graph([r1, r2], node.outputs) assert s == ["MyOp(R1, R2)"] def test_as_string_deep(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) node = MyOp.make_node(r1, r2) node2 = MyOp.make_node(node.outputs[0], r5) - s = self.str([r1, r2, r5], node2.outputs) + s = format_graph([r1, r2, r5], node2.outputs) assert s == ["MyOp(MyOp(R1, R2), R5)"] def test_multiple_references(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) node = MyOp.make_node(r1, r2) node2 = MyOp.make_node(node.outputs[0], node.outputs[0]) - assert self.str([r1, r2, r5], node2.outputs) == ["MyOp(*1 -> MyOp(R1, R2), *1)"] + assert format_graph([r1, r2, r5], node2.outputs) == [ + "MyOp(*1 -> MyOp(R1, R2), *1)" + ] def test_cutoff(self): r1, r2 = MyVariable(1), MyVariable(2) node = MyOp.make_node(r1, r2) node2 = MyOp.make_node(node.outputs[0], node.outputs[0]) - assert self.str(node.outputs, node2.outputs) == ["MyOp(R3, R3)"] - assert self.str(node2.inputs, node2.outputs) == ["MyOp(R3, R3)"] + assert format_graph(node.outputs, node2.outputs) == ["MyOp(R3, R3)"] + assert format_graph(node2.inputs, node2.outputs) == ["MyOp(R3, R3)"] -class TestClone(X): +class TestClone: def test_accurate(self): r1, r2 = MyVariable(1), MyVariable(2) node = MyOp.make_node(r1, r2) _, new = clone([r1, r2], node.outputs, False) - assert self.str([r1, r2], new) == ["MyOp(R1, R2)"] + assert format_graph([r1, r2], new) == ["MyOp(R1, R2)"] def test_copy(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) @@ -160,10 +163,10 @@ def test_not_destructive(self): _, new = clone([r1, r2, r5], node.outputs, False) new_node = new[0].owner new_node.inputs = [MyVariable(7), MyVariable(8)] - assert self.str(graph_inputs(new_node.outputs), new_node.outputs) == [ + assert format_graph(graph_inputs(new_node.outputs), new_node.outputs) == [ "MyOp(R7, R8)" ] - assert self.str(graph_inputs(node.outputs), node.outputs) == [ + assert format_graph(graph_inputs(node.outputs), node.outputs) == [ "MyOp(MyOp(R1, R2), R5)" ] From cd98f20f47d1e00efb1e003557db61c47f1574c4 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 29 Aug 2022 14:56:06 -0500 Subject: [PATCH 02/11] Use WeakValueDictionary for NominalVariable instance caching --- aesara/graph/basic.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aesara/graph/basic.py b/aesara/graph/basic.py index e8d3392c4d..9f8d0408a3 100644 --- a/aesara/graph/basic.py +++ b/aesara/graph/basic.py @@ -26,6 +26,7 @@ Union, cast, ) +from weakref import WeakValueDictionary import numpy as np @@ -44,7 +45,7 @@ if TYPE_CHECKING: from aesara.graph.op import Op - from aesara.graph.type import Type + from aesara.graph.type import Type # noqa: F401 OpType = TypeVar("OpType", bound="Op") @@ -672,7 +673,8 @@ def clone(self, **kwargs): class NominalVariable(AtomicVariable[_TypeType]): """A variable that enables alpha-equivalent comparisons.""" - __instances__: Dict[Tuple["Type", Hashable], "NominalVariable"] = {} + # WeakValueDictionary[Tuple["Type", Hashable], "NominalVariable"] + __instances__: WeakValueDictionary = WeakValueDictionary() def __new__(cls, id: _IdType, typ: _TypeType, **kwargs): if (typ, id) not in cls.__instances__: From 14830f82b491a00db5886cd592f8ec0ebe5f471f Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 29 Aug 2022 15:00:54 -0500 Subject: [PATCH 03/11] Add HashableNDArray type --- aesara/utils.py | 82 ++++++++++++++++++++++++++++++++++++++++++++- tests/test_utils.py | 33 ++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 tests/test_utils.py diff --git a/aesara/utils.py b/aesara/utils.py index 613d39ad16..03a539ccc0 100644 --- a/aesara/utils.py +++ b/aesara/utils.py @@ -4,6 +4,7 @@ import inspect import logging import os +import pickle import struct import subprocess import sys @@ -12,7 +13,9 @@ from collections import OrderedDict from collections.abc import Callable from functools import partial, wraps -from typing import List, Set +from typing import Hashable, List, Set + +import numpy as np __all__ = [ @@ -458,3 +461,80 @@ def copy(self): def __copy__(self): return type(self)(self.default_factory, self) + + +class HashableNDArray(np.ndarray, Hashable): + """A subclass of Numpy's ndarray that uses `tostring` hashing and `array_equal` equality testing. + + Usage + ----- + >>> import numpy as np + >>> from symbolic_pymc.utils import HashableNDArray + >>> x = np.r_[1, 2, 3] + >>> x_new = x.view(HashableNDArray) + >>> assert hash(x_new) == hash(x.tostring()) + >>> assert x_new == np.r_[1, 2, 3] + """ + + use_ndarray = False + _hash_val = None + + def __hash__(self): + """ + NDArray hashing based on `joblib`: + https://github.com/joblib/joblib/blob/1fdf3086674d7b1be27688e8c7aebd3159d89997/joblib/hashing.py#L178 + """ + if self._hash_val is not None: + return self._hash_val + + if self.shape == (): + self_c = self.flatten() + elif self.flags.c_contiguous: + self_c = self + elif self.flags.f_contiguous: + self_c = self.T + else: + self_c = self.flatten() + + self_c_bytes = self_c.view(np.uint8) + + if hasattr(np, "getbuffer"): + self_buffer = np.getbuffer(self_c_bytes) + else: + self_buffer = memoryview(self_c_bytes) + + h = hashlib.sha256() + h.update(self_buffer) + + if self.use_ndarray and isinstance(self, np.memmap): + sig = (np.ndarray,) + else: + sig = (self.__class__,) + + sig += ((self.dtype, self.shape, self.strides),) + + # XXX: Weak references have a hard time with this, and will raise + # import errors during pickling when the Python session ends (i.e. when + # "destructors" are called). It could be that packages are being + # removed/unloaded (or at least their symbols within the pickler's + # dispatch tables) _before_ this method is called. + # The hash memoizing below avoids this problem. + h.update(pickle.dumps(sig)) + + _hash_val = int(h.hexdigest(), 16) + + # If we can assume that the underlying array won't change, + # we shouldn't need to recompute this hash every time. + if not self.flags["WRITEABLE"]: + self._hash_val = _hash_val + + return _hash_val + + def __eq__(self, other): + return np.array_equal(self, other) + + def __ne__(self, other): + if self.__eq__(other): + return False + + return NotImplemented diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000..e2f739be57 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,33 @@ +import pickle + +import numpy as np + +from aesara.utils import HashableNDArray + + +def test_HashableNDArray(): + rng = np.random.default_rng(2392) + + x = rng.random(size=(3, 2)) + + x_hnda_1 = x.view(HashableNDArray) + x_hnda_2 = x.view(HashableNDArray) + + assert x_hnda_1 is not x_hnda_2 + assert x_hnda_1 == x_hnda_2 + assert hash(x_hnda_1) == hash(x_hnda_2) + + x_pkl = pickle.dumps(x_hnda_1) + x_hnda_3 = pickle.loads(x_pkl) + + assert x_hnda_3 == x_hnda_1 + + import weakref + + wd = weakref.WeakValueDictionary() + wd[(1, 2)] = x_hnda_1 + wd[(2, 3)] = x_hnda_2 + assert wd[(1, 2)] is x_hnda_1 + del x_hnda_1 + + assert (1, 2) not in wd From 3c46fc49e9b3ce2b6bd17e0a843324303376bf2e Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 11 Sep 2022 16:32:47 -0500 Subject: [PATCH 04/11] Replace string comparisons with equal_computations in rewrite tests --- tests/graph/rewriting/test_basic.py | 124 ++++++++++++---------------- 1 file changed, 55 insertions(+), 69 deletions(-) diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index e68d42b9cb..b347ce3990 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -17,7 +17,6 @@ SubstitutionNodeRewriter, WalkingGraphRewriter, in2out, - logging, node_rewriter, pre_constant_merge, pre_greedy_node_rewriter, @@ -65,50 +64,50 @@ def test_replace_output(self): # replacing the whole graph x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), z) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).rewrite( g ) - assert str(g) == "FunctionGraph(Op4(z, y))" + assert equal_computations(g.outputs, [op4(z, y)]) def test_nested_out_pattern(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(x, y) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter( (op1, "1", "2"), (op4, (op1, "1"), (op2, "2"), (op3, "1", "2")) ).rewrite(g) - assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))" + assert equal_computations(g.outputs, [op4(op1(x), op2(y), op3(x, y))]) def test_unification_1(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, x), z) # the arguments to op2 are the same - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter( (op1, (op2, "1", "1"), "2"), # they are the same in the pattern (op4, "2", "1"), ).rewrite(g) # So the replacement should occur - assert str(g) == "FunctionGraph(Op4(z, x))" + assert equal_computations(g.outputs, [op4(z, x)]) def test_unification_2(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), z) # the arguments to op2 are different - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter( (op1, (op2, "1", "1"), "2"), # they are the same in the pattern (op4, "2", "1"), ).rewrite(g) # The replacement should NOT occur - assert str(g) == "FunctionGraph(Op1(Op2(x, y), z))" + assert equal_computations(g.outputs, [op1(op2(x, y), z)]) def test_replace_subgraph(self): # replacing inside the graph x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), z) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter((op2, "1", "2"), (op1, "2", "1")).rewrite(g) - assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))" + assert equal_computations(g.outputs, [op1(op1(y, x), z)]) def test_no_recurse(self): # if the out pattern is an acceptable in pattern @@ -116,40 +115,40 @@ def test_no_recurse(self): # it should do the replacement and stop x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), z) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter((op2, "1", "2"), (op2, "2", "1"), ign=True).rewrite(g) - assert str(g) == "FunctionGraph(Op1(Op2(y, x), z))" + assert equal_computations(g.outputs, [op1(op2(y, x), z)]) def test_multiple(self): # it should replace all occurrences of the pattern x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") - e = op1(op2(x, y), op2(x, y), op2(y, z)) - g = FunctionGraph([x, y, z], [e]) + e = op1(op2(x, y), op2(y, x), op2(y, z)) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter((op2, "1", "2"), (op4, "1")).rewrite(g) - assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))" + assert equal_computations(g.outputs, [op1(op4(x), op4(y), op4(y))]) def test_nested_even(self): # regardless of the order in which we rewrite, this # should work x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(op1(x)))) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g) - assert str(g) == "FunctionGraph(x)" + assert equal_computations(g.outputs, [x]) def test_nested_odd(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(op1(op1(x))))) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g) - assert str(g) == "FunctionGraph(Op1(x))" + assert equal_computations(g.outputs, [op1(x)]) def test_expand(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(x))) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter((op1, "1"), (op2, (op1, "1")), ign=True).rewrite(g) - assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))" + assert equal_computations(g.outputs, [op2(op1(op2(op1(op2(op1(x))))))]) def test_ambiguous(self): # this test should always work with WalkingGraphRewriter and the @@ -157,23 +156,23 @@ def test_ambiguous(self): # = True or with other NodeProcessingGraphRewriters may differ. x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(op1(op1(x))))) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) WalkingPatternNodeRewriter((op1, (op1, "1")), (op1, "1"), ign=False).rewrite(g) - assert str(g) == "FunctionGraph(Op1(x))" + assert equal_computations(g.outputs, [op1(x)]) def test_constant(self): x = Constant(MyType(), 2, name="x") y = MyVariable("y") z = Constant(MyType(), 2, name="z") e = op1(op1(x, y), y) - g = FunctionGraph([y], [e]) + g = FunctionGraph([y], [e], clone=False) OpKeyPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g) - assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))" + assert equal_computations(g.outputs, [op1(op2(y, z), y)]) def test_constraints(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op4(op1(op2(x, y)), op1(op1(x, y))) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) def constraint(r): # Only replacing if the input is an instance of Op2 @@ -182,14 +181,14 @@ def constraint(r): OpKeyPatternNodeRewriter( (op1, {"pattern": "1", "constraint": constraint}), (op3, "1") ).rewrite(g) - assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))" + assert equal_computations(g.outputs, [op4(op3(op2(x, y)), op1(op1(x, y)))]) def test_match_same(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(x, x) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter((op1, "x", "y"), (op3, "x", "y")).rewrite(g) - assert str(g) == "FunctionGraph(Op3(x, x))" + assert equal_computations(g.outputs, [op3(x, x)]) @pytest.mark.xfail( reason="This pattern & constraint case isn't used and doesn't make much sense." @@ -197,7 +196,7 @@ def test_match_same(self): def test_match_same_illegal(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op2(op1(x, x), op1(x, y)) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) def constraint(r): # Only replacing if the input is an instance of Op2 @@ -206,27 +205,26 @@ def constraint(r): OpKeyPatternNodeRewriter( {"pattern": (op1, "x", "y"), "constraint": constraint}, (op3, "x", "y") ).rewrite(g) - assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))" + assert equal_computations(g.outputs, [op2(op1(x, x), op3(x, y))]) def test_allow_multiple_clients(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e0 = op1(x, y) # `e0` has multiple clients (i.e. the `op4` and `op3` nodes) e = op3(op4(e0), e0) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter((op4, (op1, "x", "y")), (op3, "x", "y")).rewrite(g) - assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))" + assert equal_computations(g.outputs, [op3(op4(op1(x, y)), op1(x, y))]) def test_eq(self): # replacing the whole graph x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op_y(x, y), z) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) OpKeyPatternNodeRewriter((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).rewrite( g ) - str_g = str(g) - assert str_g == "FunctionGraph(Op4(z, y))" + assert equal_computations(g.outputs, [op4(z, y)]) def KeyedSubstitutionNodeRewriter(op1, op2): @@ -237,16 +235,16 @@ class TestSubstitutionNodeRewriter: def test_straightforward(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(op1(op1(x))))) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) KeyedSubstitutionNodeRewriter(op1, op2).rewrite(g) - assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))" + assert equal_computations(g.outputs, [op2(op2(op2(op2(op2(x)))))]) def test_straightforward_2(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x), op3(y), op4(z)) - g = FunctionGraph([x, y, z], [e]) + g = FunctionGraph([x, y, z], [e], clone=False) KeyedSubstitutionNodeRewriter(op3, op4).rewrite(g) - assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))" + assert equal_computations(g.outputs, [op1(op2(x), op4(y), op4(z))]) class NoInputOp(Op): @@ -450,8 +448,7 @@ def test_1(self): # TODO FIXME: These `Op`s don't have matching/consistent `__prop__`s # and `__init__`s, so they can't be `etuplized` correctly e = op3(op4(x, y)) - g = FunctionGraph([x, y, z], [e]) - # print g + g = FunctionGraph([x, y, z], [e], clone=False) rewriter = EquilibriumGraphRewriter( [ PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")), @@ -461,14 +458,12 @@ def test_1(self): max_use_ratio=10, ) rewriter.rewrite(g) - # print g - assert str(g) == "FunctionGraph(Op2(x, y))" + assert equal_computations(g.outputs, [op2(x, y)]) def test_2(self): x, y, z = map(MyVariable, "xyz") e = op1(op1(op3(x, y))) - g = FunctionGraph([x, y, z], [e]) - # print g + g = FunctionGraph([x, y, z], [e], clone=False) rewriter = EquilibriumGraphRewriter( [ PatternNodeRewriter((op1, (op2, "x", "y")), (op4, "x", "y")), @@ -480,33 +475,24 @@ def test_2(self): max_use_ratio=10, ) rewriter.rewrite(g) - assert str(g) == "FunctionGraph(Op2(x, y))" + assert equal_computations(g.outputs, [op2(x, y)]) - @config.change_flags(on_opt_error="ignore") + @config.change_flags(on_opt_error="raise") def test_low_use_ratio(self): x, y, z = map(MyVariable, "xyz") e = op3(op4(x, y)) - g = FunctionGraph([x, y, z], [e]) - # print 'before', g - # display pesky warnings along with stdout - # also silence logger for 'aesara.graph.rewriting.basic' - _logger = logging.getLogger("aesara.graph.rewriting.basic") - oldlevel = _logger.level - _logger.setLevel(logging.CRITICAL) - try: - rewriter = EquilibriumGraphRewriter( - [ - PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")), - PatternNodeRewriter((op4, "x", "y"), (op1, "x", "y")), - PatternNodeRewriter((op3, (op2, "x", "y")), (op4, "x", "y")), - ], - max_use_ratio=1.0 / len(g.apply_nodes), - ) + g = FunctionGraph([x, y, z], [e], clone=False) + rewriter = EquilibriumGraphRewriter( + [ + PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")), + PatternNodeRewriter((op4, "x", "y"), (op1, "x", "y")), + PatternNodeRewriter((op3, (op2, "x", "y")), (op4, "x", "y")), + ], + max_use_ratio=1.0 / len(g.apply_nodes), + ) + with pytest.raises(AssertionError): rewriter.rewrite(g) - finally: - _logger.setLevel(oldlevel) - # print 'after', g - assert str(g) == "FunctionGraph(Op1(x, y))" + assert equal_computations(g.outputs, [op1(x, y)]) def test_pre_constant_merge(): From a0fe9a87c7f0cfb824017148f7d7288824d4aa54 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sat, 17 Sep 2022 13:09:41 -0500 Subject: [PATCH 05/11] Remove pre_constant_merge --- aesara/graph/rewriting/basic.py | 68 ----------------------------- tests/graph/rewriting/test_basic.py | 48 +------------------- 2 files changed, 1 insertion(+), 115 deletions(-) diff --git a/aesara/graph/rewriting/basic.py b/aesara/graph/rewriting/basic.py index a6bd80f9b2..4df9838b95 100644 --- a/aesara/graph/rewriting/basic.py +++ b/aesara/graph/rewriting/basic.py @@ -25,7 +25,6 @@ from aesara.graph.basic import ( Apply, AtomicVariable, - Constant, Variable, applys_between, io_toposort, @@ -881,73 +880,6 @@ def merge_none_number(v1, v2): ) -def pre_constant_merge(fgraph, variables): - """Merge constants in the graphs given by `variables`. - - .. warning:: - - This changes the nodes in a graph in-place! - - Parameters - ---------- - fgraph - A `FunctionGraph` instance in which some of these `variables` may - reside. - - We want to avoid terms in `variables` that are contained in `fgraph`. - The reason for that: it will break consistency of `fgraph` and its - features (e.g. `ShapeFeature`). - - variables - A list of nodes for which we want to merge constant inputs. - - Notes - ----- - It is used to pre-merge nodes generated inside an rewrite. It is - useful if there are many such replacements to make, so that `DebugMode` - will not check each of them. - - """ - seen_var = set() - # signature -> variable (for constants) - const_sig_inv = {} - if isinstance(variables, Variable): - variables = [variables] - - def recursive_merge(var): - - if var in seen_var: - return var - - if not hasattr(var, "owner"): - return var - - # We don't want to merge constants that are *within* the - # `FunctionGraph` - if var.owner in fgraph.apply_nodes: - return var - - seen_var.add(var) - - if isinstance(var, Constant): - sig = var.signature() - - if sig in const_sig_inv: - return const_sig_inv[sig] - - const_sig_inv[sig] = var - - return var - - if var.owner: - for idx, inp in enumerate(var.owner.inputs): - # XXX: This is changing the graph in place! - var.owner.inputs[idx] = recursive_merge(inp) - return var - - return [recursive_merge(v) for v in variables] - - class MetaNodeRewriter(NodeRewriter): r""" Base class for meta-rewriters that try a set of `NodeRewriter`\s diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index b347ce3990..6bdd6bd8c5 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -18,15 +18,13 @@ WalkingGraphRewriter, in2out, node_rewriter, - pre_constant_merge, pre_greedy_node_rewriter, ) from aesara.raise_op import assert_op from aesara.tensor.math import Dot, add, dot from aesara.tensor.rewriting.basic import constant_folding -from aesara.tensor.subtensor import AdvancedSubtensor from aesara.tensor.type import matrix, values_eq_approx_always_true -from aesara.tensor.type_other import MakeSlice, SliceConstant, slicetype +from aesara.tensor.type_other import MakeSlice, SliceConstant from tests.graph.utils import ( MyOp, MyType, @@ -495,50 +493,6 @@ def test_low_use_ratio(self): assert equal_computations(g.outputs, [op1(x, y)]) -def test_pre_constant_merge(): - - empty_fgraph = FunctionGraph([], []) - - x = MyVariable("x") - y = MyVariable("y") - c1 = Constant(MyType(), 1, "c1") - c2 = Constant(MyType(), 1, "c1") - o1 = op2(c1, x) - o2 = op1(o1, y, c2) - - assert c1 is not c2 - - res = pre_constant_merge(empty_fgraph, [o2]) - - assert [o2] == res - assert o2.owner.inputs[2] is c1 - - o2 = op1(o1, y, c2) - fg = FunctionGraph([x, y], [o2], clone=False) - - assert o2.owner in fg.apply_nodes - - res = pre_constant_merge(fg, [o2]) - - assert res == [o2] - assert o2.owner.inputs[2] is c2 - - # What is this supposed to test? - ms = MakeSlice()(1) - res = pre_constant_merge(empty_fgraph, [ms]) - - assert res == [ms] - - const_slice = SliceConstant(type=slicetype, data=slice(1, None, 2)) - - assert isinstance(const_slice, Constant) - - adv = AdvancedSubtensor()(matrix(), [2, 3], const_slice) - - res = pre_constant_merge(empty_fgraph, adv) - assert res == [adv] - - def test_pre_greedy_node_rewriter(): empty_fgraph = FunctionGraph([], []) From 2625228ccd622ffbef0d9b601cdf664826d455a0 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sat, 17 Sep 2022 13:24:42 -0500 Subject: [PATCH 06/11] Remove pre_greedy_node_rewriter --- aesara/graph/rewriting/basic.py | 101 ---------------------------- tests/graph/rewriting/test_basic.py | 48 ------------- 2 files changed, 149 deletions(-) diff --git a/aesara/graph/rewriting/basic.py b/aesara/graph/rewriting/basic.py index 4df9838b95..135581820b 100644 --- a/aesara/graph/rewriting/basic.py +++ b/aesara/graph/rewriting/basic.py @@ -2771,102 +2771,6 @@ def check_chain(r, *chain): return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain))) -def pre_greedy_node_rewriter( - fgraph: FunctionGraph, rewrites: Sequence[NodeRewriter], out: Variable -) -> Variable: - """Apply node rewriters throughout a graph in a greedy, pre-traversal way. - - This function traverses the computation graph in the graph before the - variable `out` but that are not in the `fgraph`. It applies - `rewrites` to each variable on the traversed graph. - - .. warning:: - - This changes the nodes in a graph in-place. - - Its main use is to apply locally constant folding when generating - the graph of the indices of a `Subtensor`. - - Changes should not be applied to nodes that are in an `fgraph`, - so we use `fgraph` to prevent that. - - Notes - ----- - This doesn't do an equilibrium rewrite, so, if there is a rewrite--like - `local_upcast_elemwise_constant_inputs`--in the list that adds additional - nodes to the inputs of the node, it might be necessary to call this - function multiple times. - - Parameters - ---------- - fgraph - The graph used to avoid/filter nodes. - rewrites - A sequence of rewrites to apply. - out - The graph to rewrite. - - """ - - def local_recursive_function( - rewrite_list: Sequence[NodeRewriter], - out: Variable, - rewritten_vars: Dict[Variable, Variable], - depth: int, - ) -> Tuple[List[Variable], Dict[Variable, Variable]]: - if not getattr(out, "owner", None): - return [out], rewritten_vars - node = out.owner - - if node in fgraph.apply_nodes: - return node.outputs, rewritten_vars - - # Walk up the graph via the node's inputs - for idx, inp in enumerate(node.inputs): - if inp in rewritten_vars: - nw_in = rewritten_vars[inp] - else: - if inp.owner: - outs, rewritten_vars = local_recursive_function( - rewrite_list, inp, rewritten_vars, depth + 1 - ) - for k, v in zip(inp.owner.outputs, outs): - rewritten_vars[k] = v - nw_in = outs[inp.owner.outputs.index(inp)] - - else: - nw_in = inp - rewritten_vars[inp] = inp - - # XXX: An in-place change - node.inputs[idx] = nw_in - - # Apply the rewrites - results = node.outputs - for rewrite in rewrite_list: - ret = rewrite.transform(fgraph, node) - if ret is not False and ret is not None: - assert isinstance(ret, Sequence) - assert len(ret) == len(node.outputs), rewrite - for k, v in zip(node.outputs, ret): - rewritten_vars[k] = v - results = ret - if ret[0].owner: - node = out.owner - else: - break - - return results, rewritten_vars - - if out.owner: - out_index: int = out.owner.outputs.index(out) - else: - out_index = 0 - - final_outs, rewritten_nodes = local_recursive_function(rewrites, out, {}, 0) - return final_outs[out_index] - - def copy_stack_trace(from_var, to_var): r"""Copy the stack traces from `from_var` to `to_var`. @@ -3092,11 +2996,6 @@ def apply(self, fgraph): "`local_optimizer` is deprecated: use `node_rewriter` instead.", node_rewriter, ), - ( - "pre_greedy_local_optimizer", - "`pre_greedy_local_optimizer` is deprecated: use `pre_greedy_node_rewriter` instead.", - pre_greedy_node_rewriter, - ), ( "FromFunctionOptimizer", "`FromFunctionOptimizer` is deprecated: use `FromFunctionGraphRewriter` instead.", diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index 6bdd6bd8c5..d0d85030a6 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -18,13 +18,10 @@ WalkingGraphRewriter, in2out, node_rewriter, - pre_greedy_node_rewriter, ) from aesara.raise_op import assert_op from aesara.tensor.math import Dot, add, dot -from aesara.tensor.rewriting.basic import constant_folding from aesara.tensor.type import matrix, values_eq_approx_always_true -from aesara.tensor.type_other import MakeSlice, SliceConstant from tests.graph.utils import ( MyOp, MyType, @@ -493,51 +490,6 @@ def test_low_use_ratio(self): assert equal_computations(g.outputs, [op1(x, y)]) -def test_pre_greedy_node_rewriter(): - - empty_fgraph = FunctionGraph([], []) - - x = MyVariable("x") - y = MyVariable("y") - c1 = Constant(MyType(), 1, "c1") - c2 = Constant(MyType(), 2, "c2") - o1 = op2(c1, c2) - o3 = op1(c1, y) - o2 = op1(o1, c2, x, o3, o1) - - assert o2.owner.inputs[0].owner is not None - assert o2.owner.inputs[4].owner is not None - - # This should fold `o1`, because it has only `Constant` arguments, and - # replace it with the `Constant` result - cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], o2) - - assert cst.owner.inputs[0].owner is None - assert cst.owner.inputs[1] is c2 - assert cst.owner.inputs[2] is x - assert cst.owner.inputs[3] is o3 - assert cst.owner.inputs[4] is cst.owner.inputs[0] - - # We're going to do it again, except this time `o1` is - # in the `fgraph`, so it shouldn't be folded - fg = FunctionGraph([], [o1], clone=False) - o2 = op1(o1, c2, x, o3, o1) - - cst = pre_greedy_node_rewriter(fg, [constant_folding], o2) - - assert cst.owner.inputs[0] is o1 - assert cst.owner.inputs[4] is cst.owner.inputs[0] - - # What exactly is this supposed to test? - ms = MakeSlice()(1) - cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], ms) - - assert isinstance(cst, SliceConstant) - - # Make sure constant of slice signature is hashable. - assert isinstance(hash(cst.signature()), int) - - @pytest.mark.parametrize("tracks", [True, False]) @pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0]) def test_patternsub_values_eq_approx(out_pattern, tracks): From 50a57bfe4c40b339332e614e7efb2fd4003fd1cc Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sat, 17 Sep 2022 13:33:10 -0500 Subject: [PATCH 07/11] Remove is_same_graph --- aesara/compile/builders.py | 4 +- aesara/graph/rewriting/utils.py | 161 +--------------------------- tests/graph/rewriting/test_utils.py | 137 +---------------------- tests/tensor/rewriting/test_math.py | 8 +- tests/tensor/test_subtensor.py | 6 +- 5 files changed, 10 insertions(+), 306 deletions(-) diff --git a/aesara/compile/builders.py b/aesara/compile/builders.py index cca88ffbc7..0ea6064a21 100644 --- a/aesara/compile/builders.py +++ b/aesara/compile/builders.py @@ -168,9 +168,7 @@ class OpFromGraph(Op, HasInnerGraph): .. TODO: - examples for a multi-layer mlp. where? - - __hash__, __eq__ otherwise won't merge, try - is_same_graph_with_merge(op1.local_outputs, op2, - local_outputs) + - __hash__, __eq__ otherwise won't merge - c_code() to remove the double overhead? - grad() make it support DisconnectedType and the new interface - add support for NullType and DisconnectedType when R_op supports them diff --git a/aesara/graph/rewriting/utils.py b/aesara/graph/rewriting/utils.py index 536c45620b..aa8dd66a3e 100644 --- a/aesara/graph/rewriting/utils.py +++ b/aesara/graph/rewriting/utils.py @@ -1,15 +1,7 @@ -import copy import warnings from typing import TYPE_CHECKING, Generator, Optional, Sequence, Union, cast -import aesara -from aesara.graph.basic import ( - Apply, - Variable, - equal_computations, - graph_inputs, - vars_between, -) +from aesara.graph.basic import Apply, Variable from aesara.graph.fg import FunctionGraph from aesara.graph.rewriting.db import RewriteDatabaseQuery @@ -82,157 +74,6 @@ def rewrite_graph( return fgraph.outputs[0] -def is_same_graph_with_merge(var1, var2, givens=None): - """ - Merge-based implementation of `aesara.graph.basic.is_same_graph`. - - See help on `aesara.graph.basic.is_same_graph` for additional documentation. - - """ - from aesara.graph.rewriting.basic import MergeOptimizer - - if givens is None: - givens = {} - # Copy variables since the MergeOptimizer will modify them. - copied = copy.deepcopy([var1, var2, givens]) - vars = copied[0:2] - givens = copied[2] - # Create FunctionGraph. - inputs = list(graph_inputs(vars)) - # The clone isn't needed as we did a deepcopy and we cloning will - # break the mapping in givens. - fgraph = aesara.graph.fg.FunctionGraph(inputs, vars, clone=False) - # Perform Variable substitution. - for to_replace, replace_by in givens.items(): - fgraph.replace(to_replace, replace_by) - # Perform merge optimization. - MergeOptimizer().rewrite(fgraph) - # When two variables perform the same computations, they will have the same - # owner in the rewritten graph. - # We need to be careful with the special case where the owner is None, - # which happens when the graph is made of a single Variable. - # We also need to make sure we replace a Variable if it is present in - # `givens`. - vars_replaced = [givens.get(v, v) for v in fgraph.outputs] - o1, o2 = [v.owner for v in vars_replaced] - if o1 is None and o2 is None: - # Comparing two single-Variable graphs: they are equal if they are - # the same Variable. - return vars_replaced[0] == vars_replaced[1] - else: - return o1 is o2 - - -def is_same_graph(var1, var2, givens=None): - """ - Return True iff Variables `var1` and `var2` perform the same computation. - - By 'performing the same computation', we mean that they must share the same - graph, so that for instance this function will return False when comparing - (x * (y * z)) with ((x * y) * z). - - The current implementation is not efficient since, when possible, it - verifies equality by calling two different functions that are expected to - return the same output. The goal is to verify this assumption, to - eventually get rid of one of them in the future. - - Parameters - ---------- - var1 - The first Variable to compare. - var2 - The second Variable to compare. - givens - Similar to the `givens` argument of `aesara.function`, it can be used - to perform substitutions in the computational graph of `var1` and - `var2`. This argument is associated to neither `var1` nor `var2`: - substitutions may affect both graphs if the substituted variable - is present in both. - - Examples - -------- - - ====== ====== ====== ====== - var1 var2 givens output - ====== ====== ====== ====== - x + 1 x + 1 {} True - x + 1 y + 1 {} False - x + 1 y + 1 {x: y} True - ====== ====== ====== ====== - - """ - use_equal_computations = True - - if givens is None: - givens = {} - - if not isinstance(givens, dict): - givens = dict(givens) - - # Get result from the merge-based function. - rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens) - - if givens: - # We need to build the `in_xs` and `in_ys` lists. To do this, we need - # to be able to tell whether a variable belongs to the computational - # graph of `var1` or `var2`. - # The typical case we want to handle is when `to_replace` belongs to - # one of these graphs, and `replace_by` belongs to the other one. In - # other situations, the current implementation of `equal_computations` - # is probably not appropriate, so we do not call it. - ok = True - in_xs = [] - in_ys = [] - # Compute the sets of all variables found in each computational graph. - inputs_var = list(map(graph_inputs, ([var1], [var2]))) - all_vars = [ - set(vars_between(v_i, v_o)) - for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2])) - ] - - def in_var(x, k): - # Return True iff `x` is in computation graph of variable `vark`. - return x in all_vars[k - 1] - - for to_replace, replace_by in givens.items(): - # Map a substitution variable to the computational graphs it - # belongs to. - inside = { - v: [in_var(v, k) for k in (1, 2)] for v in (to_replace, replace_by) - } - if ( - inside[to_replace][0] - and not inside[to_replace][1] - and inside[replace_by][1] - and not inside[replace_by][0] - ): - # Substitute variable in `var1` by one from `var2`. - in_xs.append(to_replace) - in_ys.append(replace_by) - elif ( - inside[to_replace][1] - and not inside[to_replace][0] - and inside[replace_by][0] - and not inside[replace_by][1] - ): - # Substitute variable in `var2` by one from `var1`. - in_xs.append(replace_by) - in_ys.append(to_replace) - else: - ok = False - break - if not ok: - # We cannot directly use `equal_computations`. - use_equal_computations = False - else: - in_xs = None - in_ys = None - if use_equal_computations: - rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys) - assert rval2 == rval1 - return rval1 - - def get_clients_at_depth( fgraph: FunctionGraph, node: Apply, depth: int ) -> Generator[Apply, None, None]: diff --git a/tests/graph/rewriting/test_utils.py b/tests/graph/rewriting/test_utils.py index 08aaea250e..d2033a5555 100644 --- a/tests/graph/rewriting/test_utils.py +++ b/tests/graph/rewriting/test_utils.py @@ -4,145 +4,10 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.rewriting.basic import graph_rewriter -from aesara.graph.rewriting.utils import is_same_graph, rewrite_graph -from aesara.tensor.math import neg +from aesara.graph.rewriting.utils import rewrite_graph from aesara.tensor.type import vectors -class TestIsSameGraph: - def check(self, expected): - """ - Core function to perform comparison. - - :param expected: A list of tuples (v1, v2, ((g1, o1), ..., (gN, oN))) - with: - - `v1` and `v2` two Variables (the graphs to be compared) - - `gj` a `givens` dictionary to give as input to `is_same_graph` - - `oj` the expected output of `is_same_graph(v1, v2, givens=gj)` - - This function also tries to call `is_same_graph` by inverting `v1` and - `v2`, and ensures the output remains the same. - """ - for v1, v2, go in expected: - for gj, oj in go: - r1 = is_same_graph(v1, v2, givens=gj) - assert r1 == oj - r2 = is_same_graph(v2, v1, givens=gj) - assert r2 == oj - - def test_single_var(self): - # Test `is_same_graph` with some trivial graphs (one Variable). - - x, y, z = vectors("x", "y", "z") - self.check( - [ - (x, x, (({}, True),)), - ( - x, - y, - ( - ({}, False), - ({y: x}, True), - ), - ), - (x, neg(x), (({}, False),)), - (x, neg(y), (({}, False),)), - ] - ) - - def test_full_graph(self): - # Test `is_same_graph` with more complex graphs. - - x, y, z = vectors("x", "y", "z") - t = x * y - self.check( - [ - (x * 2, x * 2, (({}, True),)), - ( - x * 2, - y * 2, - ( - ({}, False), - ({y: x}, True), - ), - ), - ( - x * 2, - y * 2, - ( - ({}, False), - ({x: y}, True), - ), - ), - ( - x * 2, - y * 3, - ( - ({}, False), - ({y: x}, False), - ), - ), - ( - t * 2, - z * 2, - ( - ({}, False), - ({t: z}, True), - ), - ), - ( - t * 2, - z * 2, - ( - ({}, False), - ({z: t}, True), - ), - ), - (x * (y * z), (x * y) * z, (({}, False),)), - ] - ) - - def test_merge_only(self): - # Test `is_same_graph` when `equal_computations` cannot be used. - - x, y, z = vectors("x", "y", "z") - t = x * y - self.check( - [ - (x, t, (({}, False), ({t: x}, True))), - ( - t * 2, - x * 2, - ( - ({}, False), - ({t: x}, True), - ), - ), - ( - x * x, - x * y, - ( - ({}, False), - ({y: x}, True), - ), - ), - ( - x * x, - x * y, - ( - ({}, False), - ({y: x}, True), - ), - ), - ( - x * x + z, - x * y + t, - (({}, False), ({y: x}, False), ({y: x, t: z}, True)), - ), - ], - ) - - def test_rewrite_graph(): x, y = vectors("xy") diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 80e7ea5c45..dee3e271c4 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -26,7 +26,7 @@ out2in, ) from aesara.graph.rewriting.db import RewriteDatabaseQuery -from aesara.graph.rewriting.utils import is_same_graph, rewrite_graph +from aesara.graph.rewriting.utils import rewrite_graph from aesara.misc.safe_asarray import _asarray from aesara.tensor import inplace from aesara.tensor.basic import Alloc, join, switch @@ -4383,7 +4383,7 @@ def check(expr1, expr2): trees = [parse_mul_tree(e) for e in (expr1, expr2)] perform_sigm_times_exp(trees[0]) trees[0] = simplify_mul(trees[0]) - good = is_same_graph(compute_mul(trees[0]), compute_mul(trees[1])) + good = equal_computations([compute_mul(trees[0])], [compute_mul(trees[1])]) if not good: print(trees[0]) print(trees[1]) @@ -4552,7 +4552,7 @@ def test_compute_mul(self): tree = (x * y) * -z mul_tree = parse_mul_tree(tree) assert parse_mul_tree(compute_mul(mul_tree)) == mul_tree - assert is_same_graph(compute_mul(parse_mul_tree(tree)), tree) + assert equal_computations([compute_mul(parse_mul_tree(tree))], [tree]) def test_parse_mul_tree(self): x, y, z = vectors("x", "y", "z") @@ -4574,7 +4574,7 @@ def test_is_1pexp(self): lambda x: is_1pexp(x, only_process_constants=False), [(1 + exp_op(-x)), (exp_op(-x) + 1)], ): - assert not neg_ and is_same_graph(exp_arg, -x) + assert not neg_ and equal_computations([exp_arg], [-x]) assert is_1pexp(1 - exp_op(x), False) is None assert is_1pexp(2 + exp_op(x), False) is None assert is_1pexp(exp_op(x) + 2, False) is None diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 1e6a3e99da..baf427ed04 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -12,8 +12,8 @@ from aesara.compile import DeepCopyOp, shared from aesara.compile.io import In from aesara.configdefaults import config +from aesara.graph.basic import equal_computations from aesara.graph.op import get_test_value -from aesara.graph.rewriting.utils import is_same_graph from aesara.printing import pprint from aesara.scalar.basic import as_scalar from aesara.tensor import get_vector_length @@ -1265,7 +1265,7 @@ def test_advanced1_inc_and_set(self): assert np.allclose(f_out, output_num), (params, f_out, output_num) def test_adv_constant_arg(self): - # Test case provided (and bug detected, gh-607) by John Salvatier + # gh-607 m = matrix("m") gv = np.array([0, 1, 3]) g = at.constant(gv) @@ -1275,7 +1275,7 @@ def test_adv_constant_arg(self): s1 = m[gv, i] s2 = m[g, i] - assert is_same_graph(s1, s2) + assert equal_computations([s1], [s2]) def test_adv1_inc_sub_notlastdim(self): # Test that taking 1-dimensional advanced indexing From 1603a0334c7633cfa3444e548d9b25cdb2ad7965 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sat, 17 Sep 2022 15:05:32 -0500 Subject: [PATCH 08/11] Print outputs in Apply.__str__ --- aesara/graph/basic.py | 21 ++++++++++++++------- tests/graph/test_fg.py | 5 +---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/aesara/graph/basic.py b/aesara/graph/basic.py index 9f8d0408a3..89f3913fb1 100644 --- a/aesara/graph/basic.py +++ b/aesara/graph/basic.py @@ -202,7 +202,7 @@ def default_output(self): return self.outputs[do] def __str__(self): - return op_as_string(self.inputs, self) + return node_as_string(self.inputs, self) def __repr__(self): return str(self) @@ -1409,8 +1409,11 @@ def compute_deps(obj): default_leaf_formatter = str -def default_node_formatter(op, argstrings): - return f"{op.op}({', '.join(argstrings)})" +def default_node_formatter(node, input_strs, output_strs=None): + if output_strs: + return f"{', '.join(output_strs)} <- {node.op}({', '.join(input_strs)})" + else: + return f"{node.op}({', '.join(input_strs)})" def io_connection_pattern(inputs, outputs): @@ -1479,12 +1482,16 @@ def io_connection_pattern(inputs, outputs): return global_connection_pattern -def op_as_string( - i, op, leaf_formatter=default_leaf_formatter, node_formatter=default_node_formatter +def node_as_string( + inputs, + node, + leaf_formatter=default_leaf_formatter, + node_formatter=default_node_formatter, ): """Return a function that returns a string representation of the subgraph between `i` and :attr:`op.inputs`""" - strs = as_string(i, op.inputs, leaf_formatter, node_formatter) - return node_formatter(op, strs) + in_strs = as_string(inputs, node.inputs, leaf_formatter, node_formatter) + out_strs = as_string(node.outputs, node.outputs, leaf_formatter, node_formatter) + return node_formatter(node, in_strs, out_strs) def as_string( diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index 8e495ff44c..d6975db647 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -278,10 +278,7 @@ def test_replace_verbose(self, capsys): capres = capsys.readouterr() assert capres.err == "" - assert ( - "rewriting: rewrite test-reason replaces Op1.0 of Op1(var2, var1) with var1 of None" - in capres.out - ) + assert capres.out.startswith("rewriting: rewrite test-reason replaces") def test_replace_circular(self): """`FunctionGraph` allows cycles--for better or worse.""" From e4451de1c573bbc7bbcceafc8659324d9fccbf65 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sat, 17 Sep 2022 16:08:16 -0500 Subject: [PATCH 09/11] Add FunctionGraph callback checks to tests --- tests/graph/test_fg.py | 404 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 397 insertions(+), 7 deletions(-) diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index d6975db647..c145a99e24 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -1,10 +1,13 @@ import pickle +from typing import Any, Dict, List, Tuple import numpy as np import pytest +from typing_extensions import Literal from aesara.configdefaults import config from aesara.graph.basic import NominalVariable +from aesara.graph.features import Feature from aesara.graph.fg import FunctionGraph from aesara.graph.utils import MissingInputError from tests.graph.utils import ( @@ -19,6 +22,32 @@ ) +class CallbackTracker(Feature): + def __init__(self): + self.callback_history: List[ + Tuple[ + Literal["attach", "detach", "import", "change_input", "prune"], + Tuple[Any, ...], + Dict[Any, Any], + ] + ] = [] + + def on_attach(self, *args, **kwargs): + self.callback_history.append(("attach", args, kwargs)) + + def on_detach(self, *args, **kwargs): + self.callback_history.append(("detach", args, kwargs)) + + def on_import(self, *args, **kwargs): + self.callback_history.append(("import", args, kwargs)) + + def on_change_input(self, *args, **kwargs): + self.callback_history.append(("change_input", args, kwargs)) + + def on_prune(self, *args, **kwargs): + self.callback_history.append(("prune", args, kwargs)) + + class TestFunctionGraph: def test_pickle(self): var1 = op1() @@ -61,7 +90,11 @@ def test_init(self): var2 = MyVariable("var2") var3 = op1(var1) var4 = op2(var3, var2) - fg = FunctionGraph([var1, var2], [var3, var4], clone=False) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var4], clone=False, features=[cb_tracker] + ) + assert fg.inputs == [var1, var2] assert fg.outputs == [var3, var4] assert fg.apply_nodes == {var3.owner, var4.owner} @@ -73,6 +106,19 @@ def test_init(self): assert fg.get_clients(var3) == [("output", 0), (var4.owner, 0)] assert fg.get_clients(var4) == [("output", 1)] + assert len(cb_tracker.callback_history) == 3 + assert cb_tracker.callback_history[0] == ("attach", (fg,), {}) + assert cb_tracker.callback_history[1] == ( + "import", + (fg, var3.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[2] == ( + "import", + (fg, var4.owner, "init"), + {}, + ) + varC = MyConstant("varC") var5 = op1(var1, varC) fg = FunctionGraph(outputs=[var3, var4, var5], clone=False) @@ -94,7 +140,10 @@ def test_remove_client(self): var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) - fg = FunctionGraph([var1, var2], [var3, var5], clone=False) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var5], clone=False, features=[cb_tracker] + ) assert fg.variables == {var1, var2, var3, var4, var5} assert fg.get_clients(var2) == [ @@ -104,6 +153,25 @@ def test_remove_client(self): (var5.owner, 2), ] + assert len(cb_tracker.callback_history) == 4 + assert cb_tracker.callback_history[0] == ("attach", (fg,), {}) + assert cb_tracker.callback_history[1] == ( + "import", + (fg, var3.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[2] == ( + "import", + (fg, var4.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[3] == ( + "import", + (fg, var5.owner, "init"), + {}, + ) + cb_tracker.callback_history.clear() + fg.remove_client(var2, (var4.owner, 1)) assert fg.get_clients(var2) == [ @@ -112,12 +180,16 @@ def test_remove_client(self): (var5.owner, 2), ] + assert len(cb_tracker.callback_history) == 0 + fg.remove_client(var1, (var3.owner, 1)) assert fg.get_clients(var1) == [] assert var4.owner in fg.apply_nodes + assert len(cb_tracker.callback_history) == 0 + # This next `remove_client` should trigger a complete removal of `var4`'s # variables and `Apply` node from the `FunctionGraph`. # @@ -132,6 +204,13 @@ def test_remove_client(self): assert var4.owner.tag.removed_by == ["testing"] assert not any(o in fg.variables for o in var4.owner.outputs) + assert len(cb_tracker.callback_history) == 1 + assert cb_tracker.callback_history[0] == ( + "prune", + (fg, var4.owner, "testing"), + {}, + ) + def test_import_node(self): var1 = MyVariable("var1") @@ -139,7 +218,29 @@ def test_import_node(self): var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) - fg = FunctionGraph([var1, var2], [var3, var5], clone=False) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var5], clone=False, features=[cb_tracker] + ) + + assert len(cb_tracker.callback_history) == 4 + assert cb_tracker.callback_history[0] == ("attach", (fg,), {}) + assert cb_tracker.callback_history[1] == ( + "import", + (fg, var3.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[2] == ( + "import", + (fg, var4.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[3] == ( + "import", + (fg, var5.owner, "init"), + {}, + ) + cb_tracker.callback_history.clear() var8 = MyVariable("var8") var6 = op2(var8) @@ -148,11 +249,16 @@ def test_import_node(self): fg.import_node(var6.owner) assert var8 not in fg.variables + assert len(cb_tracker.callback_history) == 0 fg.import_node(var6.owner, import_missing=True) assert var8 in fg.inputs assert var6.owner in fg.apply_nodes + assert len(cb_tracker.callback_history) == 1 + assert cb_tracker.callback_history[0] == ("import", (fg, var6.owner, None), {}) + cb_tracker.callback_history.clear() + var7 = op2(var2) assert not hasattr(var7.owner.tag, "imported_by") fg.import_node(var7.owner) @@ -162,6 +268,9 @@ def test_import_node(self): assert var7.owner in fg.apply_nodes assert (var7.owner, 0) in fg.get_clients(var2) + assert len(cb_tracker.callback_history) == 1 + assert cb_tracker.callback_history[0] == ("import", (fg, var7.owner, None), {}) + def test_import_var(self): var1 = MyVariable("var1") @@ -200,7 +309,29 @@ def test_change_input(self): var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) - fg = FunctionGraph([var1, var2], [var3, var5], clone=False) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var5], clone=False, features=[cb_tracker] + ) + + assert len(cb_tracker.callback_history) == 4 + assert cb_tracker.callback_history[0] == ("attach", (fg,), {}) + assert cb_tracker.callback_history[1] == ( + "import", + (fg, var3.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[2] == ( + "import", + (fg, var4.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[3] == ( + "import", + (fg, var5.owner, "init"), + {}, + ) + cb_tracker.callback_history.clear() var6 = MyVariable2("var6") with pytest.raises(TypeError): @@ -209,6 +340,8 @@ def test_change_input(self): with pytest.raises(TypeError): fg.change_node_input(var5.owner, 1, var6) + assert len(cb_tracker.callback_history) == 0 + old_apply_nodes = set(fg.apply_nodes) old_variables = set(fg.variables) old_var5_clients = list(fg.get_clients(var5)) @@ -216,6 +349,8 @@ def test_change_input(self): # We're replacing with the same variable, so nothing should happen fg.change_node_input(var5.owner, 1, var2) + assert len(cb_tracker.callback_history) == 0 + assert old_apply_nodes == fg.apply_nodes assert old_variables == fg.variables assert old_var5_clients == fg.get_clients(var5) @@ -223,9 +358,35 @@ def test_change_input(self): # Perform a valid `Apply` node input change fg.change_node_input(var5.owner, 1, var1) - assert var5.owner.inputs[1] is var1 + assert var5.owner.inputs == [var4, var1, var2] + assert fg.outputs[1].owner == var5.owner assert (var5.owner, 1) not in fg.get_clients(var2) + assert len(cb_tracker.callback_history) == 1 + assert cb_tracker.callback_history[0] == ( + "change_input", + (fg, var5.owner, 1, var2, var1), + {"reason": None}, + ) + cb_tracker.callback_history.clear() + + # Perform a valid `Apply` node input change that results in a + # node removal (i.e. `var4.owner`) + fg.change_node_input(var5.owner, 0, var1) + + assert var5.owner.inputs[0] is var1 + assert not fg.get_clients(var4) + assert var4.owner not in fg.apply_nodes + assert var4 not in fg.variables + + assert len(cb_tracker.callback_history) == 2 + assert cb_tracker.callback_history[0] == ("prune", (fg, var4.owner, None), {}) + assert cb_tracker.callback_history[1] == ( + "change_input", + (fg, var5.owner, 0, var4, var1), + {"reason": None}, + ) + @config.change_flags(compute_test_value="raise") def test_replace_test_value(self): @@ -254,18 +415,212 @@ def test_replace(self): var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) - fg = FunctionGraph([var1, var2], [var3, var5], clone=False) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var5], clone=False, features=[cb_tracker] + ) + + assert len(cb_tracker.callback_history) == 4 + assert cb_tracker.callback_history[0] == ("attach", (fg,), {}) + assert cb_tracker.callback_history[1] == ( + "import", + (fg, var3.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[2] == ( + "import", + (fg, var4.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[3] == ( + "import", + (fg, var5.owner, "init"), + {}, + ) + cb_tracker.callback_history.clear() with pytest.raises(TypeError): var0 = MyVariable2("var0") # The types don't match and one cannot be converted to the other fg.replace(var3, var0) + assert len(cb_tracker.callback_history) == 0 + # Test a basic replacement fg.replace_all([(var3, var1)]) assert var3 not in fg.variables assert fg.apply_nodes == {var4.owner, var5.owner} assert var4.owner.inputs == [var1, var2] + assert fg.outputs == [var1, var5] + + assert len(cb_tracker.callback_history) == 3 + assert cb_tracker.callback_history[0] == ( + "change_input", + (fg, "output", 0, var3, var1), + {"reason": None}, + ) + assert cb_tracker.callback_history[1] == ("prune", (fg, var3.owner, None), {}) + assert cb_tracker.callback_history[2] == ( + "change_input", + (fg, var4.owner, 0, var3, var1), + {"reason": None}, + ) + + var3 = op1(var1) + var4 = op2(var3) + var5 = op3(var4) + cb_tracker = CallbackTracker() + fg = FunctionGraph([var1], [var5], clone=False, features=[cb_tracker]) + + # Test a replacement that would remove the replacement variable + # (i.e. `var3`) from the graph when the variable to be replaced + # (i.e. `var4`) is removed + fg.replace_all([(var4, var3)]) + + assert fg.apply_nodes == {var3.owner, var5.owner} + assert fg.inputs == [var1] + assert fg.outputs == [var5] + assert fg.variables == {var1, var3, var5} + + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, var3.owner, "init"), {}), + ("import", (fg, var4.owner, "init"), {}), + ("import", (fg, var5.owner, "init"), {}), + ("prune", (fg, var4.owner, None), {}), + ("change_input", (fg, var5.owner, 0, var4, var3), {"reason": None}), + ] + + var3 = op1(var1) + var4 = op2(var3) + var5 = op3(var4, var4) + cb_tracker = CallbackTracker() + fg = FunctionGraph([var1], [var5], clone=False, features=[cb_tracker]) + + # Test multiple `change_node_input` calls on the same node + fg.replace_all([(var4, var3)]) + + assert fg.apply_nodes == {var3.owner, var5.owner} + assert fg.inputs == [var1] + assert fg.outputs == [var5] + assert fg.variables == {var1, var3, var5} + + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, var3.owner, "init"), {}), + ("import", (fg, var4.owner, "init"), {}), + ("import", (fg, var5.owner, "init"), {}), + ("change_input", (fg, var5.owner, 0, var4, var3), {"reason": None}), + ("prune", (fg, var4.owner, None), {}), + ("change_input", (fg, var5.owner, 1, var4, var3), {"reason": None}), + ] + + def test_replace_outputs(self): + var1 = MyVariable("var1") + var2 = MyVariable("var2") + var3 = op1(var1) + var4 = op2(var2) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var4, var3], clone=False, features=[cb_tracker] + ) + + fg.replace_all([(var3, var1)]) + assert var3 not in fg.variables + + assert fg.apply_nodes == {var4.owner} + assert fg.outputs == [var1, var4, var1] + + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, var3.owner, "init"), {}), + ("import", (fg, var4.owner, "init"), {}), + ("change_input", (fg, "output", 0, var3, var1), {"reason": None}), + ("prune", (fg, var3.owner, None), {}), + ("change_input", (fg, "output", 2, var3, var1), {"reason": None}), + ] + + def test_replace_contract(self): + x = MyVariable("x") + v1 = op1(x) + v2 = op1(v1) + v3 = op1(v2) + v4 = op1(v3) + + v1.name = "v1" + v2.name = "v2" + v3.name = "v3" + v4.name = "v4" + + cb_tracker = CallbackTracker() + fg = FunctionGraph([x], [v4], clone=False, features=[cb_tracker]) + + # This replacement should produce a new `Apply` node that's equivalent + # to `v2` and try to replace `v3`'s node with that one. In other + # words, the replacement creates a new node that's already in the + # `FunctionGraph`. + # The end result is `v3 = v2`. + fg.replace_all([(v2, v1)]) + + assert v2 not in fg.variables + assert fg.clients == { + x: [(v1.owner, 0)], + v1: [(v3.owner, 0)], + v2: [], + v3: [(v4.owner, 0)], + v4: [("output", 0)], + } + assert fg.apply_nodes == {v4.owner, v3.owner, v1.owner} + assert v2 not in set(sum((n.outputs for n in fg.apply_nodes), [])) + + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, v1.owner, "init"), {}), + ("import", (fg, v2.owner, "init"), {}), + ("import", (fg, v3.owner, "init"), {}), + ("import", (fg, v4.owner, "init"), {}), + ("prune", (fg, v2.owner, None), {}), + ("change_input", (fg, v3.owner, 0, v2, v1), {"reason": None}), + ] + + # Let's try the same thing at a different point in the chain + x = MyVariable("x") + v1 = op1(x) + v2 = op1(v1) + v3 = op1(v2) + v4 = op1(v3) + + v1.name = "v1" + v2.name = "v2" + v3.name = "v3" + v4.name = "v4" + + cb_tracker = CallbackTracker() + fg = FunctionGraph([x], [v4], clone=False, features=[cb_tracker]) + + fg.replace_all([(v3, v2)]) + + assert v3 not in fg.variables + assert fg.clients == { + x: [(v1.owner, 0)], + v1: [(v2.owner, 0)], + v2: [(v4.owner, 0)], + v3: [], + v4: [("output", 0)], + } + assert fg.apply_nodes == {v4.owner, v2.owner, v1.owner} + assert v3 not in set(sum((n.outputs for n in fg.apply_nodes), [])) + + exp_res = [ + ("attach", (fg,), {}), + ("import", (fg, v1.owner, "init"), {}), + ("import", (fg, v2.owner, "init"), {}), + ("import", (fg, v3.owner, "init"), {}), + ("import", (fg, v4.owner, "init"), {}), + ("prune", (fg, v3.owner, None), {}), + ("change_input", (fg, v4.owner, 0, v3, v2), {"reason": None}), + ] + assert cb_tracker.callback_history == exp_res def test_replace_verbose(self, capsys): @@ -288,7 +643,29 @@ def test_replace_circular(self): var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) - fg = FunctionGraph([var1, var2], [var3, var5], clone=False) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var5], clone=False, features=[cb_tracker] + ) + + assert len(cb_tracker.callback_history) == 4 + assert cb_tracker.callback_history[0] == ("attach", (fg,), {}) + assert cb_tracker.callback_history[1] == ( + "import", + (fg, var3.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[2] == ( + "import", + (fg, var4.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[3] == ( + "import", + (fg, var5.owner, "init"), + {}, + ) + cb_tracker.callback_history.clear() fg.replace_all([(var3, var4)]) @@ -297,6 +674,19 @@ def test_replace_circular(self): assert fg.apply_nodes == {var4.owner, var5.owner} assert var4.owner.inputs == [var4, var2] + assert len(cb_tracker.callback_history) == 3 + assert cb_tracker.callback_history[0] == ( + "change_input", + (fg, "output", 0, var3, var4), + {"reason": None}, + ) + assert cb_tracker.callback_history[1] == ("prune", (fg, var3.owner, None), {}) + assert cb_tracker.callback_history[2] == ( + "change_input", + (fg, var4.owner, 0, var3, var4), + {"reason": None}, + ) + def test_replace_bad_state(self): var1 = MyVariable("var1") From 48a249e776914d01c88fda9a5b7139ff0dd2f514 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 7 Sep 2022 15:50:00 -0500 Subject: [PATCH 10/11] Separate recursive importing from single node importing in FunctionGraph --- aesara/graph/fg.py | 94 ++++++++++++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 37 deletions(-) diff --git a/aesara/graph/fg.py b/aesara/graph/fg.py index 68e7a6a26c..56c999f871 100644 --- a/aesara/graph/fg.py +++ b/aesara/graph/fg.py @@ -311,7 +311,7 @@ def import_var( if isinstance(var.type, NullType): raise TypeError( - f"Computation graph contains a NaN. {var.type.why_null}" + f"Computation graph contains a null type: {var} {var.type.why_null}" ) if import_missing: self.add_input(var) @@ -327,7 +327,7 @@ def import_node( reason: Optional[str] = None, import_missing: bool = False, ) -> None: - """Recursively import everything between an ``Apply`` node and the ``FunctionGraph``'s outputs. + """Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs. Parameters ---------- @@ -347,42 +347,62 @@ def import_node( # to know where to stop going down.) new_nodes = io_toposort(self.variables, apply_node.outputs) - if check: - for node in new_nodes: - for var in node.inputs: - if ( - var.owner is None - and not isinstance(var, AtomicVariable) - and var not in self.inputs - ): - if import_missing: - self.add_input(var) - else: - error_msg = ( - f"Input {node.inputs.index(var)} ({var})" - " of the graph (indices start " - f"from 0), used to compute {node}, was not " - "provided and not given a value. Use the " - "Aesara flag exception_verbosity='high', " - "for more information on this error." - ) - raise MissingInputError(error_msg, variable=var) - for node in new_nodes: - assert node not in self.apply_nodes - self.apply_nodes.add(node) - if not hasattr(node.tag, "imported_by"): - node.tag.imported_by = [] - node.tag.imported_by.append(str(reason)) - for output in node.outputs: - self.setup_var(output) - self.variables.add(output) - for i, input in enumerate(node.inputs): - if input not in self.variables: - self.setup_var(input) - self.variables.add(input) - self.add_client(input, (node, i)) - self.execute_callbacks("on_import", node, reason) + self._import_node( + node, check=check, reason=reason, import_missing=import_missing + ) + + def _import_node( + self, + apply_node: Apply, + check: bool = True, + reason: Optional[str] = None, + import_missing: bool = False, + ) -> None: + """Import a single node. + + See `FunctionGraph.import_node`. + """ + assert apply_node not in self.apply_nodes + + for i, inp in enumerate(apply_node.inputs): + if ( + check + and inp.owner is None + and not isinstance(inp, AtomicVariable) + and inp not in self.inputs + ): + if import_missing: + self.add_input(inp) + else: + error_msg = ( + f"Input {apply_node.inputs.index(inp)} ({inp})" + " of the graph (indices start " + f"from 0), used to compute {apply_node}, was not " + "provided and not given a value. Use the " + "Aesara flag exception_verbosity='high', " + "for more information on this error." + ) + raise MissingInputError(error_msg, variable=inp) + + if inp not in self.variables: + self.setup_var(inp) + self.variables.add(inp) + + self.add_client(inp, (apply_node, i)) + + for output in apply_node.outputs: + self.setup_var(output) + self.variables.add(output) + + self.apply_nodes.add(apply_node) + + if not hasattr(apply_node.tag, "imported_by"): + apply_node.tag.imported_by = [] + + apply_node.tag.imported_by.append(str(reason)) + + self.execute_callbacks("on_import", apply_node, reason) def change_node_input( self, From ad7259db7eff29975fdb3895119fcbd7d11c5cfc Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 29 Aug 2022 15:04:34 -0500 Subject: [PATCH 11/11] Hash-cons Apply, Constant and change node input replacement semantics --- aesara/compile/debugmode.py | 117 +++++++------- aesara/graph/basic.py | 235 +++++++++++++++++----------- aesara/graph/destroyhandler.py | 229 +++++++++++++-------------- aesara/graph/features.py | 116 +++++++++----- aesara/graph/fg.py | 154 +++++++++++++++--- aesara/graph/rewriting/basic.py | 76 +++++---- aesara/link/c/basic.py | 4 +- aesara/link/c/params_type.py | 8 +- aesara/link/c/type.py | 10 +- aesara/sparse/basic.py | 45 +++--- aesara/tensor/basic.py | 2 + aesara/tensor/rewriting/shape.py | 108 +++++++------ aesara/tensor/type.py | 11 +- aesara/tensor/type_other.py | 49 +++--- aesara/tensor/var.py | 144 +---------------- tests/graph/rewriting/test_basic.py | 15 +- tests/graph/test_basic.py | 209 ++++++++++++++++--------- tests/graph/test_destroyhandler.py | 59 +++++-- tests/graph/test_fg.py | 208 +++++++++++++++++------- 19 files changed, 1034 insertions(+), 765 deletions(-) diff --git a/aesara/compile/debugmode.py b/aesara/compile/debugmode.py index c3c116b091..41e45d2eec 100644 --- a/aesara/compile/debugmode.py +++ b/aesara/compile/debugmode.py @@ -14,10 +14,11 @@ from itertools import chain from itertools import product as itertools_product from logging import Logger -from typing import Optional +from typing import TYPE_CHECKING, Optional, Union from warnings import warn import numpy as np +from typing_extensions import Literal import aesara from aesara.compile.function.types import ( @@ -42,7 +43,9 @@ from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function -__docformat__ = "restructuredtext en" +if TYPE_CHECKING: + from aesara.graph.basic import Apply + _logger: Logger = logging.getLogger("aesara.compile.debugmode") _logger.addFilter(NoDuplicateOptWarningFilter()) @@ -1109,43 +1112,32 @@ class _FunctionGraphEvent: """ - kind = "" - """ - One of 'import', 'change', 'prune'. - - """ - - node = None - """ - Either 'output' or an Apply instance. - - """ - - op = None - """Either 'output' or an Op instance""" + kind: Literal["import", "change", "prune"] + old_node: Optional[Union[Literal["output"], "Apply"]] + new_node: Optional[Union[Literal["output"], "Apply"]] + op: Optional[Union[Literal["output"], Op]] + idx: Optional[int] + reason: Optional[str] - idx = None - """ - Change events involve an position index of the input variable. - - """ - - reason = None - """ - Change events sometimes have a reason. - - """ - - def __init__(self, kind, node, idx=None, reason=None): + def __init__( + self, + kind: Literal["import", "change", "prune"], + old_node: Union[Literal["output"], "Apply"], + new_node: Union[Literal["output"], "Apply"] = None, + idx: Optional[int] = None, + reason: Optional[str] = None, + ): self.kind = kind - if node == "output": - self.node = "output" + if old_node == "output": + self.old_node = "output" + self.new_node = "output" self.op = "output" else: - self.node = node - self.op = node.op + self.old_node = old_node + self.new_node = new_node + self.op = old_node.op self.idx = idx - self.reason = str(reason) + self.reason = str(reason) if reason else None def __str__(self): if self.kind == "change": @@ -1219,21 +1211,21 @@ def on_attach(self, fgraph): self.replaced_by = {} self.event_list = [] for node in fgraph.toposort(): - self.on_import(fgraph, node, "on_attach") + self.on_import(fgraph, node, reason="on_attach") def on_detach(self, fgraph): assert fgraph is self.fgraph self.fgraph = None def on_prune(self, fgraph, node, reason): - self.event_list.append(_FunctionGraphEvent("prune", node, reason=str(reason))) + self.event_list.append(_FunctionGraphEvent("prune", node, reason=reason)) assert node in self.active_nodes assert node not in self.inactive_nodes self.active_nodes.remove(node) self.inactive_nodes.add(node) def on_import(self, fgraph, node, reason): - self.event_list.append(_FunctionGraphEvent("import", node, reason=str(reason))) + self.event_list.append(_FunctionGraphEvent("import", node, reason=reason)) assert node not in self.active_nodes self.active_nodes.add(node) @@ -1253,18 +1245,23 @@ def on_import(self, fgraph, node, reason): self.reasons.setdefault(r, []) self.replaced_by.setdefault(r, []) - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + def on_change_input( + self, fgraph, old_node, new_node, i, old_var, new_var, reason=None + ): reason = str(reason) self.event_list.append( - _FunctionGraphEvent("change", node, reason=reason, idx=i) + _FunctionGraphEvent("change", old_node, new_node, idx=i, reason=reason) ) - self.reasons.setdefault(new_r, []) - self.replaced_by.setdefault(new_r, []) + self.on_import(fgraph, new_node, reason=reason) + self.on_prune(fgraph, old_node, reason=reason) + + self.reasons.setdefault(new_var, []) + self.replaced_by.setdefault(new_var, []) append_reason = True - for tup in self.reasons[new_r]: - if tup[0] == reason and tup[1] is r: + for tup in self.reasons[new_var]: + if tup[0] == reason and tup[1] is old_var: append_reason = False if append_reason: @@ -1272,12 +1269,12 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): # optimizations will change the graph done = dict() used_ids = dict() - self.reasons[new_r].append( + self.reasons[new_var].append( ( reason, - r, + old_var, _debugprint( - r, + old_var, prefix=" ", depth=6, file=StringIO(), @@ -1286,7 +1283,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): used_ids=used_ids, ).getvalue(), _debugprint( - new_r, + new_var, prefix=" ", depth=6, file=StringIO(), @@ -1296,22 +1293,22 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): ).getvalue(), ) ) - self.replaced_by[r].append((reason, new_r)) + self.replaced_by[old_var].append((reason, new_var)) - if r in self.equiv: - r_set = self.equiv[r] + if old_var in self.equiv: + r_set = self.equiv[old_var] else: - r_set = self.equiv.setdefault(r, {r}) - self.all_variables_ever.append(r) + r_set = self.equiv.setdefault(old_var, {old_var}) + self.all_variables_ever.append(old_var) - if new_r in self.equiv: - new_r_set = self.equiv[new_r] + if new_var in self.equiv: + new_r_set = self.equiv[new_var] else: - new_r_set = self.equiv.setdefault(new_r, {new_r}) - self.all_variables_ever.append(new_r) + new_r_set = self.equiv.setdefault(new_var, {new_var}) + self.all_variables_ever.append(new_var) - assert new_r in new_r_set - assert r in r_set + assert new_var in new_r_set + assert old_var in r_set # update one equivalence set to contain the other # transfer all the elements of the old one to the new one @@ -1320,8 +1317,8 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): self.equiv[like_new_r] = r_set assert like_new_r in r_set - assert self.equiv[r] is r_set - assert self.equiv[new_r] is r_set + assert self.equiv[old_var] is r_set + assert self.equiv[new_var] is r_set def printstuff(self): for key in self.equiv: diff --git a/aesara/graph/basic.py b/aesara/graph/basic.py index 89f3913fb1..f02a7b1141 100644 --- a/aesara/graph/basic.py +++ b/aesara/graph/basic.py @@ -1,5 +1,4 @@ """Core graph classes.""" -import abc import warnings from collections import deque from copy import copy @@ -32,7 +31,6 @@ from aesara.configdefaults import config from aesara.graph.utils import ( - MetaObject, MethodNotDefined, Scratchpad, TestValueError, @@ -53,32 +51,48 @@ _TypeType = TypeVar("_TypeType", bound="Type") _IdType = TypeVar("_IdType", bound=Hashable) -T = TypeVar("T", bound="Node") +T = TypeVar("T", bound=Union["Apply", "Variable"]) NoParams = object() NodeAndChildren = Tuple[T, Optional[Iterable[T]]] -class Node(MetaObject): - r"""A `Node` in an Aesara graph. +class UniqueInstanceFactory(type): - Currently, graphs contain two kinds of `Nodes`: `Variable`\s and `Apply`\s. - Edges in the graph are not explicitly represented. Instead each `Node` - keeps track of its parents via `Variable.owner` / `Apply.inputs`. + __instances__: WeakValueDictionary - """ - name: Optional[str] + def __new__(cls, name, bases, dct): + dct["__instances__"] = WeakValueDictionary() - def get_parents(self): - """ - Return a list of the parents of this node. - Should return a copy--i.e., modifying the return - value should not modify the graph structure. + if "_post_call" not in dct: - """ - raise NotImplementedError() + def _post_call(self, *args, **kwargs): + return self + + dct["_post_call"] = _post_call + + res = super().__new__(cls, name, bases, dct) + return res + + def __call__( + cls, + *args, + **kwargs, + ): + idp = cls.create_key(*args, **kwargs) + res = cls.__instances__.get(idp) -class Apply(Node, Generic[OpType]): + if res is None: + res = super(UniqueInstanceFactory, cls).__call__(*args, **kwargs) + cls.__instances__[idp] = res + + return res._post_call(*args, **kwargs) + + +class Apply( + Generic[OpType], + metaclass=UniqueInstanceFactory, +): """A `Node` representing the application of an operation to inputs. Basically, an `Apply` instance is an object that represents the @@ -113,12 +127,38 @@ class Apply(Node, Generic[OpType]): """ + __slots__ = ("op", "inputs", "outputs", "__weakref__", "tag") + + @classmethod + def create_key(cls, op, inputs, outputs): + return (op,) + tuple(inputs) + def __init__( self, op: OpType, inputs: Sequence["Variable"], outputs: Sequence["Variable"], ): + r""" + + Parameters + ---------- + op + The operation that produces `outputs` given `inputs`. + inputs + The arguments of the expression modeled by the `Apply` node. + outputs + The outputs of the expression modeled by the `Apply` node. If a + node already exists for the given `op` and `inputs` combination, + each `Variable` in `outputs` will be associated with the node + (i.e. `Variable.owner` will be (re)set), and the `Apply.outputs` + values for the returned node will consist of the original outputs + and not the new `outputs`. + In other words, `Apply.outputs` is always a consistent, unique list + of `Variable`\s for each `op` and `inputs` pair. + + """ + if not isinstance(inputs, Sequence): raise TypeError("The inputs of an Apply must be a sequence type") @@ -129,7 +169,6 @@ def __init__( self.inputs: List[Variable] = [] self.tag = Scratchpad() - # filter inputs to make sure each element is a Variable for input in inputs: if isinstance(input, Variable): self.inputs.append(input) @@ -137,22 +176,34 @@ def __init__( raise TypeError( f"The 'inputs' argument to Apply must contain Variable instances, not {input}" ) - self.outputs: List[Variable] = [] - # filter outputs to make sure each element is a Variable + + self.outputs: List[Variable] = list(outputs) + + def _post_call(self, op, inputs, outputs): + + # If a user passes new outputs to an existing `Apply` node, those + # outputs will be updated and associated with the node, but the + # returned node's outputs will still be the original `Variable`s. for i, output in enumerate(outputs): - if isinstance(output, Variable): - if output.owner is None: - output.owner = self - output.index = i - elif output.owner is not self or output.index != i: - raise ValueError( - "All output variables passed to Apply must belong to it." - ) - self.outputs.append(output) - else: + if not isinstance(output, Variable): raise TypeError( f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}" ) + output.owner = self + output.index = i + + return self + + def __eq__(self, other): + if isinstance(other, type(self)): + if self.op == other.op and self.inputs == other.inputs: + return True + return False + + return NotImplemented + + def __hash__(self): + return hash((type(self), self.op, tuple(self.inputs))) def run_params(self): """ @@ -165,8 +216,7 @@ def run_params(self): return NoParams def __getstate__(self): - d = self.__dict__ - # ufunc don't pickle/unpickle well + d = {k: getattr(self, k) for k in self.__slots__ if k not in ("__weakref__",)} if hasattr(self.tag, "ufunc"): d = copy(self.__dict__) t = d["tag"] @@ -174,6 +224,11 @@ def __getstate__(self): d["tag"] = t return d + def __setstate__(self, dct): + for k in self.__slots__: + if k in dct: + setattr(self, k, dct[k]) + def default_output(self): """ Returns the default output for this node. @@ -267,6 +322,7 @@ def clone_with_new_inputs( from aesara.graph.op import HasInnerGraph assert isinstance(inputs, (list, tuple)) + remake_node = False new_inputs: List["Variable"] = list(inputs) for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)): @@ -280,17 +336,22 @@ def clone_with_new_inputs( else: remake_node = True - if remake_node: - new_op = self.op + new_op = self.op - if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore - new_op = new_op.clone() # type: ignore + if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore + new_op = new_op.clone() # type: ignore + if remake_node: new_node = new_op.make_node(*new_inputs) new_node.tag = copy(self.tag).__update__(new_node.tag) + elif new_op == self.op and new_inputs == self.inputs: + new_node = self else: - new_node = self.clone(clone_inner_graph=clone_inner_graph) - new_node.inputs = new_inputs + new_node = self.__class__( + new_op, new_inputs, [output.clone() for output in self.outputs] + ) + new_node.tag = copy(self.tag) + return new_node def get_parents(self): @@ -316,7 +377,7 @@ def params_type(self): return self.op.params_type -class Variable(Node, Generic[_TypeType, OptionalApplyType]): +class Variable(Generic[_TypeType, OptionalApplyType]): r""" A :term:`Variable` is a node in an expression graph that represents a variable. @@ -411,7 +472,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]): """ - # __slots__ = ['type', 'owner', 'index', 'name'] + __slots__ = ("_owner", "_index", "name", "type", "__weakref__", "tag", "auto_name") __count__ = count(0) _owner: OptionalApplyType @@ -487,26 +548,17 @@ def __str__(self): else: return f"<{self.type}>" - def __repr_test_value__(self): - """Return a ``repr`` of the test value. - - Return a printable representation of the test value. It can be - overridden by classes with non printable test_value to provide a - suitable representation of the test_value. - """ - return repr(self.get_test_value()) - def __repr__(self, firstPass=True): """Return a ``repr`` of the `Variable`. - Return a printable name or description of the Variable. If - ``config.print_test_value`` is ``True`` it will also print the test - value, if any. + Return a printable name or description of the `Variable`. If + `aesara.config.print_test_value` is ``True``, it will also print the + test value, if any. """ to_print = [str(self)] if config.print_test_value and firstPass: try: - to_print.append(self.__repr_test_value__()) + to_print.append(repr(self.get_test_value())) except TestValueError: pass return "\n".join(to_print) @@ -534,26 +586,6 @@ def clone(self, **kwargs): cp.tag = copy(self.tag) return cp - def __lt__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __lt__", self.__class__.__name__ - ) - - def __le__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __le__", self.__class__.__name__ - ) - - def __gt__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __gt__", self.__class__.__name__ - ) - - def __ge__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __ge__", self.__class__.__name__ - ) - def get_parents(self): if self.owner is not None: return [self.owner] @@ -611,7 +643,7 @@ def eval(self, inputs_to_values=None): return rval def __getstate__(self): - d = self.__dict__.copy() + d = {k: getattr(self, k) for k in self.__slots__ if k not in ("__weakref__",)} d.pop("_fn_cache", None) if (not config.pickle_test_value) and (hasattr(self.tag, "test_value")): if not type(config).pickle_test_value.is_default: @@ -624,6 +656,11 @@ def __getstate__(self): d["tag"] = t return d + def __setstate__(self, dct): + for k in self.__slots__: + if k in dct: + setattr(self, k, dct[k]) + class AtomicVariable(Variable[_TypeType, None]): """A node type that has no ancestors and should never be considered an input to a graph.""" @@ -631,19 +668,12 @@ class AtomicVariable(Variable[_TypeType, None]): def __init__(self, type: _TypeType, name: Optional[str] = None, **kwargs): super().__init__(type=type, owner=None, index=None, name=name, **kwargs) - @abc.abstractmethod - def signature(self): - ... - - def merge_signature(self): - return self.signature() - def equals(self, other): """ This does what `__eq__` would normally do, but `Variable` and `Apply` should always be hashable by `id`. """ - return isinstance(other, type(self)) and self.signature() == other.signature() + return self == other @property def owner(self): @@ -677,7 +707,10 @@ class NominalVariable(AtomicVariable[_TypeType]): __instances__: WeakValueDictionary = WeakValueDictionary() def __new__(cls, id: _IdType, typ: _TypeType, **kwargs): - if (typ, id) not in cls.__instances__: + + idp = (typ, id) + + if idp not in cls.__instances__: var_type = typ.variable_type type_name = f"Nominal{var_type.__name__}" @@ -692,9 +725,9 @@ def _str(self): ) res: NominalVariable = super().__new__(new_type) - cls.__instances__[(typ, id)] = res + cls.__instances__[idp] = res - return cls.__instances__[(typ, id)] + return cls.__instances__[idp] def __init__(self, id: _IdType, typ: _TypeType, name: Optional[str] = None): self.id = id @@ -720,11 +753,11 @@ def __hash__(self): def __repr__(self): return f"{type(self).__name__}({repr(self.id)}, {repr(self.type)})" - def signature(self) -> Tuple[_TypeType, _IdType]: - return (self.type, self.id) - -class Constant(AtomicVariable[_TypeType]): +class Constant( + AtomicVariable[_TypeType], + metaclass=UniqueInstanceFactory, +): """A `Variable` with a fixed `data` field. `Constant` nodes make numerous optimizations possible (e.g. constant @@ -737,19 +770,22 @@ class Constant(AtomicVariable[_TypeType]): """ - # __slots__ = ['data'] + __slots__ = ("type", "data") + + @classmethod + def create_key(cls, type, data, *args, **kwargs): + # TODO FIXME: This filters the data twice: once here, and again in + # `cls.__init__`. This might not be a big deal, though. + return (type, type.filter(data)) def __init__(self, type: _TypeType, data: Any, name: Optional[str] = None): - super().__init__(type, name=name) + AtomicVariable.__init__(self, type, name=name) self.data = type.filter(data) add_tag_trace(self) def get_test_value(self): return self.data - def signature(self): - return (self.type, self.data) - def __str__(self): if self.name is not None: return self.name @@ -775,6 +811,15 @@ def owner(self, value) -> None: def value(self): return self.data + def __hash__(self): + return hash((type(self), self.type, self.data)) + + def __eq__(self, other): + if isinstance(other, type(self)): + return self.type == other.type and self.data == other.data + + return NotImplemented + def walk( nodes: Iterable[T], diff --git a/aesara/graph/destroyhandler.py b/aesara/graph/destroyhandler.py index abc4894715..1f6c4956ce 100644 --- a/aesara/graph/destroyhandler.py +++ b/aesara/graph/destroyhandler.py @@ -14,15 +14,6 @@ from aesara.misc.ordered_set import OrderedSet -class ProtocolError(Exception): - """ - Raised when FunctionGraph calls DestroyHandler callbacks in - an invalid way, for example, pruning or changing a node that has - never been imported. - - """ - - def _contains_cycle(fgraph, orderings): """ Function to check if the given graph contains a cycle @@ -180,7 +171,8 @@ def _build_droot_impact(destroy_handler): impact = {} # destroyed nonview variable -> it + all views of it root_destroyer = {} # root -> destroyer apply - for app in destroy_handler.destroyers: + for ref_out in destroy_handler.destroyers: + app = ref_out.owner for output_idx, input_idx_list in app.op.destroy_map.items(): if len(input_idx_list) != 1: raise NotImplementedError() @@ -250,7 +242,7 @@ def fast_inplace_check(fgraph, inputs): return inputs -class DestroyHandler(Bookkeeper): # noqa +class DestroyHandler(Bookkeeper): """ The DestroyHandler class detects when a graph is impossible to evaluate because of aliasing and destructive operations. @@ -319,8 +311,8 @@ def __init__(self, do_imports_on_attach=True, algo=None): self.impact = OrderedDict() """ - If a var is destroyed, then this dict will map - droot[var] to the apply node that destroyed var + If a ``var`` is destroyed, then this dict will map + ``droot[var]`` to the `Variable` that's owner destroyed ``var`` TODO: rename to vroot_to_destroyer """ @@ -334,21 +326,6 @@ def clone(self): return type(self)(self.do_imports_on_attach, self.algo) def on_attach(self, fgraph): - """ - When attaching to a new fgraph, check that - 1) This DestroyHandler wasn't already attached to some fgraph - (its data structures are only set up to serve one). - 2) The FunctionGraph doesn't already have a DestroyHandler. - This would result in it validating everything twice, causing - compilation to be slower. - - Give the FunctionGraph instance: - 1) A new method "destroyers(var)" - TODO: what does this do exactly? - 2) A new attribute, "destroy_handler" - TODO: WRITEME: what does this do besides the checks? - - """ if any(hasattr(fgraph, attr) for attr in ("destroyers", "destroy_handler")): raise AlreadyThere("DestroyHandler feature is already present") @@ -358,20 +335,28 @@ def on_attach(self, fgraph): "A DestroyHandler instance can only serve one FunctionGraph" ) - # Annotate the FunctionGraph # self.unpickle(fgraph) + fgraph.destroy_handler = self self.fgraph = fgraph - self.destroyers = ( - OrderedSet() - ) # set of Apply instances with non-null destroy_map self.view_i = {} # variable -> variable used in calculation self.view_o = ( {} ) # variable -> set of variables that use this one as a direct input - # clients: how many times does an apply use a given variable - self.clients = OrderedDict() # variable -> apply -> ninputs + + # The following map tracks how many times a variable is referenced by an `Apply` node. + # It doesn't actually use `Apply` nodes, though, because doing so would require + # that we update the `Apply` nodes on every replacement in the graph. Instead, we use + # the first output `Variable` of an `Apply` node. Since `Variable.owner` is updated + # whenever a replacement is made, these representative output `Variable`s will always + # point to the appropriate `Apply` node. + self.clients = OrderedDict() + + # Set of output `Variable`s representing `Apply` nodes (see the + # description for `self.clients`) with non-null `Op.destroy_map`s. + self.destroyers = OrderedSet() + self.stale_droot = True self.debug_all_apps = set() @@ -497,72 +482,75 @@ def fast_destroy(self, fgraph, app, reason): # assert len(v) <= 1 # assert len(d) <= 1 - def on_import(self, fgraph, app, reason): + def on_import(self, fgraph, node, reason): """ Add Apply instance to set which must be computed. """ - if app in self.debug_all_apps: - raise ProtocolError("double import") - self.debug_all_apps.add(app) - # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) + # Choose an output to represent the `Apply` node + rep_out = node.outputs[0] + + if rep_out in self.debug_all_apps: + return + + self.debug_all_apps.add(rep_out) # If it's a destructive op, add it to our watch list - dmap = app.op.destroy_map - vmap = app.op.view_map + dmap = node.op.destroy_map + vmap = node.op.view_map if dmap: - self.destroyers.add(app) + self.destroyers.add(rep_out) if self.algo == "fast": - self.fast_destroy(fgraph, app, reason) + self.fast_destroy(fgraph, node, reason) # add this symbol to the forward and backward maps for o_idx, i_idx_list in vmap.items(): if len(i_idx_list) > 1: raise NotImplementedError( - "destroying this output invalidates multiple inputs", (app.op) + "destroying this output invalidates multiple inputs", (node.op) ) - o = app.outputs[o_idx] - i = app.inputs[i_idx_list[0]] + o = node.outputs[o_idx] + i = node.inputs[i_idx_list[0]] self.view_i[o] = i self.view_o.setdefault(i, OrderedSet()).add(o) - # update self.clients - for i, input in enumerate(app.inputs): - self.clients.setdefault(input, OrderedDict()).setdefault(app, 0) - self.clients[input][app] += 1 + for i, input in enumerate(node.inputs): + self.clients.setdefault(input, OrderedDict()).setdefault(rep_out, 0) + self.clients[input][rep_out] += 1 - for i, output in enumerate(app.outputs): + for i, output in enumerate(node.outputs): self.clients.setdefault(output, OrderedDict()) self.stale_droot = True - def on_prune(self, fgraph, app, reason): + def on_prune(self, fgraph, node, reason): """ Remove Apply instance from set which must be computed. """ - if app not in self.debug_all_apps: - raise ProtocolError("prune without import") - self.debug_all_apps.remove(app) + # Choose an output to represent the `Apply` node + rep_out = node.outputs[0] + + assert rep_out in self.debug_all_apps + + self.debug_all_apps.remove(rep_out) - # UPDATE self.clients - for input in set(app.inputs): - del self.clients[input][app] + for input in set(node.inputs): + del self.clients[input][rep_out] - if app.op.destroy_map: - self.destroyers.remove(app) + if node.op.destroy_map: + self.destroyers.remove(rep_out) # Note: leaving empty client dictionaries in the struct. # Why? It's a pain to remove them. I think they aren't doing any harm, they will be # deleted on_detach(). - # UPDATE self.view_i, self.view_o - for o_idx, i_idx_list in app.op.view_map.items(): + for o_idx, i_idx_list in node.op.view_map.items(): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs raise NotImplementedError() - o = app.outputs[o_idx] - i = app.inputs[i_idx_list[0]] + o = node.outputs[o_idx] + i = node.inputs[i_idx_list[0]] del self.view_i[o] @@ -571,53 +559,61 @@ def on_prune(self, fgraph, app, reason): del self.view_o[i] self.stale_droot = True - if app in self.fail_validate: - del self.fail_validate[app] + if rep_out in self.fail_validate: + del self.fail_validate[rep_out] - def on_change_input(self, fgraph, app, i, old_r, new_r, reason): - """ - app.inputs[i] changed from old_r to new_r. + def on_change_input( + self, fgraph, old_node, new_node, input_idx, old_var, new_var, reason + ): + """Update the clients and view mappings.""" - """ - if app == "output": + if old_node != "output": # app == 'output' is special key that means FunctionGraph is redefining which nodes are being # considered 'outputs' of the graph. - pass - else: - if app not in self.debug_all_apps: - raise ProtocolError("change without import") - - # UPDATE self.clients - self.clients[old_r][app] -= 1 - if self.clients[old_r][app] == 0: - del self.clients[old_r][app] - - self.clients.setdefault(new_r, OrderedDict()).setdefault(app, 0) - self.clients[new_r][app] += 1 - - # UPDATE self.view_i, self.view_o - for o_idx, i_idx_list in app.op.view_map.items(): - if len(i_idx_list) > 1: - # destroying this output invalidates multiple inputs - raise NotImplementedError() + + # Use the first output to represent the `Apply` node. + # N.B. The old node's outputs should be the same as the new node's + # outputs. + rep_out = old_node.outputs[0] + + assert rep_out in self.debug_all_apps + + new_count = self.clients[old_var][rep_out] - 1 + + assert new_count >= 0 + + if new_count == 0: + del self.clients[old_var][rep_out] + else: + self.clients[old_var][rep_out] = new_count + + self.clients.setdefault(new_var, OrderedDict()).setdefault(rep_out, 0) + self.clients[new_var][rep_out] += 1 + + for o_idx, i_idx_list in new_node.op.view_map.items(): + + # Destroying this output would invalidate multiple inputs, and + # that's not currently supported + assert len(i_idx_list) == 1 + i_idx = i_idx_list[0] - output = app.outputs[o_idx] - if i_idx == i: - if app.inputs[i_idx] is not new_r: - raise ProtocolError("wrong new_r on change") + output = new_node.outputs[o_idx] + if i_idx == input_idx: + assert new_node.inputs[i_idx] is new_var - self.view_i[output] = new_r + self.view_i[output] = new_var - self.view_o[old_r].remove(output) - if not self.view_o[old_r]: - del self.view_o[old_r] + self.view_o[old_var].remove(output) + if not self.view_o[old_var]: + del self.view_o[old_var] - self.view_o.setdefault(new_r, OrderedSet()).add(output) + self.view_o.setdefault(new_var, OrderedSet()).add(output) if self.algo == "fast": - if app in self.fail_validate: - del self.fail_validate[app] - self.fast_destroy(fgraph, app, reason) + if rep_out in self.fail_validate: + del self.fail_validate[rep_out] + self.fast_destroy(fgraph, old_node, reason) + self.stale_droot = True def validate(self, fgraph): @@ -632,7 +628,7 @@ def validate(self, fgraph): if self.destroyers: if self.algo == "fast": if self.fail_validate: - app_err_pairs = self.fail_validate + rep_out_err_pairs = self.fail_validate self.fail_validate = OrderedDict() # self.fail_validate can only be a hint that maybe/probably # there is a cycle.This is because inside replace() we could @@ -641,12 +637,14 @@ def validate(self, fgraph): # graph might have already changed when we raise the # self.fail_validate error. So before raising the error, we # double check here. - for app in app_err_pairs: + for rep_out in rep_out_err_pairs: + app = rep_out.owner if app in fgraph.apply_nodes: self.fast_destroy(fgraph, app, "validate") + if self.fail_validate: - self.fail_validate = app_err_pairs - raise app_err_pairs[app] + self.fail_validate = rep_out_err_pairs + raise rep_out_err_pairs[rep_out] else: ords = self.orderings(fgraph, ordered=False) if _contains_cycle(fgraph, ords): @@ -700,13 +698,14 @@ def orderings(self, fgraph, ordered=True): ) # add destroyed variable clients as computational dependencies - for app in self.destroyers: + for rep_out in self.destroyers: + destroyer_node = rep_out.owner # keep track of clients that should run before the current Apply root_clients = set_type() # for each destroyed input... - for output_idx, input_idx_list in app.op.destroy_map.items(): + for output_idx, input_idx_list in destroyer_node.op.destroy_map.items(): destroyed_idx = input_idx_list[0] - destroyed_variable = app.inputs[destroyed_idx] + destroyed_variable = destroyer_node.inputs[destroyed_idx] root = droot[destroyed_variable] root_impact = impact[root] # we generally want to put all clients of things which depend on root @@ -744,27 +743,29 @@ def orderings(self, fgraph, ordered=True): # CHECK FOR INPUT ALIASING # OPT: pre-compute this on import - tolerate_same = getattr(app.op, "destroyhandler_tolerate_same", []) + tolerate_same = getattr( + destroyer_node.op, "destroyhandler_tolerate_same", [] + ) assert isinstance(tolerate_same, list) tolerated = { idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx } tolerated.add(destroyed_idx) tolerate_aliased = getattr( - app.op, "destroyhandler_tolerate_aliased", [] + destroyer_node.op, "destroyhandler_tolerate_aliased", [] ) assert isinstance(tolerate_aliased, list) ignored = { idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx } - for i, input in enumerate(app.inputs): + for i, input in enumerate(destroyer_node.inputs): if i in ignored: continue if input in root_impact and ( i not in tolerated or input is not destroyed_variable ): raise InconsistencyError( - f"Input aliasing: {app} ({destroyed_idx}, {i})" + f"Input aliasing: {destroyer_node} ({destroyed_idx}, {i})" ) # add the rule: app must be preceded by all other Apply instances that @@ -777,8 +778,8 @@ def orderings(self, fgraph, ordered=True): # app itself is a client of the destroyed inputs, # but should not run before itself - root_clients.remove(app) + root_clients.remove(rep_out) if root_clients: - rval[app] = root_clients + rval[destroyer_node] = root_clients return rval diff --git a/aesara/graph/features.py b/aesara/graph/features.py index 73a625409f..af33468438 100644 --- a/aesara/graph/features.py +++ b/aesara/graph/features.py @@ -5,6 +5,7 @@ from collections import OrderedDict from functools import partial from io import StringIO +from typing import TYPE_CHECKING, Mapping, Optional, Sequence import numpy as np @@ -14,6 +15,11 @@ from aesara.graph.utils import InconsistencyError +if TYPE_CHECKING: + from aesara.graph.basic import Apply + from aesara.graph.fg import FunctionGraph + + class AlreadyThere(Exception): """ Raised by a Feature's on_attach callback method if the FunctionGraph @@ -262,31 +268,31 @@ class Feature: """ - def on_attach(self, fgraph): - """ + def on_attach(self, fgraph) -> None: + """Handle the association of an `FunctionGraph` with this `Feature`. + Called by `FunctionGraph.attach_feature`, the method that attaches the feature to the `FunctionGraph`. Since this is called after the `FunctionGraph` is initially populated, this is where you should run checks on the initial contents of the `FunctionGraph`. - The on_attach method may raise the `AlreadyThere` exception to cancel - the attach operation if it detects that another Feature instance - implementing the same functionality is already attached to the + This method may raise an `AlreadyThere` exception to cancel the + attachment operation, e.g. if it detects that another `Feature` + instance implementing the same functionality is already attached to the `FunctionGraph`. - The feature has great freedom in what it can do with the `fgraph`: it - may, for example, add methods to it dynamically. - """ - def on_detach(self, fgraph): + def on_detach(self, fgraph: "FunctionGraph") -> None: """ Called by `FunctionGraph.remove_feature`. Should remove any dynamically-added functionality that it installed into the fgraph. """ - def on_import(self, fgraph, node, reason): + def on_import( + self, fgraph: "FunctionGraph", node: "Apply", reason: Optional[str] = None + ) -> None: """ Called whenever a node is imported into `fgraph`, which is just before the node is actually connected to the graph. @@ -297,36 +303,55 @@ def on_import(self, fgraph, node, reason): """ - def on_change_input(self, fgraph, node, i, var, new_var, reason=None): - """ - Called whenever ``node.inputs[i]`` is changed from `var` to `new_var`. - At the moment the callback is done, the change has already taken place. + def on_change_input( + self, + fgraph: "FunctionGraph", + old_node: "Apply", + new_node: "Apply", + i: int, + old_var: Variable, + new_var: Variable, + reason: Optional[str] = None, + ) -> None: + """Handle node and input replacements. - If you raise an exception in this function, the state of the graph - might be broken for all intents and purposes. + This is called whenever ``node.inputs[i]`` is changed from `old_var` to + `new_var`, and, since `Apply` nodes represent a distinct set of inputs, + a new node is created to replace the old one. - """ + When this method is called, the change has already been made. + + Warning: If an exception is raised in this function, the state of the + graph could become invalid. - def on_prune(self, fgraph, node, reason): """ + + def on_prune( + self, fgraph: "FunctionGraph", node: "Apply", reason: Optional[str] = None + ) -> None: + """Handle removal of an `Apply` node. + Called whenever a node is pruned (removed) from the `fgraph`, after it is disconnected from the graph. """ - def orderings(self, fgraph): - """ - Called by `FunctionGraph.toposort`. It should return a dictionary of + def orderings(self, fgraph: "FunctionGraph") -> Mapping["Apply", Sequence["Apply"]]: + """Return a dictionary mapping nodes to their predecessors. + + It should return a dictionary of ``{node: predecessors}`` where ``predecessors`` is a list of nodes that should be computed before the key node. - If you raise an exception in this function, the state of the graph - might be broken for all intents and purposes. + This is called by `FunctionGraph.toposort`. + + Warning: If an exception is raised in this function, the state of the + graph could become invalid. """ return OrderedDict() - def clone(self): + def clone(self) -> "Feature": """Create a clone that can be attached to a new `FunctionGraph`. This default implementation returns `self`, which carries the @@ -361,16 +386,23 @@ def __call__(self): class LambdaExtract: - def __init__(self, fgraph, node, i, r, reason=None): + """A class that represents `change_node_input` calls.""" + + def __init__(self, fgraph, old_node, new_node, i, old_var, reason=None): self.fgraph = fgraph - self.node = node + self.old_node = old_node + self.new_node = new_node self.i = i - self.r = r + self.old_var = old_var self.reason = reason def __call__(self): return self.fgraph.change_node_input( - self.node, self.i, self.r, reason=("Revert", self.reason), check=False + self.new_node, + self.i, + self.old_var, + reason=f"Revert: {self.reason}", + check=False, ) @@ -417,11 +449,13 @@ def on_detach(self, fgraph): del fgraph.revert del self.history[fgraph] - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + def on_change_input( + self, fgraph, old_node, new_node, i, old_var, new_var, reason=None + ): if self.history[fgraph] is None: return h = self.history[fgraph] - h.append(LambdaExtract(fgraph, node, i, r, reason)) + h.append(LambdaExtract(fgraph, old_node, new_node, i, old_var, reason)) def revert(self, fgraph, checkpoint): """ @@ -742,9 +776,13 @@ def on_prune(self, fgraph, node, reason): if self.active: print(f"-- pruning: {node}, reason: {reason}") - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + def on_change_input( + self, fgraph, old_node, new_node, i, old_var, new_var, reason=None + ): if self.active: - print(f"-- changing ({node}.inputs[{i}]) from {r} to {new_r}") + print( + f"-- changing {old_node}.inputs[{i}] from {old_var} to {new_var} resulting in {new_node}" + ) class PreserveVariableAttributes(Feature): @@ -752,14 +790,16 @@ class PreserveVariableAttributes(Feature): This preserve some variables attributes and tag during optimization. """ - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): - if r.name is not None and new_r.name is None: - new_r.name = r.name + def on_change_input( + self, fgraph, old_node, new_node, i, old_var, new_var, reason=None + ): + if old_var.name is not None and new_var.name is None: + new_var.name = old_var.name if ( - getattr(r.tag, "nan_guard_mode_check", False) - and getattr(new_r.tag, "nan_guard_mode_check", False) is False + getattr(old_var.tag, "nan_guard_mode_check", False) + and getattr(new_var.tag, "nan_guard_mode_check", False) is False ): - new_r.tag.nan_guard_mode_check = r.tag.nan_guard_mode_check + new_var.tag.nan_guard_mode_check = old_var.tag.nan_guard_mode_check class NoOutputFromInplace(Feature): diff --git a/aesara/graph/fg.py b/aesara/graph/fg.py index 56c999f871..57800a3457 100644 --- a/aesara/graph/fg.py +++ b/aesara/graph/fg.py @@ -206,7 +206,11 @@ def add_client(self, var: Variable, new_client: ClientType) -> None: raise TypeError( 'The first entry of `new_client` must be an `Apply` node or the string `"output"`' ) - self.clients[var].append(new_client) + var_clients = self.clients[var] + # TODO: This might be another reason to use a type like + # `Dict[Variable, Set[Tuple[Apply, int]]]` for `FeatureGraph.clients` + if new_client not in var_clients: + var_clients.append(new_client) def remove_client( self, @@ -412,15 +416,15 @@ def change_node_input( reason: Optional[str] = None, import_missing: bool = False, check: bool = True, - ) -> None: - """Change ``node.inputs[i]`` to `new_var`. + ) -> Optional[Apply]: + """Create a clone of `node` in which ``node.inputs[i]`` is equal to `new_var`. ``new_var.type.is_super(old_var.type)`` must be ``True``, where ``old_var`` is the current value of ``node.inputs[i]`` which we want to replace. - For each feature that has an `on_change_input` method, this method calls: - ``feature.on_change_input(function_graph, node, i, old_var, new_var, reason)`` + For each feature that has an `Feature.on_change_input` method, this method calls: + ``feature.on_change_input(function_graph, old_node, new_node, i, old_var, new_var, reason)`` Parameters ---------- @@ -440,35 +444,129 @@ def change_node_input( `History` `Feature`, which needs to revert types that have been narrowed and would otherwise fail this check. """ - # TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?) - if node == "output": - r = self.outputs[i] - if check and not r.type.is_super(new_var.type): + + is_output = node == "output" + + if is_output: + old_var = self.outputs[i] + + if old_var is new_var: + return None + + if check and not old_var.type.is_super(new_var.type): raise TypeError( f"The type of the replacement ({new_var.type}) must be " - f"compatible with the type of the original Variable ({r.type})." + f"compatible with the type of the original Variable ({old_var.type})." ) + self.outputs[i] = new_var + new_node: Optional[Apply] = new_var.owner + + self.import_var(new_var, reason=reason, import_missing=import_missing) + self.add_client(new_var, (node, i)) + self.remove_client(old_var, (node, i), reason=reason) + self.execute_callbacks( + "on_change_input", node, node, i, old_var, new_var, reason=reason + ) else: assert isinstance(node, Apply) - r = node.inputs[i] - if check and not r.type.is_super(new_var.type): + old_var = node.inputs[i] + + if old_var is new_var: + return None + + if check and not old_var.type.is_super(new_var.type): raise TypeError( f"The type of the replacement ({new_var.type}) must be " - f"compatible with the type of the original Variable ({r.type})." + f"compatible with the type of the original Variable ({old_var.type})." ) - node.inputs[i] = new_var - if r is new_var: - return + self.import_var(new_var, reason=reason, import_missing=import_missing) + + # In this case, we need to construct a new `Apply` node with + # `node.inputs[i] = new_var` + new_inputs = list(node.inputs) + new_inputs[i] = new_var + + # By passing `node.outputs` we're assigning those variables + # to this new node (i.e. by resetting `Variable.owner`). + # TODO: Perhaps a `change_owner` callback would be suitable. + new_node = Apply(node.op, new_inputs, node.outputs) + + old_outputs = new_node.outputs + new_node.outputs = node.outputs + + # This is just a sanity check + assert all(o.owner is new_node for o in node.outputs) + + # Next, we need to swap the old `node` with `new_node` in + # `FunctionGraph.clients`, as well as remove any now unused + # nodes and variables induced by the replacement itself. + + if new_node in self.apply_nodes: + # In this case, `new_node` isn't actually new to the graph, so + # all the entries connecting `new_node.inputs` to `new_node` + # are already present in `FunctionGraph.clients`. All we need + # to do is replace references to `new_node.outputs` (i.e. the + # pre-existing node) with `node.outputs`. + for old_out in old_outputs: + for o_node, o_i in self.clients[old_out]: + self.apply_nodes.remove( + o_node if o_node != "output" else self.outputs[o_i].owner + ) + + del self.clients[old_out] + self.variables.remove(old_out) + + else: + self.apply_nodes.add(new_node) + # self._import_node(new_node, reason=reason) + + self.add_client(new_var, (new_node, i)) + + # We need to replace all client references to the old node with the + # new node + for j, inp in enumerate(node.inputs): + if j != i: + self.add_client(inp, (new_node, j)) + # The old variable and node needs to be removed + self.remove_client( + inp, (node, j), reason=reason, remove_if_empty=True + ) + + # TODO: If we know that no intermediate nodes need to be + # removed, then we could perform the node replacements much + # more efficiently + # old_clients = self.clients[inp] + # # TODO: Were the clients list a `dict` mapping nodes to input + # # positions, we could simplify this considerably. + # for k, (client_, input_id) in enumerate(old_clients): + # # client = self.outputs[input_id] if client_ == "output" else client_ + # if client_ == node: + # old_clients[k] = (new_node, input_id) + + self.apply_nodes.remove(node) + + if not hasattr(node.tag, "removed_by"): + node.tag.removed_by = [] + + node.tag.removed_by.append(str(reason)) + + # This is here to simulate the old behavior + self.execute_callbacks("on_prune", node, reason) + self.execute_callbacks("on_import", new_node, reason) + + self.execute_callbacks( + "on_change_input", + node, + new_node, + i, + old_var, + new_var, + reason=reason, + ) - self.import_var(new_var, reason=reason, import_missing=import_missing) - self.add_client(new_var, (node, i)) - self.remove_client(r, (node, i), reason=reason) - # Precondition: the substitution is semantically valid However it may - # introduce cycles to the graph, in which case the transaction will be - # reverted later. - self.execute_callbacks("on_change_input", node, i, r, new_var, reason=reason) + return new_node def replace( self, @@ -532,10 +630,16 @@ def replace( f"test value. Original: {tval_shape}, new: {new_tval_shape}" ) + new_nodes: Dict[ApplyOrOutput, ApplyOrOutput] = {} for node, i in list(self.clients[var]): - self.change_node_input( - node, i, new_var, reason=reason, import_missing=import_missing + new_node = self.change_node_input( + new_nodes.get(node, node), + i, + new_var, + reason=reason, + import_missing=import_missing, ) + new_nodes[node] = new_node or node def replace_all(self, pairs: Iterable[Tuple[Variable, Variable]], **kwargs) -> None: """Replace variables in the `FunctionGraph` according to ``(var, new_var)`` pairs in a list.""" diff --git a/aesara/graph/rewriting/basic.py b/aesara/graph/rewriting/basic.py index 135581820b..ac421e68d4 100644 --- a/aesara/graph/rewriting/basic.py +++ b/aesara/graph/rewriting/basic.py @@ -33,7 +33,7 @@ from aesara.graph.features import AlreadyThere, Feature, NodeFinder from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op -from aesara.graph.utils import AssocList, InconsistencyError +from aesara.graph.utils import InconsistencyError from aesara.misc.ordered_set import OrderedSet from aesara.utils import flatten @@ -531,8 +531,7 @@ def on_attach(self, fgraph): fgraph.merge_feature = self self.seen_atomics = set() - self.atomic_sig = AssocList() - self.atomic_sig_inv = AssocList() + self.canonical_atomics = {} # For all Apply nodes # Set of distinct (not mergeable) nodes @@ -562,15 +561,15 @@ def on_attach(self, fgraph): def clone(self): return type(self)() - def on_change_input(self, fgraph, node, i, r, new_r, reason): - if node in self.nodes_seen: + def on_change_input(self, fgraph, old_node, new_node, i, old_var, new_var, reason): + if old_node in self.nodes_seen: # If inputs to a node change, it's not guaranteed that the node is # distinct from the other nodes in `self.nodes_seen`. - self.nodes_seen.discard(node) - self.process_node(fgraph, node) + self.nodes_seen.discard(old_node) + self.process_node(fgraph, new_node) - if isinstance(new_r, AtomicVariable): - self.process_atomic(fgraph, new_r) + if isinstance(new_var, AtomicVariable): + self.process_atomic(fgraph, new_var) def on_import(self, fgraph, node, reason): for c in node.inputs: @@ -586,17 +585,14 @@ def on_prune(self, fgraph, node, reason): for c in node.inputs: if isinstance(c, AtomicVariable) and len(fgraph.clients[c]) <= 1: # This was the last node using this constant - sig = self.atomic_sig[c] - self.atomic_sig.discard(c) - self.atomic_sig_inv.discard(sig) + self.canonical_atomics.pop(c) self.seen_atomics.discard(id(c)) def process_atomic(self, fgraph, c): """Check if an atomic `c` can be merged, and queue that replacement.""" if id(c) in self.seen_atomics: return - sig = c.merge_signature() - other_c = self.atomic_sig_inv.get(sig, None) + other_c = self.canonical_atomics.get(c, None) if other_c is not None: # multiple names will clobber each other.. # we adopt convention to keep the last name @@ -605,8 +601,7 @@ def process_atomic(self, fgraph, c): self.scheduled.append([[(c, other_c, "merge")]]) else: # this is a new constant - self.atomic_sig[c] = sig - self.atomic_sig_inv[sig] = c + self.canonical_atomics[c] = c self.seen_atomics.add(id(c)) def process_node(self, fgraph, node): @@ -1662,9 +1657,9 @@ def on_prune(self, fgraph, node, reason): if self.pruner: self.pruner(node) - def on_change_input(self, fgraph, node, i, r, new_r, reason): + def on_change_input(self, fgraph, old_node, new_node, i, old_var, new_var, reason): if self.chin: - self.chin(node, i, r, new_r, reason) + self.chin(old_node, new_node, i, old_var, new_var, reason) def on_detach(self, fgraph): # To allow pickling this object @@ -1798,7 +1793,7 @@ def attach_updater( if self.ignore_newtrees: importer = None - if importer is None and pruner is None: + if importer is None and pruner is None and chin is None: return None u = DispatchingFeature(importer, pruner, chin, name=name) @@ -1909,7 +1904,7 @@ def process_node( return False try: fgraph.replace_all_validate_remove( # type: ignore - repl_pairs, reason=node_rewriter, remove=remove + repl_pairs, reason=str(node_rewriter), remove=remove ) return True except Exception as e: @@ -1966,8 +1961,11 @@ def importer(node): if node is not current_node: q.append(node) + def change_input(old_node, new_node, i, old_var, new_var, reason): + q.append(new_node) + u = self.attach_updater( - fgraph, importer, None, name=getattr(self, "name", None) + fgraph, importer, None, chin=change_input, name=getattr(self, "name", None) ) nb = 0 try: @@ -2108,12 +2106,19 @@ def apply(self, fgraph): q = list(fgraph.get_nodes(op)) def importer(node): - if node is not current_node: - if node.op == op: - q.append(node) + if node is not current_node and node.op == op: + q.append(node) + + def change_input(old_node, new_node, i, r, new_r, reason): + if ( + node is not current_node + and isinstance(new_node, Apply) + and new_node.op == op + ): + q.append(new_node) u = self.attach_updater( - fgraph, importer, None, name=getattr(self, "name", None) + fgraph, importer, None, chin=change_input, name=getattr(self, "name", None) ) try: while q: @@ -2142,7 +2147,7 @@ def on_import(self, fgraph, node, reason): self.nb_imported += 1 self.changed = True - def on_change_input(self, fgraph, node, i, r, new_r, reason): + def on_change_input(self, fgraph, old_node, new_node, i, old_var, new_var, reason): self.changed = True def reset(self): @@ -2357,15 +2362,24 @@ def importer(node): if node is not current_node: q.append(node) - chin = None if self.tracks_on_change_inputs: - def chin(node, i, r, new_r, reason): - if node is not current_node and not isinstance(node, str): - q.append(node) + def change_input(old_node, new_node, i, r, new_r, reason): + if old_node is not current_node and isinstance(new_node, Apply): + q.append(new_node) + + else: + + def change_input(old_node, new_node, i, r, new_r, reason): + if isinstance(new_node, Apply): + q.append(new_node) u = self.attach_updater( - fgraph, importer, None, chin=chin, name=getattr(self, "name", None) + fgraph, + importer, + None, + chin=change_input, + name=getattr(self, "name", None), ) try: while q: diff --git a/aesara/link/c/basic.py b/aesara/link/c/basic.py index 8aed25cd13..f601404086 100644 --- a/aesara/link/c/basic.py +++ b/aesara/link/c/basic.py @@ -1416,15 +1416,13 @@ def in_sig(i, topological_pos, i_idx): # yield a 'position' that reflects its role in code_gen() if isinstance(i, AtomicVariable): # orphans if id(i) not in constant_ids: - isig = (i.signature(), topological_pos, i_idx) + isig = (hash(i), topological_pos, i_idx) # If the Aesara constant provides a strong hash # (no collision for transpose, 2, 1, 0, -1, -2, # 2 element swapped...) we put this hash in the signature # instead of the value. This makes the key file much # smaller for big constant arrays. Before this, we saw key # files up to 80M. - if hasattr(isig[0], "aesara_hash"): - isig = (isig[0].aesara_hash(), topological_pos, i_idx) try: hash(isig) except Exception: diff --git a/aesara/link/c/params_type.py b/aesara/link/c/params_type.py index 928b92ed2c..6e4007ef1c 100644 --- a/aesara/link/c/params_type.py +++ b/aesara/link/c/params_type.py @@ -291,9 +291,11 @@ def __hash__(self): # NB: For writing, we must bypass setattr() which is always called by default by Python. self.__dict__["__signatures__"] = tuple( # NB: Params object should have been already filtered. - self.__params_type__.types[i] - .make_constant(self[self.__params_type__.fields[i]]) - .signature() + hash( + self.__params_type__.types[i].make_constant( + self[self.__params_type__.fields[i]] + ) + ) for i in range(self.__params_type__.length) ) return hash((type(self), self.__params_type__) + self.__signatures__) diff --git a/aesara/link/c/type.py b/aesara/link/c/type.py index 33632fa1a6..e3dc2f4c9e 100644 --- a/aesara/link/c/type.py +++ b/aesara/link/c/type.py @@ -292,15 +292,7 @@ def __setstate__(self, dct): class CDataTypeConstant(Constant[T]): - def merge_signature(self): - # We don't want to merge constants that don't point to the - # same object. - return id(self.data) - - def signature(self): - # There is no way to put the data in the signature, so we - # don't even try - return (self.type,) + pass CDataType.constant_type = CDataTypeConstant diff --git a/aesara/sparse/basic.py b/aesara/sparse/basic.py index 46ac71d8ce..46bda1dc93 100644 --- a/aesara/sparse/basic.py +++ b/aesara/sparse/basic.py @@ -24,7 +24,6 @@ from aesara.link.c.type import generic from aesara.misc.safe_asarray import _asarray from aesara.sparse.type import SparseTensorType, _is_sparse -from aesara.sparse.utils import hash_from_sparse from aesara.tensor import basic as at from aesara.tensor.basic import Split from aesara.tensor.math import _conj @@ -441,35 +440,33 @@ def __repr__(self): return str(self) -class SparseConstantSignature(tuple): +class SparseConstant(TensorConstant, _sparse_py_operators): + format = property(lambda self: self.type.format) + + # def __init__(self, *args): + # .view(HashableNDArray) + def __eq__(self, other): - (a, b), (x, y) = self, other - return ( - a == x - and (b.dtype == y.dtype) - and (type(b) == type(y)) - and (b.shape == y.shape) - and (abs(b - y).sum() < 1e-6 * b.nnz) - ) + if isinstance(other, type(self)): + b = self.data + y = other.data + if ( + self.type == other.type + and (b.dtype == y.dtype) + and (type(b) == type(y)) + and (b.shape == y.shape) + and (abs(b - y).sum() < 1e-6 * b.nnz) + ): + return True + return False + + return NotImplemented def __ne__(self, other): return not self == other def __hash__(self): - (a, b) = self - return hash(type(self)) ^ hash(a) ^ hash(type(b)) - - def aesara_hash(self): - (_, d) = self - return hash_from_sparse(d) - - -class SparseConstant(TensorConstant, _sparse_py_operators): - format = property(lambda self: self.type.format) - - def signature(self): - assert self.data is not None - return SparseConstantSignature((self.type, self.data)) + return hash((type(self), self.type, self.data)) def __str__(self): return "{}{{{},{},shape={},nnz={}}}".format( diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index 4762d903d2..7009d60a01 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -228,6 +228,8 @@ def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant: ttype = TensorType(dtype=x_.dtype, shape=x_.shape) + x_.setflags(write=0) + return TensorConstant(ttype, x_, name=name) diff --git a/aesara/tensor/rewriting/shape.py b/aesara/tensor/rewriting/shape.py index 87d77b1322..da0ea410d4 100644 --- a/aesara/tensor/rewriting/shape.py +++ b/aesara/tensor/rewriting/shape.py @@ -1,6 +1,6 @@ import traceback from io import StringIO -from typing import Optional +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from typing import cast as type_cast from warnings import warn @@ -51,6 +51,10 @@ from aesara.tensor.type_other import NoneConst +if TYPE_CHECKING: + from aesara.graph.basic import Apply + + class ShapeFeature(Feature): r"""A `Feature` that tracks shape information in a graph. @@ -366,8 +370,8 @@ def set_shape(self, r, s, override=False): assert all( not hasattr(r.type, "shape") or r.type.shape[i] != 1 - or self.lscalar_one.equals(shape_vars[i]) - or self.lscalar_one.equals(extract_constant(shape_vars[i])) + or self.lscalar_one == shape_vars[i] + or self.lscalar_one == extract_constant(shape_vars[i]) for i in range(r.type.ndim) ) self.shape_of[r] = tuple(shape_vars) @@ -508,13 +512,13 @@ def on_attach(self, fgraph): self.lscalar_one = constant(1, dtype="int64") assert self.lscalar_one.type.dtype == "int64" - self.fgraph = fgraph + self.fgraph: FunctionGraph = fgraph # Variable -> tuple(scalars) or None (All tensor vars map to tuple) - self.shape_of = {} + self.shape_of: Dict[Variable, Optional[Tuple[Variable]]] = {} # Variable -> - self.scheduled = {} + self.scheduled: Dict["Apply", Variable] = {} # shape var -> graph v - self.shape_of_reverse_index = {} + self.shape_of_reverse_index: Dict[Variable, Set[Variable]] = {} for node in fgraph.toposort(): self.on_import(fgraph, node, reason="on_attach") @@ -586,34 +590,37 @@ def on_import(self, fgraph, node, reason): for r, s in zip(node.outputs, o_shapes): self.set_shape(r, s) - def on_change_input(self, fgraph, node, i, r, new_r, reason): - if new_r not in self.shape_of: - # It happen that the fgraph didn't called on_import for some - # new_r. This happen when new_r don't have an - # owner(i.e. it is a constant or an input of the graph) - # update_shape suppose that r and new_r are in shape_of. - self.init_r(new_r) + def on_change_input(self, fgraph, old_node, new_node, i, old_var, new_var, reason): + if new_var not in self.shape_of: + # It happen that the fgraph didn't call `ShapeFeature.on_import` for some + # `new_var`. This can happen when `new_var` doesn't have an + # owner (i.e. it is a constant or an input of the graph). + # FYI: `ShapeFeature.update_shape` suppose that `old_var` and `new_var` are in shape_of. + self.init_r(new_var) - # This tells us that r and new_r must have the same shape if + # This tells us that `old_var` and `new_var` must have the same shape if # we didn't know that the shapes are related, now we do. - self.update_shape(new_r, r) + self.update_shape(new_var, old_var) - # change_input happens in two cases: - # 1) we are trying to get rid of r, or + # Let's consider two (mutually exclusive?) cases: + # 1) we are trying to get rid of `old_var`, or # 2) we are putting things back after a failed transaction. - - # In case 1, if r has a shape_i client, we will want to - # replace the shape_i of r with the shape of new_r. Say that - # r is *scheduled*. - # At that point, node is no longer a client of r, but of new_r - for (shpnode, idx) in fgraph.clients[r] + [(node, i)]: + # + # In case 1, if `old_var` has a `ShapeFeature.shape_i` client, we will want to + # replace the shape_i of `old_var` with the shape of `new_var` (i.e. we say that + # `old_var` is *scheduled*). + # + # At that point, `old_node` is no longer a client of `old_var`, and all the clients + # of `old_node` now belong to `new_node`. + + for (shpnode, idx) in fgraph.clients.get(old_var, []) + [(new_node, i)]: if isinstance(getattr(shpnode, "op", None), Shape_i): idx = shpnode.op.i - repl = self.shape_of[new_r][idx] + repl = self.shape_of[new_var][idx] if repl.owner is shpnode: - # This mean the replacement shape object is - # exactly the same as the current shape object. So - # no need for replacement. + # This means the replacement shape object is exactly the + # same as the current shape object, so no need for + # replacement. continue if ( repl.owner @@ -629,30 +636,31 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason): if shpnode.outputs[0] in ancestors([repl]): raise InconsistencyError( "This substitution would insert a cycle in the graph:" - f"node: {node}, i: {i}, r: {r}, new_r: {new_r}" + f"old_node: {old_node}, new_node: {new_node}, i: {i}, " + f"old_var: {old_var}, new_var: {new_var}" ) - self.scheduled[shpnode] = new_r - # In case 2, if r is a variable that we've scheduled for shape update, + self.scheduled[shpnode] = new_var + # In case 2, if `old_var` is a variable that we've scheduled for shape update, # then we should cancel it. - unscheduled = [k for k, v in self.scheduled.items() if v == r] + unscheduled = [k for k, v in self.scheduled.items() if v == old_var] for k in unscheduled: del self.scheduled[k] - # In either case, r could be in shape_of.values(), that is, r itself - # is the shape of something. In that case, we want to update - # the value in shape_of, to keep it up-to-date. - for v in self.shape_of_reverse_index.get(r, []): + # In either case, `old_var` could be in shape_of.values(), that is, + # `old_var` itself is the shape of something. In that case, we want to + # update the value in shape_of, to keep it up-to-date. + for v in self.shape_of_reverse_index.get(old_var, ()): # The reverse index is only approximate. It is not updated on - # deletion of variables, or on change_input so it might be the - # case that there are a few extra `v`'s in it that no longer have - # a shape of r or possibly have been deleted from shape_of - # entirely. The important thing is that it permits to recall - # all variables with r in their shape. - for ii, svi in enumerate(self.shape_of.get(v, [])): - if svi == r: - self.set_shape_i(v, ii, new_r) - self.shape_of_reverse_index[r] = set() + # deletion of variables or `Feature.on_change_input`, so it might + # be the case that there are a few extra `v`'s in it that no longer + # have a shape of `old_var` or possibly have been deleted from + # `ShapeFeature.shape_of` entirely. The important thing is that it + # permits to recall all variables with `old_var` in their shape. + for ii, svi in enumerate(self.shape_of.get(v, ())): + if svi == old_var: + self.set_shape_i(v, ii, new_var) + self.shape_of_reverse_index[old_var] = set() def same_shape( self, @@ -684,10 +692,10 @@ def same_shape( return False if dim_x is not None: - sx = [sx[dim_x]] + sx = (sx[dim_x],) if dim_y is not None: - sy = [sy[dim_y]] + sy = (sy[dim_y],) if len(sx) != len(sy): return False @@ -710,11 +718,7 @@ def same_shape( rewrite_graph(shapes_fg, custom_rewrite=topo_constant_folding), ) canon_shapes = canon_shapes_fg.outputs - - sx = canon_shapes[: len(sx)] - sy = canon_shapes[len(sx) :] - - for dx, dy in zip(sx, sy): + for dx, dy in zip(canon_shapes[: len(sx)], canon_shapes[len(sx) :]): if not equal_computations([dx], [dy]): return False diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 5890b6e22e..74d7ae02a7 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -12,7 +12,7 @@ from aesara.graph.utils import MetaType from aesara.link.c.type import CType from aesara.misc.safe_asarray import _asarray -from aesara.utils import apply_across_args +from aesara.utils import HashableNDArray, apply_across_args if TYPE_CHECKING: @@ -64,7 +64,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): filter_checks_isfinite = False """ When this is ``True``, strict filtering rejects data containing - ``numpy.nan`` or ``numpy.inf`` entries. (Used in `DebugMode`) + `numpy.nan` or `numpy.inf` entries. (Used in `DebugMode`) """ def __init__( @@ -253,6 +253,13 @@ def filter(self, data, strict=False, allow_downcast=None): if self.filter_checks_isfinite and not np.all(np.isfinite(data)): raise ValueError("Non-finite elements not allowed") + + if not isinstance(data, HashableNDArray): + return data.view(HashableNDArray) + + # Make sure it's read-only so that we can cache hash values and such + data.setflags(write=0) + return data def filter_variable(self, other, allow_convert=True): diff --git a/aesara/tensor/type_other.py b/aesara/tensor/type_other.py index e0c438c5e5..00c7ed3048 100644 --- a/aesara/tensor/type_other.py +++ b/aesara/tensor/type_other.py @@ -57,12 +57,25 @@ def clone(self, **kwargs): def filter(self, x, strict=False, allow_downcast=None): if isinstance(x, slice): + + if isinstance(x.start, np.ndarray): + assert str(x.start.dtype) in integer_dtypes + x = slice(x.start.item(), x.stop, x.step) + + if isinstance(x.stop, np.ndarray): + assert str(x.stop.dtype) in integer_dtypes + x = slice(x.start, x.stop.item(), x.step) + + if isinstance(x.step, np.ndarray): + assert str(x.step.dtype) in integer_dtypes + x = slice(x.start, x.stop, x.step.item()) + return x else: raise TypeError("Expected a slice!") def __str__(self): - return "slice" + return f"{type(self)}()" def __eq__(self, other): return type(self) == type(other) @@ -80,25 +93,23 @@ def may_share_memory(a, b): class SliceConstant(Constant): + @classmethod + def create_key(cls, type, data, *args, **kwargs): + return (type, data.start, data.stop, data.step) + def __init__(self, type, data, name=None): - assert isinstance(data, slice) - # Numpy ndarray aren't hashable, so get rid of them. - if isinstance(data.start, np.ndarray): - assert data.start.ndim == 0 - assert str(data.start.dtype) in integer_dtypes - data = slice(int(data.start), data.stop, data.step) - elif isinstance(data.stop, np.ndarray): - assert data.stop.ndim == 0 - assert str(data.stop.dtype) in integer_dtypes - data = slice(data.start, int(data.stop), data.step) - elif isinstance(data.step, np.ndarray): - assert data.step.ndim == 0 - assert str(data.step.dtype) in integer_dtypes - data = slice(data.start, int(data.stop), data.step) - Constant.__init__(self, type, data, name) - - def signature(self): - return (SliceConstant, self.data.start, self.data.stop, self.data.step) + super().__init__(type, data, name) + + def __eq__(self, other): + if isinstance(other, type(self)): + if self.data == other.data: + return True + return False + + return NotImplemented + + def __hash__(self): + return hash(self.data.__reduce__()) def __str__(self): return "{}{{{}, {}, {}}}".format( diff --git a/aesara/tensor/var.py b/aesara/tensor/var.py index 8b281e6bd0..3a7d3bb02d 100644 --- a/aesara/tensor/var.py +++ b/aesara/tensor/var.py @@ -1,4 +1,3 @@ -import copy import traceback as tb import warnings from collections.abc import Iterable @@ -16,7 +15,6 @@ from aesara.tensor.exceptions import AdvancedIndexingError from aesara.tensor.type import TensorType from aesara.tensor.type_other import NoneConst -from aesara.tensor.utils import hash_from_ndarray _TensorTypeType = TypeVar("_TensorTypeType", bound=TensorType) @@ -877,119 +875,6 @@ def _get_vector_length_TensorVariable(op_or_var, var): TensorType.variable_type = TensorVariable -class TensorConstantSignature(tuple): - r"""A signature object for comparing `TensorConstant` instances. - - An instance is a pair with the type ``(Type, ndarray)``. - - TODO FIXME: Subclassing `tuple` is unnecessary, and it appears to be - preventing the use of a much more convenient `__init__` that removes the - need for all these lazy computations and their safety checks. - - Also, why do we even need this signature stuff? We could simply implement - good `Constant.__eq__` and `Constant.__hash__` implementations. - - We could also produce plain `tuple`\s with hashable values. - - """ - - def __eq__(self, other): - if type(self) != type(other): - return False - try: - (t0, d0), (t1, d1) = self, other - except Exception: - return False - - # N.B. compare shape to ensure no broadcasting in == - if t0 != t1 or d0.shape != d1.shape: - return False - - self.no_nan # Ensure has_nan is computed. - # Note that in the comparisons below, the elementwise comparisons - # come last because they are the most expensive checks. - if self.has_nan: - other.no_nan # Ensure has_nan is computed. - return ( - other.has_nan - and self.sum == other.sum - and (self.no_nan.mask == other.no_nan.mask).all() - and - # Note that the second test below (==) may crash e.g. for - # a single scalar NaN value, so we do not run it when all - # values are missing. - (self.no_nan.mask.all() or (self.no_nan == other.no_nan).all()) - ) - else: - # Simple case where we do not need to worry about NaN values. - # (note that if there are NaN values in d1, this will return - # False, which is why we do not bother with testing `other.has_nan` - # here). - return (self.sum == other.sum) and np.all(d0 == d1) - - def __ne__(self, other): - return not self == other - - def __hash__(self): - t, d = self - return hash((type(self), t, d.shape, self.sum)) - - def aesara_hash(self): - _, d = self - return hash_from_ndarray(d) - - @property - def sum(self): - """Compute sum of non NaN / Inf values in the array.""" - try: - return self._sum - except AttributeError: - - # Prevent warnings when there are `inf`s and `-inf`s present - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=RuntimeWarning) - self._sum = self.no_nan.sum() - - # The following 2 lines are needed as in Python 3.3 with NumPy - # 1.7.1, numpy.ndarray and numpy.memmap aren't hashable. - if isinstance(self._sum, np.memmap): - self._sum = np.asarray(self._sum).item() - - if self.has_nan and self.no_nan.mask.all(): - # In this case the sum is not properly computed by numpy. - self._sum = 0 - - if np.isinf(self._sum) or np.isnan(self._sum): - # NaN may happen when there are both -inf and +inf values. - if self.has_nan: - # Filter both NaN and Inf values. - mask = self.no_nan.mask + np.isinf(self[1]) - else: - # Filter only Inf values. - mask = np.isinf(self[1]) - if mask.all(): - self._sum = 0 - else: - self._sum = np.ma.masked_array(self[1], mask).sum() - # At this point there should be no more NaN. - assert not np.isnan(self._sum) - - if isinstance(self._sum, np.ma.core.MaskedConstant): - self._sum = 0 - - return self._sum - - @property - def no_nan(self): - try: - return self._no_nan - except AttributeError: - nans = np.isnan(self[1]) - self._no_nan = np.ma.masked_array(self[1], nans) - self.has_nan = np.any(nans) - return self._no_nan - - def get_unique_value(x: TensorVariable) -> Optional[Number]: """Return the unique value of a tensor, if there is one""" if isinstance(x, Constant): @@ -998,7 +883,7 @@ def get_unique_value(x: TensorVariable) -> Optional[Number]: if isinstance(data, np.ndarray) and data.ndim > 0: flat_data = data.ravel() if flat_data.shape[0]: - if (flat_data == flat_data[0]).all(): + if np.all(flat_data == flat_data[0]): return flat_data[0] return None @@ -1022,6 +907,8 @@ def __init__(self, type: _TensorTypeType, data, name=None): assert not any(s is None for s in new_type.shape) + data.setflags(write=0) + Constant.__init__(self, new_type, data, name) def __str__(self): @@ -1039,31 +926,6 @@ def __str__(self): name = "TensorConstant" return "%s{%s}" % (name, val) - def signature(self): - return TensorConstantSignature((self.type, self.data)) - - def equals(self, other): - # Override Constant.equals to allow to compare with - # numpy.ndarray, and python type. - if isinstance(other, (np.ndarray, int, float)): - # Make a TensorConstant to be able to compare - other = at.basic.constant(other) - return ( - isinstance(other, TensorConstant) and self.signature() == other.signature() - ) - - def __copy__(self): - # We need to do this to remove the cached attribute - return type(self)(self.type, self.data, self.name) - - def __deepcopy__(self, memo): - # We need to do this to remove the cached attribute - return type(self)( - copy.deepcopy(self.type, memo), - copy.deepcopy(self.data, memo), - copy.deepcopy(self.name, memo), - ) - TensorType.constant_type = TensorConstant diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index d0d85030a6..0350ed28bd 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -42,7 +42,9 @@ class AssertNoChanges(Feature): """A `Feature` that raises an error when nodes are changed in a graph.""" - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + def on_change_input( + self, fgraph, old_node, new_node, i, old_var, new_var, reason=None + ): raise AssertionError() @@ -591,14 +593,9 @@ def local_rewrite_2(fgraph, node): capres = capsys.readouterr() assert capres.err == "" - assert ( - "rewriting: rewrite local_rewrite_1 replaces node Op1(x, y) with [Op2.0]" - in capres.out - ) - assert ( - "rewriting: rewrite local_rewrite_2 replaces node Op2(y, y) with [Op2.0]" - in capres.out - ) + out1, out2 = capres.out.split("\n", maxsplit=1) + assert out1.startswith("rewriting: rewrite local_rewrite_1 replaces") + assert out2.startswith("rewriting: rewrite local_rewrite_2 replaces") def test_node_rewriter_str(): diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index cdd362b00b..aae868b259 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -8,6 +8,7 @@ from aesara import tensor as at from aesara.graph.basic import ( Apply, + Constant, NominalVariable, Variable, ancestors, @@ -41,10 +42,14 @@ ) from aesara.tensor.type_other import NoneConst from aesara.tensor.var import TensorVariable +from aesara.utils import HashableNDArray from tests import unittest_tools as utt from tests.graph.utils import MyInnerGraphOp +pytestmark = pytest.mark.filterwarnings("error") + + class MyType(Type): def __init__(self, thingy): self.thingy = thingy @@ -84,7 +89,7 @@ def perform(self, *args, **kwargs): raise NotImplementedError("No Python implementation available.") -MyOp = MyOp() +my_op = MyOp() def leaf_formatter(leaf): @@ -107,29 +112,29 @@ def format_graph(inputs, outputs): class TestStr: def test_as_string(self): r1, r2 = MyVariable(1), MyVariable(2) - node = MyOp.make_node(r1, r2) + node = my_op.make_node(r1, r2) s = format_graph([r1, r2], node.outputs) assert s == ["MyOp(R1, R2)"] def test_as_string_deep(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(r1, r2) - node2 = MyOp.make_node(node.outputs[0], r5) + node = my_op.make_node(r1, r2) + node2 = my_op.make_node(node.outputs[0], r5) s = format_graph([r1, r2, r5], node2.outputs) assert s == ["MyOp(MyOp(R1, R2), R5)"] def test_multiple_references(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(r1, r2) - node2 = MyOp.make_node(node.outputs[0], node.outputs[0]) + node = my_op.make_node(r1, r2) + node2 = my_op.make_node(node.outputs[0], node.outputs[0]) assert format_graph([r1, r2, r5], node2.outputs) == [ "MyOp(*1 -> MyOp(R1, R2), *1)" ] def test_cutoff(self): r1, r2 = MyVariable(1), MyVariable(2) - node = MyOp.make_node(r1, r2) - node2 = MyOp.make_node(node.outputs[0], node.outputs[0]) + node = my_op.make_node(r1, r2) + node2 = my_op.make_node(node.outputs[0], node.outputs[0]) assert format_graph(node.outputs, node2.outputs) == ["MyOp(R3, R3)"] assert format_graph(node2.inputs, node2.outputs) == ["MyOp(R3, R3)"] @@ -137,43 +142,27 @@ def test_cutoff(self): class TestClone: def test_accurate(self): r1, r2 = MyVariable(1), MyVariable(2) - node = MyOp.make_node(r1, r2) - _, new = clone([r1, r2], node.outputs, False) + node = my_op.make_node(r1, r2) + _, new = clone([r1, r2], node.outputs, copy_inputs=False) assert format_graph([r1, r2], new) == ["MyOp(R1, R2)"] def test_copy(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(r1, r2) - node2 = MyOp.make_node(node.outputs[0], r5) - _, new = clone([r1, r2, r5], node2.outputs, False) - assert ( - node2.outputs[0].type == new[0].type and node2.outputs[0] is not new[0] - ) # the new output is like the old one but not the same object - assert node2 is not new[0].owner # the new output has a new owner + node = my_op.make_node(r1, r2) + node2 = my_op.make_node(node.outputs[0], r5) + _, new = clone([r1, r2, r5], node2.outputs, copy_inputs=False) + assert node2.outputs[0].type == new[0].type and node2.outputs[0] is new[0] + assert node2 is new[0].owner assert new[0].owner.inputs[1] is r5 # the inputs are not copied assert ( new[0].owner.inputs[0].type == node.outputs[0].type - and new[0].owner.inputs[0] is not node.outputs[0] - ) # check that we copied deeper too - - def test_not_destructive(self): - # Checks that manipulating a cloned graph leaves the original unchanged. - r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5) - _, new = clone([r1, r2, r5], node.outputs, False) - new_node = new[0].owner - new_node.inputs = [MyVariable(7), MyVariable(8)] - assert format_graph(graph_inputs(new_node.outputs), new_node.outputs) == [ - "MyOp(R7, R8)" - ] - assert format_graph(graph_inputs(node.outputs), node.outputs) == [ - "MyOp(MyOp(R1, R2), R5)" - ] + and new[0].owner.inputs[0] is node.outputs[0] + ) def test_constant(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5) - _, new = clone([r1, r2, r5], node.outputs, False) + node = my_op.make_node(my_op.make_node(r1, r2).outputs[0], r5) + _, new = clone([r1, r2, r5], node.outputs, copy_inputs=False) new_node = new[0].owner new_node.inputs = [MyVariable(7), MyVariable(8)] c1 = at.constant(1.5) @@ -192,13 +181,13 @@ def test_constant(self): def test_clone_inner_graph(self): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" # Inner graph igo_in_1 = MyVariable(4) igo_in_2 = MyVariable(5) - igo_out_1 = MyOp(igo_in_1, igo_in_2) + igo_out_1 = my_op(igo_in_1, igo_in_2) igo_out_1.name = "igo1" igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) @@ -209,8 +198,8 @@ def test_clone_inner_graph(self): o2_node = o2.owner o2_node_clone = o2_node.clone(clone_inner_graph=True) - assert o2_node_clone is not o2_node - assert o2_node_clone.op.fgraph is not o2_node.op.fgraph + assert o2_node_clone is o2_node + assert o2_node_clone.op.fgraph is o2_node.op.fgraph assert equal_computations( o2_node_clone.op.fgraph.outputs, o2_node.op.fgraph.outputs ) @@ -228,9 +217,9 @@ class TestToposort: def test_simple(self): # Test a simple graph r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - o = MyOp(r1, r2) + o = my_op(r1, r2) o.name = "o1" - o2 = MyOp(o, r5) + o2 = my_op(o, r5) o2.name = "o2" clients = {} @@ -257,49 +246,50 @@ def test_simple(self): def test_double_dependencies(self): # Test a graph with double dependencies r1, r5 = MyVariable(1), MyVariable(5) - o = MyOp.make_node(r1, r1) - o2 = MyOp.make_node(o.outputs[0], r5) + o = my_op.make_node(r1, r1) + o2 = my_op.make_node(o.outputs[0], r5) all = general_toposort(o2.outputs, prenode) assert all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]] def test_inputs_owners(self): # Test a graph where the inputs have owners r1, r5 = MyVariable(1), MyVariable(5) - o = MyOp.make_node(r1, r1) + o = my_op.make_node(r1, r1) r2b = o.outputs[0] - o2 = MyOp.make_node(r2b, r2b) + o2 = my_op.make_node(r2b, r2b) all = io_toposort([r2b], o2.outputs) assert all == [o2] - o2 = MyOp.make_node(r2b, r5) + o2 = my_op.make_node(r2b, r5) all = io_toposort([r2b], o2.outputs) assert all == [o2] def test_not_connected(self): # Test a graph which is not connected r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) - o0 = MyOp.make_node(r1, r2) - o1 = MyOp.make_node(r3, r4) + o0 = my_op.make_node(r1, r2) + o1 = my_op.make_node(r3, r4) all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs) assert all == [o1, o0] or all == [o0, o1] def test_io_chain(self): # Test inputs and outputs mixed together in a chain graph r1, r2 = MyVariable(1), MyVariable(2) - o0 = MyOp.make_node(r1, r2) - o1 = MyOp.make_node(o0.outputs[0], r1) + o0 = my_op.make_node(r1, r2) + o1 = my_op.make_node(o0.outputs[0], r1) all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]]) assert all == [o1] def test_outputs_clients(self): # Test when outputs have clients r1, r2, r4 = MyVariable(1), MyVariable(2), MyVariable(4) - o0 = MyOp.make_node(r1, r2) - MyOp.make_node(o0.outputs[0], r4) + o0 = my_op.make_node(r1, r2) + my_op.make_node(o0.outputs[0], r4) all = io_toposort([], o0.outputs) assert all == [o0] +@pytest.mark.skip(reason="Not finished") class TestEval: def setup_method(self): self.x, self.y = scalars("x", "y") @@ -397,9 +387,9 @@ def test_equal_computations(): def test_walk(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" def expand(r): @@ -428,9 +418,9 @@ def expand(r): def test_ancestors(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" res = ancestors([o2], blockers=None) @@ -450,9 +440,9 @@ def test_ancestors(): def test_graph_inputs(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" res = graph_inputs([o2], blockers=None) @@ -463,9 +453,9 @@ def test_graph_inputs(): def test_variables_and_orphans(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" vars_res = vars_between([r1, r2], [o2]) @@ -480,11 +470,11 @@ def test_variables_and_orphans(): def test_ops(): r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, r4) + o2 = my_op(r3, r4) o2.name = "o2" - o3 = MyOp(r3, o1, o2) + o3 = my_op(r3, o1, o2) o3.name = "o3" res = applys_between([r1, r2], [o3]) @@ -495,9 +485,9 @@ def test_ops(): def test_list_of_nodes(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" res = list_of_nodes([r1, r2], [o2]) @@ -507,9 +497,9 @@ def test_list_of_nodes(): def test_is_in_ancestors(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" assert is_in_ancestors(o2.owner, o1.owner) @@ -528,13 +518,13 @@ def test_view_roots(): def test_get_var_by_name(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" # Inner graph igo_in_1 = MyVariable(4) igo_in_2 = MyVariable(5) - igo_out_1 = MyOp(igo_in_1, igo_in_2) + igo_out_1 = my_op(igo_in_1, igo_in_2) igo_out_1.name = "igo1" igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) @@ -667,6 +657,7 @@ def test_cloning_replace_not_strict_not_copy_inputs(self): assert x not in f2_inp assert y2 not in f2_inp + @pytest.mark.skip(reason="Not finished") def test_clone(self): def test(x, y, mention_y): if mention_y: @@ -771,7 +762,7 @@ def test_NominalVariable(): assert repr(nv5) == f"NominalVariable(2, {repr(type3)})" - assert nv5.signature() == (type3, 2) + assert hash(nv5) == hash((type(nv5), 2, type3)) nv5_pkld = pickle.dumps(nv5) nv5_unpkld = pickle.loads(nv5_pkld) @@ -807,5 +798,81 @@ def test_NominalVariable_create_variable_type(): ntv_unpkld = pickle.loads(ntv_pkld) assert type(ntv_unpkld) is type(ntv) - assert ntv_unpkld.equals(ntv) + assert ntv_unpkld == ntv assert ntv_unpkld is ntv + + +def test_Apply_equivalence(): + + type1 = MyType(1) + + in_1 = Variable(type1, None, name="in_1") + in_2 = Variable(type1, None, name="in_2") + out_10 = Variable(type1, None, name="out_10") + out_11 = Variable(type1, None, name="out_11") + out_12 = Variable(type1, None, name="out_12") + + apply_1 = Apply(my_op, [in_1], [out_10]) + apply_2 = Apply(my_op, [in_1], [out_11]) + apply_3 = Apply(my_op, [in_2], [out_12]) + + assert apply_1 is apply_2 + assert apply_1 == apply_2 + assert apply_1 != apply_3 + assert hash(apply_1) == hash(apply_2) + assert hash(apply_1) != hash(apply_3) + + assert apply_1.inputs == apply_2.inputs + + assert apply_1.outputs == [out_10] + assert apply_2.outputs == [out_10] + # Output `Variable`s should be updated when the constructor is called with + # the same inputs but different outputs. + assert out_10.owner is apply_1 + assert out_11.owner is apply_1 + + apply_1_pkl = pickle.dumps(apply_1) + apply_1_2 = pickle.loads(apply_1_pkl) + + assert apply_1.op == apply_1_2.op + assert len(apply_1.inputs) == len(apply_1_2.inputs) + assert len(apply_1.outputs) == len(apply_1_2.outputs) + assert apply_1.inputs[0].type == apply_1_2.inputs[0].type + assert apply_1.inputs[0].name == apply_1_2.inputs[0].name + assert apply_1.outputs[0].type == apply_1_2.outputs[0].type + assert apply_1.outputs[0].name == apply_1_2.outputs[0].name + + +class MyType2(MyType): + def filter(self, value, **kwargs): + value = np.asarray(value).view(HashableNDArray) + value.setflags(write=0) + return value + + +def test_Constant_equivalence(): + type1 = MyType2(1) + x = Constant(type1, 1.0) + y = Constant(type1, 1.0) + + assert x == y + assert x is y + + rng = np.random.default_rng(3209) + a_val = rng.normal(size=(2, 3)) + c_val = rng.normal(size=(2, 3)) + + a = Constant(type1, a_val) + b = Constant(type1, a_val) + c = Constant(type1, c_val) + + assert a == b + assert a is b + assert a != x + assert a != c + + a_pkl = pickle.dumps(a) + a_2 = pickle.loads(a_pkl) + + assert a.type == a_2.type + assert a.data == a_2.data diff --git a/tests/graph/test_destroyhandler.py b/tests/graph/test_destroyhandler.py index 3470284e66..62e2f22b9c 100644 --- a/tests/graph/test_destroyhandler.py +++ b/tests/graph/test_destroyhandler.py @@ -44,6 +44,9 @@ def filter(self, data): def __eq__(self, other): return isinstance(other, MyType) + def __hash__(self): + return id(self) + def MyVariable(name): return Variable(MyType(), None, None, name=name) @@ -156,12 +159,16 @@ def test_misc(): @assertFailure_fast def test_aliased_inputs_replacement(): - x, y, z = inputs() + x, *_ = inputs() tv = transpose_view(x) + tv.name = "tv" tvv = transpose_view(tv) + tvv.name = "tvv" sx = sigmoid(x) + sx.name = "sx" e = add_in_place(x, tv) - g = create_fgraph([x, y], [e], False) + e.name = "e" + g = create_fgraph([x], [e], False) assert not g.consistent() g.replace(tv, sx) assert g.consistent() @@ -310,16 +317,48 @@ def test_indirect_2(): @assertFailure_fast def test_long_destroyers_loop(): x, y, z = inputs() - e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x)) + add_xy = add_in_place(x, y) + add_xy.name = "add_i_xy" + add_yz = add_in_place(y, z) + add_yz.name = "add_i_yz" + add_zx = add(z, x) + add_zx.name = "add_zx" + dot_add_xy_yz = dot(add_xy, add_yz) + dot_add_xy_yz.name = "dot_add_xy_yz" + e = dot(dot_add_xy_yz, add_zx) + e.name = "e" g = create_fgraph([x, y, z], [e]) + + orderings = g.destroy_handler.orderings(g, ordered=False) + exp_orderings = {add_yz.owner: {add_xy}, add_xy.owner: {add_zx}} + assert orderings == exp_orderings + assert g.consistent() + + # This apparently introduces a cycle into the graph? + # That means it should fail validation and revert the replacement. + # TODO FIXME: We need tests that directly confirm the results of the + # functions in `DestroyHandler`, and not these extremely indirect + # integration-like tests that assert almost to nothing about the results + # produced by the code we're testing. + # TODO FIXME: Also, why are we even allowing `FunctionGraph`s to take + # these broken states? A quick cycle check in `FunctionGraph.replace` + # would be a lot better. TopoSubstitutionNodeRewriter(add, add_in_place).rewrite(g) + + # When `g` is in its inconsistent state the orderings are as follows: + # {AddInPlace(y, z): {AddInPlace(x, y)}, + # AddInPlace(x, y): {AddInPlace(z, x)}, + # AddInPlace(z, x): {AddInPlace(y, z)}} + + # Make sure the replacement was reverted + assert g.outputs[0].owner.inputs[-1].owner.op == add + + orderings = g.destroy_handler.orderings(g, ordered=False) + assert orderings == exp_orderings + assert g.consistent() - # we don't want to see that! - assert ( - str(g) - != "FunctionGraph(Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x)))" - ) + e2 = dot(dot(add_in_place(x, y), add_in_place(y, z)), add_in_place(z, x)) with pytest.raises(InconsistencyError): create_fgraph(*clone([x, y, z], [e2])) @@ -337,8 +376,8 @@ def test_misc_2(): def test_multi_destroyers(): x, y, z = inputs() - e = add(add_in_place(x, y), add_in_place(x, y)) - with pytest.raises(InconsistencyError): + e = add(add_in_place(x, y), add_in_place(x, z)) + with pytest.raises(InconsistencyError, match="Multiple destroyers of"): create_fgraph([x, y, z], [e]) diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index c145a99e24..50a57bef63 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -6,7 +6,7 @@ from typing_extensions import Literal from aesara.configdefaults import config -from aesara.graph.basic import NominalVariable +from aesara.graph.basic import Apply, NominalVariable from aesara.graph.features import Feature from aesara.graph.fg import FunctionGraph from aesara.graph.utils import MissingInputError @@ -307,8 +307,11 @@ def test_change_input(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) + var3.name = "var3" var4 = op2(var3, var2) + var4.name = "var4" var5 = op3(var4, var2, var2) + var5.name = "var5" cb_tracker = CallbackTracker() fg = FunctionGraph( [var1, var2], [var3, var5], clone=False, features=[cb_tracker] @@ -345,6 +348,7 @@ def test_change_input(self): old_apply_nodes = set(fg.apply_nodes) old_variables = set(fg.variables) old_var5_clients = list(fg.get_clients(var5)) + old_var5_node = var5.owner # We're replacing with the same variable, so nothing should happen fg.change_node_input(var5.owner, 1, var2) @@ -362,28 +366,40 @@ def test_change_input(self): assert fg.outputs[1].owner == var5.owner assert (var5.owner, 1) not in fg.get_clients(var2) - assert len(cb_tracker.callback_history) == 1 - assert cb_tracker.callback_history[0] == ( - "change_input", - (fg, var5.owner, 1, var2, var1), - {"reason": None}, - ) + assert len(cb_tracker.callback_history) == 3 + assert cb_tracker.callback_history == [ + ("prune", (fg, old_var5_node, None), {}), + ("import", (fg, var5.owner, None), {}), + ( + "change_input", + (fg, old_var5_node, var5.owner, 1, var2, var1), + {"reason": None}, + ), + ] cb_tracker.callback_history.clear() + old_var5_node = var5.owner + # Perform a valid `Apply` node input change that results in a # node removal (i.e. `var4.owner`) fg.change_node_input(var5.owner, 0, var1) assert var5.owner.inputs[0] is var1 - assert not fg.get_clients(var4) + assert var4 not in fg.clients assert var4.owner not in fg.apply_nodes assert var4 not in fg.variables - assert len(cb_tracker.callback_history) == 2 + assert len(cb_tracker.callback_history) == 4 assert cb_tracker.callback_history[0] == ("prune", (fg, var4.owner, None), {}) assert cb_tracker.callback_history[1] == ( + "prune", + (fg, old_var5_node, None), + {}, + ) + assert cb_tracker.callback_history[2] == ("import", (fg, var5.owner, None), {}) + assert cb_tracker.callback_history[3] == ( "change_input", - (fg, var5.owner, 0, var4, var1), + (fg, old_var5_node, var5.owner, 0, var4, var1), {"reason": None}, ) @@ -446,23 +462,32 @@ def test_replace(self): assert len(cb_tracker.callback_history) == 0 + old_var4_node = var4.owner + # Test a basic replacement fg.replace_all([(var3, var1)]) assert var3 not in fg.variables + assert fg.apply_nodes == {var4.owner, var5.owner} assert var4.owner.inputs == [var1, var2] assert fg.outputs == [var1, var5] - assert len(cb_tracker.callback_history) == 3 + assert len(cb_tracker.callback_history) == 5 assert cb_tracker.callback_history[0] == ( "change_input", - (fg, "output", 0, var3, var1), + (fg, "output", "output", 0, var3, var1), {"reason": None}, ) assert cb_tracker.callback_history[1] == ("prune", (fg, var3.owner, None), {}) assert cb_tracker.callback_history[2] == ( + "prune", + (fg, old_var4_node, None), + {}, + ) + assert cb_tracker.callback_history[3] == ("import", (fg, var4.owner, None), {}) + assert cb_tracker.callback_history[4] == ( "change_input", - (fg, var4.owner, 0, var3, var1), + (fg, old_var4_node, var4.owner, 0, var3, var1), {"reason": None}, ) @@ -472,6 +497,16 @@ def test_replace(self): cb_tracker = CallbackTracker() fg = FunctionGraph([var1], [var5], clone=False, features=[cb_tracker]) + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, var3.owner, "init"), {}), + ("import", (fg, var4.owner, "init"), {}), + ("import", (fg, var5.owner, "init"), {}), + ] + cb_tracker.callback_history.clear() + + old_var5_node = var5.owner + # Test a replacement that would remove the replacement variable # (i.e. `var3`) from the graph when the variable to be replaced # (i.e. `var4`) is removed @@ -483,12 +518,14 @@ def test_replace(self): assert fg.variables == {var1, var3, var5} assert cb_tracker.callback_history == [ - ("attach", (fg,), {}), - ("import", (fg, var3.owner, "init"), {}), - ("import", (fg, var4.owner, "init"), {}), - ("import", (fg, var5.owner, "init"), {}), ("prune", (fg, var4.owner, None), {}), - ("change_input", (fg, var5.owner, 0, var4, var3), {"reason": None}), + ("prune", (fg, old_var5_node, None), {}), + ("import", (fg, var5.owner, None), {}), + ( + "change_input", + (fg, old_var5_node, var5.owner, 0, var4, var3), + {"reason": None}, + ), ] var3 = op1(var1) @@ -497,6 +534,16 @@ def test_replace(self): cb_tracker = CallbackTracker() fg = FunctionGraph([var1], [var5], clone=False, features=[cb_tracker]) + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, var3.owner, "init"), {}), + ("import", (fg, var4.owner, "init"), {}), + ("import", (fg, var5.owner, "init"), {}), + ] + cb_tracker.callback_history.clear() + + old_var5_node = var5.owner + # Test multiple `change_node_input` calls on the same node fg.replace_all([(var4, var3)]) @@ -505,14 +552,24 @@ def test_replace(self): assert fg.outputs == [var5] assert fg.variables == {var1, var3, var5} + tmp_var5_node = Apply(op3, [var3, var4], [MyVariable("var5_tmp")]) + assert cb_tracker.callback_history == [ - ("attach", (fg,), {}), - ("import", (fg, var3.owner, "init"), {}), - ("import", (fg, var4.owner, "init"), {}), - ("import", (fg, var5.owner, "init"), {}), - ("change_input", (fg, var5.owner, 0, var4, var3), {"reason": None}), + ("prune", (fg, old_var5_node, None), {}), + ("import", (fg, tmp_var5_node, None), {}), + ( + "change_input", + (fg, old_var5_node, tmp_var5_node, 0, var4, var3), + {"reason": None}, + ), ("prune", (fg, var4.owner, None), {}), - ("change_input", (fg, var5.owner, 1, var4, var3), {"reason": None}), + ("prune", (fg, tmp_var5_node, None), {}), + ("import", (fg, var5.owner, None), {}), + ( + "change_input", + (fg, tmp_var5_node, var5.owner, 1, var4, var3), + {"reason": None}, + ), ] def test_replace_outputs(self): @@ -535,9 +592,9 @@ def test_replace_outputs(self): ("attach", (fg,), {}), ("import", (fg, var3.owner, "init"), {}), ("import", (fg, var4.owner, "init"), {}), - ("change_input", (fg, "output", 0, var3, var1), {"reason": None}), + ("change_input", (fg, "output", "output", 0, var3, var1), {"reason": None}), ("prune", (fg, var3.owner, None), {}), - ("change_input", (fg, "output", 2, var3, var1), {"reason": None}), + ("change_input", (fg, "output", "output", 2, var3, var1), {"reason": None}), ] def test_replace_contract(self): @@ -555,6 +612,18 @@ def test_replace_contract(self): cb_tracker = CallbackTracker() fg = FunctionGraph([x], [v4], clone=False, features=[cb_tracker]) + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, v1.owner, "init"), {}), + ("import", (fg, v2.owner, "init"), {}), + ("import", (fg, v3.owner, "init"), {}), + ("import", (fg, v4.owner, "init"), {}), + ] + cb_tracker.callback_history.clear() + + old_v3_node = v3.owner + old_v4_node = v4.owner + # This replacement should produce a new `Apply` node that's equivalent # to `v2` and try to replace `v3`'s node with that one. In other # words, the replacement creates a new node that's already in the @@ -566,7 +635,7 @@ def test_replace_contract(self): assert fg.clients == { x: [(v1.owner, 0)], v1: [(v3.owner, 0)], - v2: [], + # v2: [], v3: [(v4.owner, 0)], v4: [("output", 0)], } @@ -574,13 +643,9 @@ def test_replace_contract(self): assert v2 not in set(sum((n.outputs for n in fg.apply_nodes), [])) assert cb_tracker.callback_history == [ - ("attach", (fg,), {}), - ("import", (fg, v1.owner, "init"), {}), - ("import", (fg, v2.owner, "init"), {}), - ("import", (fg, v3.owner, "init"), {}), - ("import", (fg, v4.owner, "init"), {}), - ("prune", (fg, v2.owner, None), {}), - ("change_input", (fg, v3.owner, 0, v2, v1), {"reason": None}), + ("prune", (fg, old_v3_node, None), {}), + ("import", (fg, v3.owner, None), {}), + ("change_input", (fg, old_v3_node, v3.owner, 0, v2, v1), {"reason": None}), ] # Let's try the same thing at a different point in the chain @@ -598,6 +663,17 @@ def test_replace_contract(self): cb_tracker = CallbackTracker() fg = FunctionGraph([x], [v4], clone=False, features=[cb_tracker]) + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, v1.owner, "init"), {}), + ("import", (fg, v2.owner, "init"), {}), + ("import", (fg, v3.owner, "init"), {}), + ("import", (fg, v4.owner, "init"), {}), + ] + cb_tracker.callback_history.clear() + + old_v4_node = v4.owner + fg.replace_all([(v3, v2)]) assert v3 not in fg.variables @@ -605,20 +681,16 @@ def test_replace_contract(self): x: [(v1.owner, 0)], v1: [(v2.owner, 0)], v2: [(v4.owner, 0)], - v3: [], + # v3: [], v4: [("output", 0)], } assert fg.apply_nodes == {v4.owner, v2.owner, v1.owner} assert v3 not in set(sum((n.outputs for n in fg.apply_nodes), [])) exp_res = [ - ("attach", (fg,), {}), - ("import", (fg, v1.owner, "init"), {}), - ("import", (fg, v2.owner, "init"), {}), - ("import", (fg, v3.owner, "init"), {}), - ("import", (fg, v4.owner, "init"), {}), - ("prune", (fg, v3.owner, None), {}), - ("change_input", (fg, v4.owner, 0, v3, v2), {"reason": None}), + ("prune", (fg, old_v4_node, None), {}), + ("import", (fg, v4.owner, None), {}), + ("change_input", (fg, old_v4_node, v4.owner, 0, v3, v2), {"reason": None}), ] assert cb_tracker.callback_history == exp_res @@ -667,25 +739,31 @@ def test_replace_circular(self): ) cb_tracker.callback_history.clear() + old_var4_owner = var4.owner + fg.replace_all([(var3, var4)]) # The following works (and is kind of gross), because `var4` has been # mutated in-place - assert fg.apply_nodes == {var4.owner, var5.owner} assert var4.owner.inputs == [var4, var2] + assert fg.apply_nodes == {var4.owner, var5.owner} + assert fg.outputs == [var4, var5] - assert len(cb_tracker.callback_history) == 3 - assert cb_tracker.callback_history[0] == ( - "change_input", - (fg, "output", 0, var3, var4), - {"reason": None}, - ) - assert cb_tracker.callback_history[1] == ("prune", (fg, var3.owner, None), {}) - assert cb_tracker.callback_history[2] == ( - "change_input", - (fg, var4.owner, 0, var3, var4), - {"reason": None}, - ) + assert cb_tracker.callback_history == [ + ( + "change_input", + (fg, "output", "output", 0, var3, var4), + {"reason": None}, + ), + ("prune", (fg, var3.owner, None), {}), + ("prune", (fg, old_var4_owner, None), {}), + ("import", (fg, var4.owner, None), {}), + ( + "change_input", + (fg, old_var4_owner, var4.owner, 0, var3, var4), + {"reason": None}, + ), + ] def test_replace_bad_state(self): @@ -708,8 +786,11 @@ def test_check_integrity(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) + var3.name = "var3" var4 = op2(var3, var2) + var4.name = "var4" var5 = op3(var4, var2, var2) + var5.name = "var5" fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(Exception, match="The following nodes are .*"): @@ -733,15 +814,24 @@ def test_check_integrity(self): fg.variables.add(var4) with pytest.raises(Exception, match="Undeclared input.*"): - var6 = MyVariable2("var6") - fg.clients[var6] = [(var5.owner, 3)] + var6 = MyVariable("var6") + var7 = op1(var6) + var7.name = "var7" + fg.clients[var6] = [(var7.owner, 0)] fg.variables.add(var6) - var5.owner.inputs.append(var6) + fg.clients[var7] = [("output", 2)] + fg.variables.add(var7) + fg.outputs.append(var7) + fg.apply_nodes.add(var7.owner) fg.check_integrity() fg.variables.remove(var6) - var5.owner.inputs.remove(var6) + fg.variables.remove(var7) + del fg.clients[var6] + del fg.clients[var7] + fg.outputs.remove(var7) + fg.apply_nodes.remove(var7.owner) # TODO: What if the index value is greater than 1? It will throw an # `IndexError`, but that doesn't sound like anything we'd want.