Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add opt_sample that performs optimizations prior to calling pymc.sample #396

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ methods in the current release of PyMC experimental.
MarginalModel
marginalize
model_builder.ModelBuilder
opt_sample

Inference
=========
Expand Down
1 change: 1 addition & 0 deletions pymc_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pymc_extras.inference.fit import fit
from pymc_extras.model.marginal.marginal_model import MarginalModel, marginalize
from pymc_extras.model.model_api import as_model
from pymc_extras.sampling.mcmc import opt_sample
from pymc_extras.version import __version__

_log = logging.getLogger("pmx")
Expand Down
14 changes: 1 addition & 13 deletions pymc_extras/model/marginal/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pymc.logprob.abstract import MeasurableOp, _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import constant_fold
from pytensor import Variable
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.mode import Mode
from pytensor.graph import Op, vectorize_graph
Expand All @@ -17,6 +16,7 @@
from pytensor.tensor import TensorVariable

from pymc_extras.distributions import DiscreteMarkovChain
from pymc_extras.utils.ofg import inline_ofg_outputs


class MarginalRV(OpFromGraph, MeasurableOp):
Expand Down Expand Up @@ -126,18 +126,6 @@ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> Tens
return logp.transpose(*dims_alignment)


def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
"""Inline the inner graph (outputs) of an OpFromGraph Op.

Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
the inner graph.
"""
return clone_replace(
op.inner_outputs,
replace=tuple(zip(op.inner_inputs, inputs)),
)


DUMMY_ZERO = pt.constant(0, name="dummy_zero")


Expand Down
Empty file.
114 changes: 114 additions & 0 deletions pymc_extras/sampling/mcmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import sys

from pymc.model.core import Model
from pymc.sampling.mcmc import sample
from pytensor.graph.rewriting.basic import GraphRewriter

from pymc_extras.sampling.optimizations.optimize import (
TAGS_TYPE,
optimize_model_for_mcmc_sampling,
)


def opt_sample(
*args,
model: Model | None = None,
include: TAGS_TYPE = ("default",),
exclude: TAGS_TYPE = None,
rewriter: GraphRewriter | None = None,
verbose: bool = False,
**kwargs,
):
"""Sample from a model after applying optimizations.

Parameters
----------
model : Model, optinoal
The model to sample from. If None, use the model associated with the context.
include : TAGS_TYPE
The tags to include in the optimizations. Ignored if `rewriter` is not None.
exclude : TAGS_TYPE
The tags to exclude from the optimizations. Ignored if `rewriter` is not None.
rewriter : RewriteDatabaseQuery (optional)
The rewriter to use. If None, use the default rewriter with the given `include` and `exclude` tags.
verbose : bool, default=False
Print information about the optimizations applied.
*args, **kwargs:
Passed to `pm.sample`

Returns
-------
sample_output:
The output of `pm.sample`

Examples
--------
.. code:: python
import pymc as pm
import pymc_experimental as pmx

with pm.Model() as m:
p = pm.Beta("p", 1, 1, shape=(1000,))
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)

idata = pmx.opt_sample(verbose=True)

# Applied optimization: beta_binomial_conjugacy 1x
# ConjugateRVSampler: [p]


You can control which optimizations are applied using the `include` and `exclude` arguments:

.. code:: python
import pymc as pm
import pymc_experimental as pmx

with pm.Model() as m:
p = pm.Beta("p", 1, 1, shape=(1000,))
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)

idata = pmx.opt_sample(exclude="conjugacy", verbose=True)

# No optimizations applied
# NUTS: [p]

.. code:: python
import pymc as pm
import pymc_experimental as pmx

with pm.Model() as m:
a = pm.InverseGamma("a", 1, 1)
b = pm.InverseGamma("b", 1, 1)
p = pm.Beta("p", a, b, shape=(1000,))
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)

# By default, the conjugacy of p will not be applied because it depends on other free variables
idata = pmx.opt_sample(include="conjugacy-eager", verbose=True)

# Applied optimization: beta_binomial_conjugacy_eager 1x
# CompoundStep
# >NUTS: [a, b]
# >ConjugateRVSampler: [p]

"""
if kwargs.get("step", None) is not None:
raise ValueError(
"The `step` argument is not supported in `opt_sample`, as custom steps would refer to the original model.\n"
"You can manually transform the model with `pymc_experimental.sampling.optimizations.optimize_model_for_mcmc_sampling` "
"and then define the custom steps and forward them to `pymc.sample`."
)

