Skip to content

Commit

Permalink
Allow ordering of multi-output OpFromGraph variables in `toposort_rep…
Browse files Browse the repository at this point in the history
…lace`
  • Loading branch information
ricardoV94 committed Nov 11, 2024
1 parent 71d25e7 commit 285263d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
34 changes: 32 additions & 2 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,10 +1122,40 @@ def toposort_replace(
fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], reverse: bool = False
) -> None:
"""Replace multiple variables in place in topological order."""
toposort = fgraph.toposort()
fgraph_toposort = {node: i for i, node in enumerate(fgraph.toposort())}
_inner_fgraph_toposorts = {} # Cache inner toposorts

def _nested_toposort_index(var, fgraph_toposort) -> tuple[int]:
"""Compute position of variable in fgraph toposort.
When a variable is an OpFromGraph output, extend output with the toposort index of the inner graph(s).
This allows ordering variables that come from the same OpFromGraph.
"""
if not var.owner:
return (-1,)

index = fgraph_toposort[var.owner]

# Recurse into OpFromGraphs
# TODO: Could also recurse into Scans
if isinstance(var.owner.op, OpFromGraph):
inner_fgraph = var.owner.op.fgraph

if inner_fgraph not in _inner_fgraph_toposorts:
_inner_fgraph_toposorts[inner_fgraph] = {
node: i for i, node in enumerate(inner_fgraph.toposort())
}

inner_fgraph_toposort = _inner_fgraph_toposorts[inner_fgraph]
inner_var = inner_fgraph.outputs[var.owner.outputs.index(var)]
return (index, *_nested_toposort_index(inner_var, inner_fgraph_toposort))
else:
return (index,)

sorted_replacements = sorted(
replacements,
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1,
key=lambda pair: _nested_toposort_index(pair[0], fgraph_toposort),
reverse=reverse,
)
fgraph.replace_all(sorted_replacements, import_missing=True)
Expand Down
37 changes: 36 additions & 1 deletion tests/test_initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import pytensor.tensor as pt
import pytest

from pytensor.compile.builders import OpFromGraph
from pytensor.tensor.random.op import RandomVariable

import pymc as pm

from pymc.distributions.distribution import support_point
from pymc.distributions.distribution import _support_point, support_point
from pymc.initial_point import make_initial_point_fn, make_initial_point_fns_per_chain


Expand Down Expand Up @@ -192,6 +193,40 @@ def test_string_overrides_work(self):
assert np.isclose(iv["B_log__"], 0)
assert iv["C_log__"] == 0

@pytest.mark.parametrize("reverse_rvs", [False, True])
def test_dependent_initval_from_OFG(self, reverse_rvs):
class MyTestOp(OpFromGraph):
pass

@_support_point.register(MyTestOp)
def my_test_op_support_point(op, out):
out1, out2 = out.owner.outputs
if out is out1:
return out1
else:
return out1 * 4

out1 = pt.zeros(())
out2 = out1 * 2
rv_op = MyTestOp([], [out1, out2])

with pm.Model() as model:
A, B = rv_op()
if reverse_rvs:
model.register_rv(B, "B")
model.register_rv(A, "A")
else:
model.register_rv(A, "A")
model.register_rv(B, "B")

assert model.initial_point() == {"A": 0, "B": 0}

model.set_initval(A, 1)
assert model.initial_point() == {"A": 1, "B": 4}

model.set_initval(B, 3)
assert model.initial_point() == {"A": 1, "B": 3}


class TestSupportPoint:
def test_basic(self):
Expand Down

0 comments on commit 285263d

Please sign in to comment.