From 285263d6b8cb8901d143f4e2c5547953925c92fb Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 11 Nov 2024 13:20:40 +0100 Subject: [PATCH] Allow ordering of multi-output OpFromGraph variables in `toposort_replace` --- pymc/pytensorf.py | 34 ++++++++++++++++++++++++++++++++-- tests/test_initial_point.py | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 213831c9f11..527a1c8e0c4 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -1122,10 +1122,40 @@ def toposort_replace( fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], reverse: bool = False ) -> None: """Replace multiple variables in place in topological order.""" - toposort = fgraph.toposort() + fgraph_toposort = {node: i for i, node in enumerate(fgraph.toposort())} + _inner_fgraph_toposorts = {} # Cache inner toposorts + + def _nested_toposort_index(var, fgraph_toposort) -> tuple[int]: + """Compute position of variable in fgraph toposort. + + When a variable is an OpFromGraph output, extend output with the toposort index of the inner graph(s). + + This allows ordering variables that come from the same OpFromGraph. + """ + if not var.owner: + return (-1,) + + index = fgraph_toposort[var.owner] + + # Recurse into OpFromGraphs + # TODO: Could also recurse into Scans + if isinstance(var.owner.op, OpFromGraph): + inner_fgraph = var.owner.op.fgraph + + if inner_fgraph not in _inner_fgraph_toposorts: + _inner_fgraph_toposorts[inner_fgraph] = { + node: i for i, node in enumerate(inner_fgraph.toposort()) + } + + inner_fgraph_toposort = _inner_fgraph_toposorts[inner_fgraph] + inner_var = inner_fgraph.outputs[var.owner.outputs.index(var)] + return (index, *_nested_toposort_index(inner_var, inner_fgraph_toposort)) + else: + return (index,) + sorted_replacements = sorted( replacements, - key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1, + key=lambda pair: _nested_toposort_index(pair[0], fgraph_toposort), reverse=reverse, ) fgraph.replace_all(sorted_replacements, import_missing=True) diff --git a/tests/test_initial_point.py b/tests/test_initial_point.py index e8ace32557c..9138f37b3e7 100644 --- a/tests/test_initial_point.py +++ b/tests/test_initial_point.py @@ -17,11 +17,12 @@ import pytensor.tensor as pt import pytest +from pytensor.compile.builders import OpFromGraph from pytensor.tensor.random.op import RandomVariable import pymc as pm -from pymc.distributions.distribution import support_point +from pymc.distributions.distribution import _support_point, support_point from pymc.initial_point import make_initial_point_fn, make_initial_point_fns_per_chain @@ -192,6 +193,40 @@ def test_string_overrides_work(self): assert np.isclose(iv["B_log__"], 0) assert iv["C_log__"] == 0 + @pytest.mark.parametrize("reverse_rvs", [False, True]) + def test_dependent_initval_from_OFG(self, reverse_rvs): + class MyTestOp(OpFromGraph): + pass + + @_support_point.register(MyTestOp) + def my_test_op_support_point(op, out): + out1, out2 = out.owner.outputs + if out is out1: + return out1 + else: + return out1 * 4 + + out1 = pt.zeros(()) + out2 = out1 * 2 + rv_op = MyTestOp([], [out1, out2]) + + with pm.Model() as model: + A, B = rv_op() + if reverse_rvs: + model.register_rv(B, "B") + model.register_rv(A, "A") + else: + model.register_rv(A, "A") + model.register_rv(B, "B") + + assert model.initial_point() == {"A": 0, "B": 0} + + model.set_initval(A, 1) + assert model.initial_point() == {"A": 1, "B": 4} + + model.set_initval(B, 3) + assert model.initial_point() == {"A": 1, "B": 3} + class TestSupportPoint: def test_basic(self):