opt_model, rewrite_counters = optimize_model_for_mcmc_sampling(
model, include=include, exclude=exclude, rewriter=rewriter
)

if verbose:
applied_opt = False
for rewrite_counter in rewrite_counters:
for rewrite, counts in rewrite_counter.items():
applied_opt = True
print(f"Applied optimization: {rewrite} {counts}x", file=sys.stdout)
if not applied_opt:
print("No optimizations applied", file=sys.stdout)

return sample(*args, model=opt_model, **kwargs)
13 changes: 13 additions & 0 deletions pymc_extras/sampling/optimizations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# ruff: noqa: F401
# Add rewrites to the optimization DBs

from pymc_extras.sampling.optimizations import conjugacy, summary_stats
from pymc_extras.sampling.optimizations.optimize import (
optimize_model_for_mcmc_sampling,
posterior_optimization_db,
)

__all__ = [
"posterior_optimization_db",
"optimize_model_for_mcmc_sampling",
]
205 changes: 205 additions & 0 deletions pymc_extras/sampling/optimizations/conjugacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from collections.abc import Sequence
from functools import partial

from pymc.distributions import Beta, Binomial
from pymc.model.fgraph import ModelFreeRV, ModelValuedVar, model_free_rv
from pymc.pytensorf import collect_default_updates
from pytensor.graph.basic import Variable, ancestors
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.subtensor import _sum_grad_over_bcasted_dims as sum_bcasted_dims

from pymc_extras.sampling.optimizations.conjugate_sampler import (
ConjugateRV,
)
from pymc_extras.sampling.optimizations.optimize import posterior_optimization_db


def register_conjugacy_rewrites_variants(rewrite_fn, tracks=(ModelFreeRV,)):
"""Register a rewrite function and its force variant in the posterior optimization DB."""
name = rewrite_fn.__name__

rewrite_fn_default = partial(rewrite_fn, eager=False)
rewrite_fn_default.__name__ = name
rewrite_default = node_rewriter(tracks=tracks)(rewrite_fn_default)

rewrite_fn_eager = partial(rewrite_fn, eager=True)
rewrite_fn_eager.__name__ = f"{name}_eager"
rewrite_eager = node_rewriter(tracks=tracks)(rewrite_fn_eager)

posterior_optimization_db.register(
rewrite_default.__name__,
rewrite_default,
"default",
"conjugacy",
)

posterior_optimization_db.register(
rewrite_eager.__name__,
rewrite_eager,
"non-default",
"conjugacy-eager",
)

return rewrite_default, rewrite_eager


def has_free_rv_ancestor(vars: Variable | Sequence[Variable]) -> bool:
"""Return True if any of the variables have a model variable as an ancestor."""
if not isinstance(vars, Sequence):
vars = (vars,)

# TODO: It should stop at observed RVs, it doesn't matter if they have a free RV above
# Did not implement due to laziness and it being a rare case
return any(
var.owner is not None and isinstance(var.owner.op, ModelFreeRV) for var in ancestors(vars)
)


def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable:
"""Return the Model dummy var that wraps the RV"""
for client, _ in fgraph.clients[rv]:
if isinstance(client.op, ModelValuedVar):
return client.outputs[0]


def get_dist_params(rv: Variable) -> tuple[Variable]:
return rv.owner.op.dist_params(rv.owner)


