Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement marginalization as a model transformation and deprecate MarginalModel #388

Merged
merged 2 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ import pymc as pm
import pymc_extras as pmx

with pm.Model():
alpha = pmx.ParabolicFractal('alpha', b=1, c=1)

alpha = pmx.ParabolicFractal('alpha', b=1, c=1)

...
...

```

Expand Down
3 changes: 2 additions & 1 deletion docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ methods in the current release of PyMC experimental.
:toctree: generated/

as_model
MarginalModel
marginalize
recover_marginals
model_builder.ModelBuilder

Inference
Expand Down Expand Up @@ -53,6 +53,7 @@ Utils

spline.bspline_interpolation
prior.prior_from_idata
model_equivalence.equivalent_models


Statespace Models
Expand Down
6 changes: 5 additions & 1 deletion pymc_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from pymc_extras import gp, statespace, utils
from pymc_extras.distributions import *
from pymc_extras.inference.fit import fit
from pymc_extras.model.marginal.marginal_model import MarginalModel, marginalize
from pymc_extras.model.marginal.marginal_model import (
MarginalModel,
marginalize,
recover_marginals,
)
from pymc_extras.model.model_api import as_model
from pymc_extras.version import __version__

Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def transition(*args):
discrete_mc_op = DiscreteMarkovChainRV(
inputs=[P_, steps_, init_dist_, state_rng],
outputs=[state_next_rng, discrete_mc_],
ndim_supp=1,
n_lags=n_lags,
extended_signature="(p,p),(),(p),[rng]->[rng],(t)",
)

discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng)
Expand Down
103 changes: 100 additions & 3 deletions pymc_extras/model/marginal/distributions.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
import warnings

from collections.abc import Sequence

import numpy as np
import pytensor.tensor as pt

from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
from pymc.distributions.distribution import _support_point, support_point
from pymc.logprob.abstract import MeasurableOp, _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import constant_fold
from pytensor import Variable
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.mode import Mode
from pytensor.graph import Op, vectorize_graph
from pytensor.graph import FunctionGraph, Op, vectorize_graph
from pytensor.graph.basic import equal_computations
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.scan import map as scan_map
from pytensor.scan import scan
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.type import RandomType

from pymc_extras.distributions import DiscreteMarkovChain


class MarginalRV(OpFromGraph, MeasurableOp):
"""Base class for Marginalized RVs"""

def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
def __init__(
self,
*args,
dims_connections: tuple[tuple[int | None], ...],
dims: tuple[Variable, ...],
**kwargs,
) -> None:
self.dims_connections = dims_connections
self.dims = dims
super().__init__(*args, **kwargs)

@property
Expand All @@ -43,6 +55,74 @@ def support_axes(self) -> tuple[tuple[int]]:
)
return tuple(support_axes_vars)

def __eq__(self, other):
# Just to allow easy testing of equivalent models,
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
if type(self) is not type(other):
return False

return equal_computations(
self.inner_outputs,
other.inner_outputs,
self.inner_inputs,
other.inner_inputs,
)

def __hash__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this. Would it be possible to move this logic into equal_computations or a similar place? Also is this actually used anywhere as is?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the comment mentions this will be moved to the base class in PyTensor. It's used for the equal computations check that make one/two tests much more straightforward to write.

# Just to allow easy testing of equivalent models,
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
return hash((type(self), len(self.inner_inputs), len(self.inner_outputs)))


@_support_point.register
def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
"""Support point for a marginalized RV.

The support point of a marginalized RV is the support point of the inner RV,
conditioned on the marginalized RV taking its support point.
"""
outputs = rv.owner.outputs

inner_rv = op.inner_outputs[outputs.index(rv)]
marginalized_inner_rv, *other_dependent_inner_rvs = (
out
for out in op.inner_outputs
if out is not inner_rv and not isinstance(out.type, RandomType)
)

# Replace references to inner rvs by the dummy variables (including the marginalized RV)
# This is necessary because the inner RVs may depend on each other
marginalized_inner_rv_dummy = marginalized_inner_rv.clone()
other_dependent_inner_rv_to_dummies = {
inner_rv: inner_rv.clone() for inner_rv in other_dependent_inner_rvs
}
inner_rv = clone_replace(
inner_rv,
replace={marginalized_inner_rv: marginalized_inner_rv_dummy}
| other_dependent_inner_rv_to_dummies,
)

# Get support point of inner RV and marginalized RV
inner_rv_support_point = support_point(inner_rv)
marginalized_inner_rv_support_point = support_point(marginalized_inner_rv)

replacements = [
# Replace the marginalized RV dummy by its support point
(marginalized_inner_rv_dummy, marginalized_inner_rv_support_point),
# Replace other dependent RVs dummies by the respective outer outputs.
# PyMC will replace them by their support points later
*(
(v, outputs[op.inner_outputs.index(k)])
for k, v in other_dependent_inner_rv_to_dummies.items()
),
# Replace outer input RVs
*zip(op.inner_inputs, inputs),
]
fgraph = FunctionGraph(outputs=[inner_rv_support_point], clone=False)
fgraph.replace_all(replacements, import_missing=True)
[rv_support_point] = fgraph.outputs
return rv_support_point


class MarginalFiniteDiscreteRV(MarginalRV):
"""Base class for Marginalized Finite Discrete RVs"""
Expand Down Expand Up @@ -132,12 +212,27 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
the inner graph.
"""
return clone_replace(
return graph_replace(
op.inner_outputs,
replace=tuple(zip(op.inner_inputs, inputs)),
strict=False,
)


class NonSeparableLogpWarning(UserWarning):
pass


def warn_non_separable_logp(values):
if len(values) > 1:
warnings.warn(
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
f"Their joint logp terms will be assigned to the first value: {values[0]}.",
NonSeparableLogpWarning,
stacklevel=2,
)


DUMMY_ZERO = pt.constant(0, name="dummy_zero")


Expand Down Expand Up @@ -199,6 +294,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
# Align logp with non-collapsed batch dimensions of first RV
joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)

warn_non_separable_logp(values)
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
return joint_logp, *dummy_logps
Expand Down Expand Up @@ -272,5 +368,6 @@ def step_alpha(logp_emission, log_alpha, log_P):

# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
warn_non_separable_logp(values)
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
return joint_logp, *dummy_logps
17 changes: 8 additions & 9 deletions pymc_extras/model/marginal/graph_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from itertools import zip_longest

from pymc import SymbolicRandomVariable
from pytensor.compile import SharedVariable
from pytensor.graph import Constant, Variable, ancestors
from pymc.model.fgraph import ModelVar
from pytensor.graph import Variable, ancestors
from pytensor.graph.basic import io_toposort
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.blockwise import Blockwise
Expand Down Expand Up @@ -35,13 +35,9 @@ def static_shape_ancestors(vars):

def find_conditional_input_rvs(output_rvs, all_rvs):
"""Find conditionally indepedent input RVs."""
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
return [
var
for var in ancestors(output_rvs, blockers=blockers)
if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
]
other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
return [var for var in ancestors(output_rvs, blockers=blockers) if var in other_rvs]


def is_conditional_dependent(
Expand Down Expand Up @@ -141,6 +137,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
# None of the inputs are related to the batch_axes of the input_vars
continue

elif isinstance(node.op, ModelVar):
var_dims[node.outputs[0]] = inputs_dims[0]

elif isinstance(node.op, DimShuffle):
[input_dims] = inputs_dims
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)
Expand Down
Loading
Loading