Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow interdependent initial points from same OpFromGraph node #7569

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.transforms import Transform
from pymc.pytensorf import compile_pymc, find_rng_nodes, replace_rng_nodes, reseed_rngs
from pymc.pytensorf import (
compile_pymc,
find_rng_nodes,
replace_rng_nodes,
reseed_rngs,
toposort_replace,
)
from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name

StartDict = dict[Variable | str, np.ndarray | Variable | str]
Expand Down Expand Up @@ -288,8 +294,7 @@ def make_initial_point_expression(
# order, so that later nodes do not reintroduce expressions with earlier
# rvs that would need to once again be replaced by their initial_points
graph = FunctionGraph(outputs=free_rvs_clone, clone=False)
replacements = reversed(list(zip(free_rvs_clone, initial_values_clone)))
graph.replace_all(replacements, import_missing=True)
toposort_replace(graph, tuple(zip(free_rvs_clone, initial_values_clone)), reverse=True)

if not return_transformed:
return graph.outputs
Expand Down
34 changes: 32 additions & 2 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,10 +1122,40 @@ def toposort_replace(
fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], reverse: bool = False
) -> None:
"""Replace multiple variables in place in topological order."""
toposort = fgraph.toposort()
fgraph_toposort = {node: i for i, node in enumerate(fgraph.toposort())}
_inner_fgraph_toposorts = {} # Cache inner toposorts

def _nested_toposort_index(var, fgraph_toposort) -> tuple[int]:
"""Compute position of variable in fgraph toposort.

When a variable is an OpFromGraph output, extend output with the toposort index of the inner graph(s).

This allows ordering variables that come from the same OpFromGraph.
"""
if not var.owner:
return (-1,)

index = fgraph_toposort[var.owner]

# Recurse into OpFromGraphs
# TODO: Could also recurse into Scans
if isinstance(var.owner.op, OpFromGraph):
inner_fgraph = var.owner.op.fgraph

if inner_fgraph not in _inner_fgraph_toposorts:
_inner_fgraph_toposorts[inner_fgraph] = {
node: i for i, node in enumerate(inner_fgraph.toposort())
}

inner_fgraph_toposort = _inner_fgraph_toposorts[inner_fgraph]
inner_var = inner_fgraph.outputs[var.owner.outputs.index(var)]
return (index, *_nested_toposort_index(inner_var, inner_fgraph_toposort))
else:
return (index,)

sorted_replacements = sorted(
replacements,
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1,
key=lambda pair: _nested_toposort_index(pair[0], fgraph_toposort),
reverse=reverse,
)
fgraph.replace_all(sorted_replacements, import_missing=True)
Expand Down
44 changes: 42 additions & 2 deletions tests/test_initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import pytensor.tensor as pt
import pytest

from pytensor.compile.builders import OpFromGraph
from pytensor.tensor.random.op import RandomVariable

import pymc as pm

from pymc.distributions.distribution import support_point
from pymc.distributions.distribution import _support_point, support_point
from pymc.initial_point import make_initial_point_fn, make_initial_point_fns_per_chain


Expand All @@ -47,12 +48,17 @@ def test_make_initial_point_fns_per_chain_checks_kwargs(self):
)
pass

def test_dependent_initvals(self):
@pytest.mark.parametrize("reverse_rvs", [False, True])
def test_dependent_initvals(self, reverse_rvs):
with pm.Model() as pmodel:
L = pm.Uniform("L", 0, 1, initval=0.5)
U = pm.Uniform("U", lower=9, upper=10, initval=9.5)
B1 = pm.Uniform("B1", lower=L, upper=U, initval=5)
B2 = pm.Uniform("B2", lower=L, upper=U, initval=(L + U) / 2)

if reverse_rvs:
pmodel.free_RVs = pmodel.free_RVs[::-1]

ip = pmodel.initial_point(random_seed=0)
assert ip["L_interval__"] == 0
assert ip["U_interval__"] == 0
Expand Down Expand Up @@ -187,6 +193,40 @@ def test_string_overrides_work(self):
assert np.isclose(iv["B_log__"], 0)
assert iv["C_log__"] == 0

@pytest.mark.parametrize("reverse_rvs", [False, True])
def test_dependent_initval_from_OFG(self, reverse_rvs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test indicative of why this features is need for the other PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup

class MyTestOp(OpFromGraph):
pass

@_support_point.register(MyTestOp)
def my_test_op_support_point(op, out):
out1, out2 = out.owner.outputs
if out is out1:
return out1
else:
return out1 * 4

out1 = pt.zeros(())
out2 = out1 * 2
rv_op = MyTestOp([], [out1, out2])

with pm.Model() as model:
A, B = rv_op()
if reverse_rvs:
model.register_rv(B, "B")
model.register_rv(A, "A")
else:
model.register_rv(A, "A")
model.register_rv(B, "B")

assert model.initial_point() == {"A": 0, "B": 0}

model.set_initval(A, 1)
assert model.initial_point() == {"A": 1, "B": 4}

model.set_initval(B, 3)
assert model.initial_point() == {"A": 1, "B": 3}


class TestSupportPoint:
def test_basic(self):
Expand Down
Loading