def get_size_param(rv: Variable) -> Variable:
return rv.owner.op.size_param(rv.owner)


def rv_used_by(
fgraph: FunctionGraph,
rv: Variable,
used_by_type: type,
used_as_arg_idx: int | Sequence[int],
arg_idx_offset: int = 2, # Ignore the first two arguments (rng and size)
strict: bool = True,
) -> list[Variable]:
"""Return the RVs that use `rv` as an argument in an operation of type `used_by_type`.

RV may be used directly or broadcasted before being used.

Parameters
----------
fgraph : FunctionGraph
The function graph containing the RVs
rv : Variable
The RV to check for uses.
used_by_type : type
The type of operation that may use the RV.
used_as_arg_idx : int | Sequence[int]
The index of the RV in the operation's inputs.
strict : bool, default=True
If True, return no results when the RV is used in an unrecognized way.

"""
if isinstance(used_as_arg_idx, int):
used_as_arg_idx = (used_as_arg_idx,)
used_as_arg_idx = tuple(arg_idx + arg_idx_offset for arg_idx in used_as_arg_idx)

clients = fgraph.clients
used_by: list[Variable] = []
for client, inp_idx in clients[rv]:
if isinstance(client.op, Output):
continue

if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx:
# RV is directly used by the RV type
used_by.append(client.default_output())

elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims:
for sub_client, sub_inp_idx in clients[client.outputs[0]]:
if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx:
# RV is broadcasted and then used by the RV type
used_by.append(sub_client.default_output())
elif strict:
# Some other unrecognized use, bail out
return []
elif strict:
# Some other unrecognized use, bail out
return []

return used_by


def wrap_rv_and_conjugate_rv(
fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable]
) -> Variable:
"""Wrap the RV and its conjugate posterior RV in a ConjugateRV node.

Also takes care of handling the random number generators used in the conjugate posterior.
"""
rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items())
for rng in rngs:
if rng not in fgraph.inputs:
fgraph.add_input(rng)
conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs])
return conjugate_op(rv, *inputs, *rngs)[0]


def create_untransformed_free_rv(
fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable]
) -> Variable:
"""Create a model FreeRV without transform."""
transform = None
value = rv.type(name=name)
fgraph.add_input(value)
free_rv = model_free_rv(rv, value, transform, *dims)
free_rv.name = name
return free_rv


def beta_binomial_conjugacy(fgraph: FunctionGraph, node, eager: bool = False):
if not isinstance(node.op, ModelFreeRV):
return None

[beta_free_rv] = node.outputs
beta_rv, _, *beta_dims = node.inputs

if not isinstance(beta_rv.owner.op, Beta):
return None

a, b = get_dist_params(beta_rv)
if not eager and has_free_rv_ancestor([a, b]):
# Don't apply rewrite if a, b depend on other model variables as that will force a Gibbs sampling scheme
return None

p_arg_idx = 1 # Params to the Binomial are (n, p)
binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx)

if len(binomial_rvs) != 1:
# Question: Can we apply conjugacy when RV is used by more than one binomial?
return None

[binomial_rv] = binomial_rvs

binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv)
if binomial_model_var is None:
return None

# We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv)
n, _ = get_dist_params(binomial_rv)

# Use value of y in new graph to avoid circularity
y = binomial_model_var.owner.inputs[1]

conjugate_a = sum_bcasted_dims(beta_rv, a + y)
conjugate_b = sum_bcasted_dims(beta_rv, b + (n - y))

conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b, size=get_size_param(beta_rv))

new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y])
new_beta_free_rv = create_untransformed_free_rv(
fgraph, new_beta_rv, beta_free_rv.name, beta_dims
)
return [new_beta_free_rv]


beta_binomial_conjugacy_default, beta_binomial_conjugacy_force = (
register_conjugacy_rewrites_variants(beta_binomial_conjugacy)
)
Loading
Loading