From cd34e7412e27a0b8650147218bfbe1e569951a82 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 5 Dec 2024 14:45:42 +0100 Subject: [PATCH 1/4] Add opt_sample --- docs/api_reference.rst | 1 + pymc_experimental/sampling/__init__.py | 0 pymc_experimental/sampling/mcmc.py | 76 +++++++++++++++++++ .../sampling/optimizations/optimize.py | 37 +++++++++ pymc_extras/__init__.py | 1 + tests/sampling/mcmc/__init__.py | 0 tests/sampling/mcmc/test_mcmc.py | 20 +++++ 7 files changed, 135 insertions(+) create mode 100644 pymc_experimental/sampling/__init__.py create mode 100644 pymc_experimental/sampling/mcmc.py create mode 100644 pymc_experimental/sampling/optimizations/optimize.py create mode 100644 tests/sampling/mcmc/__init__.py create mode 100644 tests/sampling/mcmc/test_mcmc.py diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 15c3d346..80421f89 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -15,6 +15,7 @@ methods in the current release of PyMC experimental. MarginalModel marginalize model_builder.ModelBuilder + opt_sample Inference ========= diff --git a/pymc_experimental/sampling/__init__.py b/pymc_experimental/sampling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pymc_experimental/sampling/mcmc.py b/pymc_experimental/sampling/mcmc.py new file mode 100644 index 00000000..0d1085f4 --- /dev/null +++ b/pymc_experimental/sampling/mcmc.py @@ -0,0 +1,76 @@ +import sys + +from pymc.model.core import Model +from pymc.sampling.mcmc import sample +from pytensor.graph.rewriting.basic import GraphRewriter + +from pymc_experimental.sampling.optimizations.optimize import ( + TAGS_TYPE, + optimize_model_for_mcmc_sampling, +) + + +def opt_sample( + *args, + model: Model | None = None, + include: TAGS_TYPE = ("default",), + exclude: TAGS_TYPE = None, + rewriter: GraphRewriter | None = None, + verbose: bool = False, + **kwargs, +): + """Sample from a model after applying optimizations. + + Parameters + ---------- + model : Model, optinoal + The model to sample from. If None, use the model associated with the context. + include : TAGS_TYPE + The tags to include in the optimizations. Ignored if `rewriter` is not None. + exclude : TAGS_TYPE + The tags to exclude from the optimizations. Ignored if `rewriter` is not None. + rewriter : RewriteDatabaseQuery (optional) + The rewriter to use. If None, use the default rewriter with the given `include` and `exclude` tags. + verbose : bool, default=False + Print information about the optimizations applied. + *args, **kwargs: + Passed to `pm.sample` + + Returns + ------- + sample_output: + The output of `pm.sample` + + Examples + -------- + .. code:: python + import pymc as pm + import pymc_experimental as pmx + + with pm.Model() as m: + p = pm.Beta("p", 1, 1, shape=(1000,)) + y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250) + + idata = pmx.opt_sample(verbose=True) + """ + if kwargs.get("step", None) is not None: + raise ValueError( + "The `step` argument is not supported in `opt_sample`, as custom steps would refer to the original model.\n" + "You can manually transform the model with `pymc_experimental.sampling.optimizations.optimize_model_for_mcmc_sampling` " + "and then define the custom steps and forward them to `pymc.sample`." + ) + + opt_model, rewrite_counters = optimize_model_for_mcmc_sampling( + model, include=include, exclude=exclude, rewriter=rewriter + ) + + if verbose: + applied_opt = False + for rewrite_counter in rewrite_counters: + for rewrite, counts in rewrite_counter.items(): + applied_opt = True + print(f"Applied optimization: {rewrite} {counts}x", file=sys.stdout) + if not applied_opt: + print("No optimizations applied", file=sys.stdout) + + return sample(*args, model=opt_model, **kwargs) diff --git a/pymc_experimental/sampling/optimizations/optimize.py b/pymc_experimental/sampling/optimizations/optimize.py new file mode 100644 index 00000000..18485a8b --- /dev/null +++ b/pymc_experimental/sampling/optimizations/optimize.py @@ -0,0 +1,37 @@ +from collections import Counter +from collections.abc import Sequence +from typing import TypeAlias + +from pymc.model.core import Model, modelcontext +from pymc.model.fgraph import fgraph_from_model, model_from_fgraph +from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabaseQuery + +posterior_optimization_db = EquilibriumDB() +posterior_optimization_db.failure_callback = None # Raise an error if an optimization fails +posterior_optimization_db.name = "posterior_optimization_db" + +TAGS_TYPE: TypeAlias = str | Sequence[str] | None + + +def optimize_model_for_mcmc_sampling( + model: Model, + include: TAGS_TYPE = ("default",), + exclude: TAGS_TYPE = None, + rewriter=None, +) -> tuple[Model, Sequence[Counter]]: + if isinstance(include, str): + include = (include,) + if isinstance(exclude, str): + exclude = (exclude,) + + model = modelcontext(model) + fgraph, _ = fgraph_from_model(model) + + if rewriter is None: + rewriter = posterior_optimization_db.query( + RewriteDatabaseQuery(include=include, exclude=exclude) + ) + _, _, rewrite_counters, *_ = rewriter.rewrite(fgraph) + + opt_model = model_from_fgraph(fgraph, mutate_fgraph=True) + return opt_model, rewrite_counters diff --git a/pymc_extras/__init__.py b/pymc_extras/__init__.py index 7df570f5..cc693141 100644 --- a/pymc_extras/__init__.py +++ b/pymc_extras/__init__.py @@ -18,6 +18,7 @@ from pymc_extras.inference.fit import fit from pymc_extras.model.marginal.marginal_model import MarginalModel, marginalize from pymc_extras.model.model_api import as_model +from pymc_extras.sampling.mcmc import opt_sample from pymc_extras.version import __version__ _log = logging.getLogger("pmx") diff --git a/tests/sampling/mcmc/__init__.py b/tests/sampling/mcmc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sampling/mcmc/test_mcmc.py b/tests/sampling/mcmc/test_mcmc.py new file mode 100644 index 00000000..9debf85f --- /dev/null +++ b/tests/sampling/mcmc/test_mcmc.py @@ -0,0 +1,20 @@ +import pytest + +from pymc.distributions import Beta, Binomial, InverseGamma +from pymc.model.core import Model +from pymc.step_methods import Slice + +from pymc_experimental import opt_sample + + +def test_custom_step_raises(): + with Model() as m: + a = InverseGamma("a", 1, 1) + b = InverseGamma("b", 1, 1) + p = Beta("p", a, b) + y = Binomial("y", n=100, p=p, observed=99) + + with pytest.raises( + ValueError, match="The `step` argument is not supported in `opt_sample`" + ): + opt_sample(step=Slice([a, b])) From cbdb4040ff0612e103312a8d25eaf51d2213b66c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 5 Dec 2024 14:57:26 +0100 Subject: [PATCH 2/4] Add Normal summary stats optimization --- .../sampling/optimizations/__init__.py | 13 +++ .../sampling/optimizations/summary_stats.py | 79 +++++++++++++++++++ tests/sampling/__init__.py | 0 tests/sampling/mcmc/test_mcmc.py | 36 ++++++++- tests/sampling/optimizations/__init__.py | 0 .../optimizations/test_summary_stats.py | 43 ++++++++++ 6 files changed, 170 insertions(+), 1 deletion(-) create mode 100644 pymc_experimental/sampling/optimizations/__init__.py create mode 100644 pymc_experimental/sampling/optimizations/summary_stats.py create mode 100644 tests/sampling/__init__.py create mode 100644 tests/sampling/optimizations/__init__.py create mode 100644 tests/sampling/optimizations/test_summary_stats.py diff --git a/pymc_experimental/sampling/optimizations/__init__.py b/pymc_experimental/sampling/optimizations/__init__.py new file mode 100644 index 00000000..363535b4 --- /dev/null +++ b/pymc_experimental/sampling/optimizations/__init__.py @@ -0,0 +1,13 @@ +# ruff: noqa: F401 +# Add rewrites to the optimization DBs +import pymc_experimental.sampling.optimizations.summary_stats + +from pymc_experimental.sampling.optimizations.optimize import ( + optimize_model_for_mcmc_sampling, + posterior_optimization_db, +) + +__all__ = [ + "posterior_optimization_db", + "optimize_model_for_mcmc_sampling", +] diff --git a/pymc_experimental/sampling/optimizations/summary_stats.py b/pymc_experimental/sampling/optimizations/summary_stats.py new file mode 100644 index 00000000..58a9c806 --- /dev/null +++ b/pymc_experimental/sampling/optimizations/summary_stats.py @@ -0,0 +1,79 @@ +import pytensor.tensor as pt + +from pymc.distributions import Gamma, Normal +from pymc.model.fgraph import ModelObservedRV, model_observed_rv +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import node_rewriter + +from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db + + +@node_rewriter(tracks=[ModelObservedRV]) +def summary_stats_normal(fgraph: FunctionGraph, node): + """Applies the equivalence (up to a normalizing constant) described in: + + https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics + """ + [observed_rv] = node.outputs + [rv, data] = node.inputs + + if not isinstance(rv.owner.op, Normal): + return None + + # Check the normal RV is not just a scalar + if all(rv.type.broadcastable): + return None + + # Check that the observed RV is not used anywhere else (like a Potential or Deterministic) + # There should be only one use: as an "output" + if len(fgraph.clients[observed_rv]) > 1: + return None + + mu, sigma = rv.owner.op.dist_params(rv.owner) + + # Check if mu and sigma are scalar RVs + if not all(mu.type.broadcastable) and not all(sigma.type.broadcastable): + return None + + # Check that mu and sigma are not used anywhere else + # Note: This is too restrictive, it's fine if they're used in Deterministics! + # There should only be two uses: as an "output" and as the param of the `rv` + if len(fgraph.clients[mu]) > 2 or len(fgraph.clients[sigma]) > 2: + return None + + # Remove expand_dims + mu = mu.squeeze() + sigma = sigma.squeeze() + + # Apply the rewrite + mean_data = pt.mean(data) + mean_data.name = None + var_data = pt.var(data, ddof=1) + var_data.name = None + N = data.size + sqrt_N = pt.sqrt(N) + nm1_over2 = (N - 1) / 2 + + observed_mean = model_observed_rv( + Normal.dist(mu=mu, sigma=sigma / sqrt_N), + mean_data, + ) + observed_mean.name = f"{rv.name}_mean" + + observed_var = model_observed_rv( + Gamma.dist(alpha=nm1_over2, beta=nm1_over2 / (sigma**2)), + var_data, + ) + observed_var.name = f"{rv.name}_var" + + fgraph.add_output(observed_mean, import_missing=True) + fgraph.add_output(observed_var, import_missing=True) + fgraph.remove_node(node) + # Just so it shows in the profile for verbose=True, + # It won't do anything because node is not in the fgraph anymore + return [node.out.copy()] + + +posterior_optimization_db.register( + summary_stats_normal.__name__, summary_stats_normal, "default", "summary_stats" +) diff --git a/tests/sampling/__init__.py b/tests/sampling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sampling/mcmc/test_mcmc.py b/tests/sampling/mcmc/test_mcmc.py index 9debf85f..dc8f0f30 100644 --- a/tests/sampling/mcmc/test_mcmc.py +++ b/tests/sampling/mcmc/test_mcmc.py @@ -1,7 +1,9 @@ +import numpy as np import pytest -from pymc.distributions import Beta, Binomial, InverseGamma +from pymc.distributions import Beta, Binomial, HalfNormal, InverseGamma, Normal from pymc.model.core import Model +from pymc.sampling.mcmc import sample from pymc.step_methods import Slice from pymc_experimental import opt_sample @@ -18,3 +20,35 @@ def test_custom_step_raises(): ValueError, match="The `step` argument is not supported in `opt_sample`" ): opt_sample(step=Slice([a, b])) + + +def test_sample_opt_summary_stats(capsys): + rng = np.random.default_rng(3) + y_data = rng.normal(loc=1, scale=0.5, size=(1000,)) + + with Model() as m: + mu = Normal("mu") + sigma = HalfNormal("sigma") + y = Normal("y", mu=mu, sigma=sigma, observed=y_data) + + sample_kwargs = dict( + chains=1, tune=500, draws=500, compute_convergence_checks=False, progressbar=False + ) + idata = sample(**sample_kwargs) + # TODO: Make extract_data more robust to avoid this warning/error + # Or alternatively extract data on the original model, not the optimized one + with pytest.warns(UserWarning, match="Could not extract data from symbolic observation"): + opt_idata = opt_sample(**sample_kwargs, verbose=True) + + captured_out = capsys.readouterr().out + assert "Applied optimization: summary_stats_normal 1x" in captured_out + + assert opt_idata.posterior.sizes["chain"] == 1 + assert opt_idata.posterior.sizes["draw"] == 500 + np.testing.assert_allclose( + idata.posterior["mu"].mean(), opt_idata.posterior["mu"].mean(), rtol=1e-2 + ) + np.testing.assert_allclose( + idata.posterior["sigma"].mean(), opt_idata.posterior["sigma"].mean(), rtol=1e-2 + ) + assert idata.sample_stats.sampling_time > opt_idata.sample_stats.sampling_time diff --git a/tests/sampling/optimizations/__init__.py b/tests/sampling/optimizations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sampling/optimizations/test_summary_stats.py b/tests/sampling/optimizations/test_summary_stats.py new file mode 100644 index 00000000..bb390e0e --- /dev/null +++ b/tests/sampling/optimizations/test_summary_stats.py @@ -0,0 +1,43 @@ +import numpy as np + +from pymc.distributions import HalfNormal, Normal +from pymc.model.core import Model + +from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling + + +def test_summary_stats_normal(): + rng = np.random.default_rng(3) + y_data = rng.normal(loc=1, scale=0.5, size=(1000,)) + + with Model() as m: + mu = Normal("mu") + sigma = HalfNormal("sigma") + y = Normal("y", mu=mu, sigma=sigma, observed=y_data) + + assert len(m.free_RVs) == 2 + assert len(m.observed_RVs) == 1 + + new_m, rewrite_counters = optimize_model_for_mcmc_sampling(m) + assert "summary_stats_normal" in (r.name for rc in rewrite_counters for r in rc) + + assert len(new_m.free_RVs) == 2 + assert len(new_m.observed_RVs) == 2 + + # Confirm equivalent (up to an additive normalization constant) + m_logp = m.compile_logp() + new_m_logp = new_m.compile_logp() + + ip = m.initial_point() + first_logp_diff = m_logp(ip) - new_m_logp(ip) + + ip["mu"] += 0.5 + ip["sigma_log__"] += 1.5 + second_logp_diff = m_logp(ip) - new_m_logp(ip) + + np.testing.assert_allclose(first_logp_diff, second_logp_diff) + + # dlogp should be the same + m_dlogp = m.compile_dlogp() + new_m_dlogp = new_m.compile_dlogp() + np.testing.assert_allclose(m_dlogp(ip), new_m_dlogp(ip)) From 8de53460fadf43f01a071974590e9a323bae1d84 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 5 Dec 2024 14:57:37 +0100 Subject: [PATCH 3/4] Add Beta-Binomial conjugacy optimization --- pymc_experimental/sampling/mcmc.py | 38 ++++ .../sampling/optimizations/__init__.py | 1 + .../sampling/optimizations/conjugacy.py | 205 ++++++++++++++++++ .../optimizations/conjugate_sampler.py | 115 ++++++++++ pymc_extras/model/marginal/distributions.py | 13 +- pymc_extras/utils/ofg.py | 17 ++ tests/sampling/mcmc/test_mcmc.py | 54 +++++ .../sampling/optimizations/test_conjugacy.py | 59 +++++ 8 files changed, 490 insertions(+), 12 deletions(-) create mode 100644 pymc_experimental/sampling/optimizations/conjugacy.py create mode 100644 pymc_experimental/sampling/optimizations/conjugate_sampler.py create mode 100644 pymc_extras/utils/ofg.py create mode 100644 tests/sampling/optimizations/test_conjugacy.py diff --git a/pymc_experimental/sampling/mcmc.py b/pymc_experimental/sampling/mcmc.py index 0d1085f4..c57e7fa2 100644 --- a/pymc_experimental/sampling/mcmc.py +++ b/pymc_experimental/sampling/mcmc.py @@ -52,6 +52,44 @@ def opt_sample( y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250) idata = pmx.opt_sample(verbose=True) + + # Applied optimization: beta_binomial_conjugacy 1x + # ConjugateRVSampler: [p] + + + You can control which optimizations are applied using the `include` and `exclude` arguments: + + .. code:: python + import pymc as pm + import pymc_experimental as pmx + + with pm.Model() as m: + p = pm.Beta("p", 1, 1, shape=(1000,)) + y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250) + + idata = pmx.opt_sample(exclude="conjugacy", verbose=True) + + # No optimizations applied + # NUTS: [p] + + .. code:: python + import pymc as pm + import pymc_experimental as pmx + + with pm.Model() as m: + a = pm.InverseGamma("a", 1, 1) + b = pm.InverseGamma("b", 1, 1) + p = pm.Beta("p", a, b, shape=(1000,)) + y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250) + + # By default, the conjugacy of p will not be applied because it depends on other free variables + idata = pmx.opt_sample(include="conjugacy-eager", verbose=True) + + # Applied optimization: beta_binomial_conjugacy_eager 1x + # CompoundStep + # >NUTS: [a, b] + # >ConjugateRVSampler: [p] + """ if kwargs.get("step", None) is not None: raise ValueError( diff --git a/pymc_experimental/sampling/optimizations/__init__.py b/pymc_experimental/sampling/optimizations/__init__.py index 363535b4..78d0c858 100644 --- a/pymc_experimental/sampling/optimizations/__init__.py +++ b/pymc_experimental/sampling/optimizations/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 # Add rewrites to the optimization DBs +import pymc_experimental.sampling.optimizations.conjugacy import pymc_experimental.sampling.optimizations.summary_stats from pymc_experimental.sampling.optimizations.optimize import ( diff --git a/pymc_experimental/sampling/optimizations/conjugacy.py b/pymc_experimental/sampling/optimizations/conjugacy.py new file mode 100644 index 00000000..f4c3a360 --- /dev/null +++ b/pymc_experimental/sampling/optimizations/conjugacy.py @@ -0,0 +1,205 @@ +from collections.abc import Sequence +from functools import partial + +from pymc.distributions import Beta, Binomial +from pymc.model.fgraph import ModelFreeRV, ModelValuedVar, model_free_rv +from pymc.pytensorf import collect_default_updates +from pytensor.graph.basic import Variable, ancestors +from pytensor.graph.fg import FunctionGraph, Output +from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.subtensor import _sum_grad_over_bcasted_dims as sum_bcasted_dims + +from pymc_experimental.sampling.optimizations.conjugate_sampler import ( + ConjugateRV, +) +from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db + + +def register_conjugacy_rewrites_variants(rewrite_fn, tracks=(ModelFreeRV,)): + """Register a rewrite function and its force variant in the posterior optimization DB.""" + name = rewrite_fn.__name__ + + rewrite_fn_default = partial(rewrite_fn, eager=False) + rewrite_fn_default.__name__ = name + rewrite_default = node_rewriter(tracks=tracks)(rewrite_fn_default) + + rewrite_fn_eager = partial(rewrite_fn, eager=True) + rewrite_fn_eager.__name__ = f"{name}_eager" + rewrite_eager = node_rewriter(tracks=tracks)(rewrite_fn_eager) + + posterior_optimization_db.register( + rewrite_default.__name__, + rewrite_default, + "default", + "conjugacy", + ) + + posterior_optimization_db.register( + rewrite_eager.__name__, + rewrite_eager, + "non-default", + "conjugacy-eager", + ) + + return rewrite_default, rewrite_eager + + +def has_free_rv_ancestor(vars: Variable | Sequence[Variable]) -> bool: + """Return True if any of the variables have a model variable as an ancestor.""" + if not isinstance(vars, Sequence): + vars = (vars,) + + # TODO: It should stop at observed RVs, it doesn't matter if they have a free RV above + # Did not implement due to laziness and it being a rare case + return any( + var.owner is not None and isinstance(var.owner.op, ModelFreeRV) for var in ancestors(vars) + ) + + +def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable: + """Return the Model dummy var that wraps the RV""" + for client, _ in fgraph.clients[rv]: + if isinstance(client.op, ModelValuedVar): + return client.outputs[0] + + +def get_dist_params(rv: Variable) -> tuple[Variable]: + return rv.owner.op.dist_params(rv.owner) + + +def get_size_param(rv: Variable) -> Variable: + return rv.owner.op.size_param(rv.owner) + + +def rv_used_by( + fgraph: FunctionGraph, + rv: Variable, + used_by_type: type, + used_as_arg_idx: int | Sequence[int], + arg_idx_offset: int = 2, # Ignore the first two arguments (rng and size) + strict: bool = True, +) -> list[Variable]: + """Return the RVs that use `rv` as an argument in an operation of type `used_by_type`. + + RV may be used directly or broadcasted before being used. + + Parameters + ---------- + fgraph : FunctionGraph + The function graph containing the RVs + rv : Variable + The RV to check for uses. + used_by_type : type + The type of operation that may use the RV. + used_as_arg_idx : int | Sequence[int] + The index of the RV in the operation's inputs. + strict : bool, default=True + If True, return no results when the RV is used in an unrecognized way. + + """ + if isinstance(used_as_arg_idx, int): + used_as_arg_idx = (used_as_arg_idx,) + used_as_arg_idx = tuple(arg_idx + arg_idx_offset for arg_idx in used_as_arg_idx) + + clients = fgraph.clients + used_by: list[Variable] = [] + for client, inp_idx in clients[rv]: + if isinstance(client.op, Output): + continue + + if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx: + # RV is directly used by the RV type + used_by.append(client.default_output()) + + elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims: + for sub_client, sub_inp_idx in clients[client.outputs[0]]: + if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx: + # RV is broadcasted and then used by the RV type + used_by.append(sub_client.default_output()) + elif strict: + # Some other unrecognized use, bail out + return [] + elif strict: + # Some other unrecognized use, bail out + return [] + + return used_by + + +def wrap_rv_and_conjugate_rv( + fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable] +) -> Variable: + """Wrap the RV and its conjugate posterior RV in a ConjugateRV node. + + Also takes care of handling the random number generators used in the conjugate posterior. + """ + rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items()) + for rng in rngs: + if rng not in fgraph.inputs: + fgraph.add_input(rng) + conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs]) + return conjugate_op(rv, *inputs, *rngs)[0] + + +def create_untransformed_free_rv( + fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable] +) -> Variable: + """Create a model FreeRV without transform.""" + transform = None + value = rv.type(name=name) + fgraph.add_input(value) + free_rv = model_free_rv(rv, value, transform, *dims) + free_rv.name = name + return free_rv + + +def beta_binomial_conjugacy(fgraph: FunctionGraph, node, eager: bool = False): + if not isinstance(node.op, ModelFreeRV): + return None + + [beta_free_rv] = node.outputs + beta_rv, _, *beta_dims = node.inputs + + if not isinstance(beta_rv.owner.op, Beta): + return None + + a, b = get_dist_params(beta_rv) + if not eager and has_free_rv_ancestor([a, b]): + # Don't apply rewrite if a, b depend on other model variables as that will force a Gibbs sampling scheme + return None + + p_arg_idx = 1 # Params to the Binomial are (n, p) + binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx) + + if len(binomial_rvs) != 1: + # Question: Can we apply conjugacy when RV is used by more than one binomial? + return None + + [binomial_rv] = binomial_rvs + + binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv) + if binomial_model_var is None: + return None + + # We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv) + n, _ = get_dist_params(binomial_rv) + + # Use value of y in new graph to avoid circularity + y = binomial_model_var.owner.inputs[1] + + conjugate_a = sum_bcasted_dims(beta_rv, a + y) + conjugate_b = sum_bcasted_dims(beta_rv, b + (n - y)) + + conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b, size=get_size_param(beta_rv)) + + new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y]) + new_beta_free_rv = create_untransformed_free_rv( + fgraph, new_beta_rv, beta_free_rv.name, beta_dims + ) + return [new_beta_free_rv] + + +beta_binomial_conjugacy_default, beta_binomial_conjugacy_force = ( + register_conjugacy_rewrites_variants(beta_binomial_conjugacy) +) diff --git a/pymc_experimental/sampling/optimizations/conjugate_sampler.py b/pymc_experimental/sampling/optimizations/conjugate_sampler.py new file mode 100644 index 00000000..9ddaa462 --- /dev/null +++ b/pymc_experimental/sampling/optimizations/conjugate_sampler.py @@ -0,0 +1,115 @@ +import numpy as np + +from pymc import STEP_METHODS +from pymc.distributions.distribution import _support_point +from pymc.initial_point import PointType +from pymc.logprob.abstract import MeasurableOp, _logprob +from pymc.model.core import modelcontext +from pymc.pytensorf import compile_pymc +from pymc.step_methods.compound import BlockedStep, Competence, StepMethodState +from pymc.util import get_value_vars_from_user_vars +from pytensor import shared +from pytensor.compile.builders import OpFromGraph +from pytensor.link.jax.linker import JAXLinker +from pytensor.tensor.random.type import RandomGeneratorType + +from pymc_experimental.utils.ofg import inline_ofg_outputs + + +class ConjugateRV(OpFromGraph, MeasurableOp): + """Wrapper for ConjugateRVs, that outputs the original RV and the conjugate posterior expression. + + For partial step samplers to work, the logp and initial point correspond to the original RV + while the variable itself is sampled by default by the `ConjugateRVSampler` by evaluating directly the + conjugate posterior expression (i.e., taking forward random draws). + """ + + +@_logprob.register(ConjugateRV) +def conjugate_rv_logp(op, values, rv, *params, **kwargs): + # Logp is the same as the original RV + return _logprob(rv.owner.op, values, *rv.owner.inputs) + + +@_support_point.register(ConjugateRV) +def conjugate_rv_support_point(op, conjugate_rv, rv, *params): + # Support point is the same as the original RV + return _support_point(rv.owner.op, rv, *rv.owner.inputs) + + +class ConjugateRVSampler(BlockedStep): + name = "conjugate_rv_sampler" + _state_class = StepMethodState + + def __init__(self, vars, model=None, rng=None, compile_kwargs: dict | None = None, **kwargs): + if len(vars) != 1: + raise ValueError("ConjugateRVSampler can only be assigned to one variable at a time") + + model = modelcontext(model) + [value] = get_value_vars_from_user_vars(vars, model=model) + rv = model.values_to_rvs[value] + self.vars = (value,) + self.rv_name = value.name + + if model.rvs_to_transforms[rv] is not None: + raise ValueError("Variable assigned to ConjugateRVSampler cannot be transformed") + + rv_and_posterior_rv_node = rv.owner + op = rv_and_posterior_rv_node.op + if not isinstance(op, ConjugateRV): + raise ValueError("Variable must be a ConjugateRV") + + # Replace RVs in inputs of rv_posterior_rv_node by the corresponding value variables + value_inputs = model.replace_rvs_by_values( + [rv_and_posterior_rv_node.outputs[1]], + )[0].owner.inputs + # Inline the ConjugateRV graph to only compile `posterior_rv` + _, posterior_rv, *_ = inline_ofg_outputs(op, value_inputs) + + if compile_kwargs is None: + compile_kwargs = {} + self.posterior_fn = compile_pymc( + model.value_vars, + posterior_rv, + random_seed=rng, + on_unused_input="ignore", + **compile_kwargs, + ) + self.posterior_fn.trust_input = True + if isinstance(self.posterior_fn.maker.linker, JAXLinker): + # Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables + # used internally are not the ones that `function.get_shared()` returns. + raise ValueError("ConjugateRVSampler is not compatible with JAX backend") + + def set_rng(self, rng: np.random.Generator): + # Copy the function and replace any shared RNGs + # This is needed so that it can work correctly with multiple traces + # This will be costly if set_rng is called too often! + shared_rngs = [ + var + for var in self.posterior_fn.get_shared() + if isinstance(var.type, RandomGeneratorType) + ] + n_shared_rngs = len(shared_rngs) + swap = { + old_shared_rng: shared(rng, borrow=True) + for old_shared_rng, rng in zip(shared_rngs, rng.spawn(n_shared_rngs), strict=True) + } + self.posterior_fn = self.posterior_fn.copy(swap=swap) + + def step(self, point: PointType) -> tuple[PointType, list]: + new_point = point.copy() + new_point[self.rv_name] = self.posterior_fn(**point) + return new_point, [] + + @staticmethod + def competence(var, has_grad): + """BinaryMetropolis is only suitable for Bernoulli and Categorical variables with k=2.""" + if isinstance(var.owner.op, ConjugateRV): + return Competence.IDEAL + + return Competence.INCOMPATIBLE + + +# Register the ConjugateRVSampler +STEP_METHODS.append(ConjugateRVSampler) diff --git a/pymc_extras/model/marginal/distributions.py b/pymc_extras/model/marginal/distributions.py index 33f4bc7c..de6ec5ba 100644 --- a/pymc_extras/model/marginal/distributions.py +++ b/pymc_extras/model/marginal/distributions.py @@ -17,6 +17,7 @@ from pytensor.tensor import TensorVariable from pymc_extras.distributions import DiscreteMarkovChain +from pymc_extras.utils.ofg import inline_ofg_outputs class MarginalRV(OpFromGraph, MeasurableOp): @@ -126,18 +127,6 @@ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> Tens return logp.transpose(*dims_alignment) -def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: - """Inline the inner graph (outputs) of an OpFromGraph Op. - - Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" - the inner graph. - """ - return clone_replace( - op.inner_outputs, - replace=tuple(zip(op.inner_inputs, inputs)), - ) - - DUMMY_ZERO = pt.constant(0, name="dummy_zero") diff --git a/pymc_extras/utils/ofg.py b/pymc_extras/utils/ofg.py new file mode 100644 index 00000000..6de8ed4a --- /dev/null +++ b/pymc_extras/utils/ofg.py @@ -0,0 +1,17 @@ +from collections.abc import Sequence + +from pytensor.compile.builders import OpFromGraph +from pytensor.graph.basic import Variable +from pytensor.graph.replace import clone_replace + + +def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: + """Inline the inner graph (outputs) of an OpFromGraph Op. + + Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" + the inner graph. + """ + return clone_replace( + op.inner_outputs, + replace=tuple(zip(op.inner_inputs, inputs)), + ) diff --git a/tests/sampling/mcmc/test_mcmc.py b/tests/sampling/mcmc/test_mcmc.py index dc8f0f30..f179e22d 100644 --- a/tests/sampling/mcmc/test_mcmc.py +++ b/tests/sampling/mcmc/test_mcmc.py @@ -1,3 +1,5 @@ +import logging + import numpy as np import pytest @@ -52,3 +54,55 @@ def test_sample_opt_summary_stats(capsys): idata.posterior["sigma"].mean(), opt_idata.posterior["sigma"].mean(), rtol=1e-2 ) assert idata.sample_stats.sampling_time > opt_idata.sample_stats.sampling_time + + +def test_sample_opt_conjugate(caplog, capsys): + caplog.set_level(logging.INFO, logger="pymc") + + sample_kwargs = dict( + tune=0, + draws=250, + chains=4, + progressbar=False, + compute_convergence_checks=False, + ) + + with Model() as m: + p = Beta("p", 1, 1) + y = Binomial("y", n=100, p=p, observed=99) + + idata = opt_sample( + **sample_kwargs, + random_seed=0, + verbose=True, + ) + + assert "Applied optimization: beta_binomial_conjugacy 1x" in capsys.readouterr().out + + # Test it used ConjugateRVSampler + assert "ConjugateRVSampler: [p]" in caplog.text + + np.testing.assert_allclose(idata.posterior["p"].mean(), 100 / 102, atol=1e-3) + np.testing.assert_allclose( + idata.posterior["p"].std(), np.sqrt(100 * 2 / (102**2 * 103)), atol=1e-3 + ) + + # Draws are different across chains + assert (np.diff(idata.posterior["p"].isel(draw=0).values) != 0).all() + + # Check draws respect random_seed + with m: + new_idata = opt_sample( + **sample_kwargs, + random_seed=0, + ) + np.testing.assert_allclose( + idata.posterior["p"].isel(draw=0), new_idata.posterior["p"].isel(draw=0) + ) + + with m: + new_idata = opt_sample( + **sample_kwargs, + random_seed=1, + ) + assert not np.allclose(idata.posterior["p"].isel(draw=0), new_idata.posterior["p"].isel(draw=0)) diff --git a/tests/sampling/optimizations/test_conjugacy.py b/tests/sampling/optimizations/test_conjugacy.py new file mode 100644 index 00000000..b45aabfc --- /dev/null +++ b/tests/sampling/optimizations/test_conjugacy.py @@ -0,0 +1,59 @@ +import numpy as np +import pytest + +from pymc.distributions import Beta, Binomial, DiracDelta +from pymc.model.core import Model +from pymc.model.transform.conditioning import remove_value_transforms +from pymc.sampling import draw + +from pymc_experimental.sampling.optimizations.conjugate_sampler import ConjugateRV +from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling + + +@pytest.mark.parametrize("eager", [False, True]) +def test_beta_binomial_conjugacy(eager): + with Model() as m: + if eager: + a, b = DiracDelta("a,b", [1, 1]) + else: + a, b = 1, 1 + p = Beta("p", a, b) + y = Binomial("y", n=100, p=p, observed=99) + + assert m.rvs_to_transforms[p] is not None + assert isinstance(p.owner.op, Beta) + + new_m, rewrite_counters = optimize_model_for_mcmc_sampling(m) + rewrite_applied = "beta_binomial_conjugacy" in (r.name for rc in rewrite_counters for r in rc) + if eager: + assert not rewrite_applied + new_m, rewrite_counters = optimize_model_for_mcmc_sampling(m, include="conjugacy-eager") + assert "beta_binomial_conjugacy_eager" in (r.name for rc in rewrite_counters for r in rc) + else: + assert rewrite_applied + + new_p = new_m["p"] + assert isinstance(new_p.owner.op, ConjugateRV) + assert new_m.rvs_to_transforms[new_p] is None + beta_rv, conjugate_beta_rv, *_ = new_p.owner.outputs + + # Check it behaves like a beta and its conjugate + beta_draws, conjugate_beta_draws = draw( + [beta_rv, conjugate_beta_rv], draws=1000, random_seed=25 + ) + np.testing.assert_allclose(beta_draws.mean(), 1 / 2, atol=1e-2) + np.testing.assert_allclose(conjugate_beta_draws.mean(), 100 / 102, atol=1e-3) + np.testing.assert_allclose(beta_draws.std(), np.sqrt(1 / 12), atol=1e-2) + np.testing.assert_allclose( + conjugate_beta_draws.std(), np.sqrt(100 * 2 / (102**2 * 103)), atol=1e-3 + ) + + # Check if support point and logp is the same as the original model without transforms + untransformed_m = remove_value_transforms(m) + new_m_ip = new_m.initial_point() + for key, value in untransformed_m.initial_point().items(): + np.testing.assert_allclose(new_m_ip[key], value) + + new_m_logp = new_m.compile_logp()(new_m_ip) + untransformed_m_logp = untransformed_m.compile_logp()(new_m_ip) + np.testing.assert_allclose(new_m_logp, untransformed_m_logp) From 03593734c064f2677da02610725c191650a0c849 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 2 Jan 2025 14:00:05 +0800 Subject: [PATCH 4/4] rename `pymc_experimental` -> `pymc_extras` --- pymc_extras/model/marginal/distributions.py | 1 - {pymc_experimental => pymc_extras}/sampling/__init__.py | 0 {pymc_experimental => pymc_extras}/sampling/mcmc.py | 2 +- .../sampling/optimizations/__init__.py | 5 ++--- .../sampling/optimizations/conjugacy.py | 4 ++-- .../sampling/optimizations/conjugate_sampler.py | 2 +- .../sampling/optimizations/optimize.py | 0 .../sampling/optimizations/summary_stats.py | 2 +- pymc_extras/utils/ofg.py | 2 +- tests/sampling/mcmc/test_mcmc.py | 2 +- tests/sampling/optimizations/test_conjugacy.py | 4 ++-- tests/sampling/optimizations/test_summary_stats.py | 2 +- 12 files changed, 12 insertions(+), 14 deletions(-) rename {pymc_experimental => pymc_extras}/sampling/__init__.py (100%) rename {pymc_experimental => pymc_extras}/sampling/mcmc.py (98%) rename {pymc_experimental => pymc_extras}/sampling/optimizations/__init__.py (54%) rename {pymc_experimental => pymc_extras}/sampling/optimizations/conjugacy.py (97%) rename {pymc_experimental => pymc_extras}/sampling/optimizations/conjugate_sampler.py (98%) rename {pymc_experimental => pymc_extras}/sampling/optimizations/optimize.py (100%) rename {pymc_experimental => pymc_extras}/sampling/optimizations/summary_stats.py (96%) diff --git a/pymc_extras/model/marginal/distributions.py b/pymc_extras/model/marginal/distributions.py index de6ec5ba..d1807f61 100644 --- a/pymc_extras/model/marginal/distributions.py +++ b/pymc_extras/model/marginal/distributions.py @@ -7,7 +7,6 @@ from pymc.logprob.abstract import MeasurableOp, _logprob from pymc.logprob.basic import conditional_logp, logp from pymc.pytensorf import constant_fold -from pytensor import Variable from pytensor.compile.builders import OpFromGraph from pytensor.compile.mode import Mode from pytensor.graph import Op, vectorize_graph diff --git a/pymc_experimental/sampling/__init__.py b/pymc_extras/sampling/__init__.py similarity index 100% rename from pymc_experimental/sampling/__init__.py rename to pymc_extras/sampling/__init__.py diff --git a/pymc_experimental/sampling/mcmc.py b/pymc_extras/sampling/mcmc.py similarity index 98% rename from pymc_experimental/sampling/mcmc.py rename to pymc_extras/sampling/mcmc.py index c57e7fa2..65b9dd2e 100644 --- a/pymc_experimental/sampling/mcmc.py +++ b/pymc_extras/sampling/mcmc.py @@ -4,7 +4,7 @@ from pymc.sampling.mcmc import sample from pytensor.graph.rewriting.basic import GraphRewriter -from pymc_experimental.sampling.optimizations.optimize import ( +from pymc_extras.sampling.optimizations.optimize import ( TAGS_TYPE, optimize_model_for_mcmc_sampling, ) diff --git a/pymc_experimental/sampling/optimizations/__init__.py b/pymc_extras/sampling/optimizations/__init__.py similarity index 54% rename from pymc_experimental/sampling/optimizations/__init__.py rename to pymc_extras/sampling/optimizations/__init__.py index 78d0c858..46fdf848 100644 --- a/pymc_experimental/sampling/optimizations/__init__.py +++ b/pymc_extras/sampling/optimizations/__init__.py @@ -1,9 +1,8 @@ # ruff: noqa: F401 # Add rewrites to the optimization DBs -import pymc_experimental.sampling.optimizations.conjugacy -import pymc_experimental.sampling.optimizations.summary_stats -from pymc_experimental.sampling.optimizations.optimize import ( +from pymc_extras.sampling.optimizations import conjugacy, summary_stats +from pymc_extras.sampling.optimizations.optimize import ( optimize_model_for_mcmc_sampling, posterior_optimization_db, ) diff --git a/pymc_experimental/sampling/optimizations/conjugacy.py b/pymc_extras/sampling/optimizations/conjugacy.py similarity index 97% rename from pymc_experimental/sampling/optimizations/conjugacy.py rename to pymc_extras/sampling/optimizations/conjugacy.py index f4c3a360..f1f6e246 100644 --- a/pymc_experimental/sampling/optimizations/conjugacy.py +++ b/pymc_extras/sampling/optimizations/conjugacy.py @@ -10,10 +10,10 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.subtensor import _sum_grad_over_bcasted_dims as sum_bcasted_dims -from pymc_experimental.sampling.optimizations.conjugate_sampler import ( +from pymc_extras.sampling.optimizations.conjugate_sampler import ( ConjugateRV, ) -from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db +from pymc_extras.sampling.optimizations.optimize import posterior_optimization_db def register_conjugacy_rewrites_variants(rewrite_fn, tracks=(ModelFreeRV,)): diff --git a/pymc_experimental/sampling/optimizations/conjugate_sampler.py b/pymc_extras/sampling/optimizations/conjugate_sampler.py similarity index 98% rename from pymc_experimental/sampling/optimizations/conjugate_sampler.py rename to pymc_extras/sampling/optimizations/conjugate_sampler.py index 9ddaa462..1502cfe1 100644 --- a/pymc_experimental/sampling/optimizations/conjugate_sampler.py +++ b/pymc_extras/sampling/optimizations/conjugate_sampler.py @@ -13,7 +13,7 @@ from pytensor.link.jax.linker import JAXLinker from pytensor.tensor.random.type import RandomGeneratorType -from pymc_experimental.utils.ofg import inline_ofg_outputs +from pymc_extras.utils.ofg import inline_ofg_outputs class ConjugateRV(OpFromGraph, MeasurableOp): diff --git a/pymc_experimental/sampling/optimizations/optimize.py b/pymc_extras/sampling/optimizations/optimize.py similarity index 100% rename from pymc_experimental/sampling/optimizations/optimize.py rename to pymc_extras/sampling/optimizations/optimize.py diff --git a/pymc_experimental/sampling/optimizations/summary_stats.py b/pymc_extras/sampling/optimizations/summary_stats.py similarity index 96% rename from pymc_experimental/sampling/optimizations/summary_stats.py rename to pymc_extras/sampling/optimizations/summary_stats.py index 58a9c806..167e2f1d 100644 --- a/pymc_experimental/sampling/optimizations/summary_stats.py +++ b/pymc_extras/sampling/optimizations/summary_stats.py @@ -5,7 +5,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db +from pymc_extras.sampling.optimizations.optimize import posterior_optimization_db @node_rewriter(tracks=[ModelObservedRV]) diff --git a/pymc_extras/utils/ofg.py b/pymc_extras/utils/ofg.py index 6de8ed4a..c85c85cb 100644 --- a/pymc_extras/utils/ofg.py +++ b/pymc_extras/utils/ofg.py @@ -5,7 +5,7 @@ from pytensor.graph.replace import clone_replace -def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: +def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> list[Variable]: """Inline the inner graph (outputs) of an OpFromGraph Op. Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" diff --git a/tests/sampling/mcmc/test_mcmc.py b/tests/sampling/mcmc/test_mcmc.py index f179e22d..beda5bbe 100644 --- a/tests/sampling/mcmc/test_mcmc.py +++ b/tests/sampling/mcmc/test_mcmc.py @@ -8,7 +8,7 @@ from pymc.sampling.mcmc import sample from pymc.step_methods import Slice -from pymc_experimental import opt_sample +from pymc_extras import opt_sample def test_custom_step_raises(): diff --git a/tests/sampling/optimizations/test_conjugacy.py b/tests/sampling/optimizations/test_conjugacy.py index b45aabfc..658f56d4 100644 --- a/tests/sampling/optimizations/test_conjugacy.py +++ b/tests/sampling/optimizations/test_conjugacy.py @@ -6,8 +6,8 @@ from pymc.model.transform.conditioning import remove_value_transforms from pymc.sampling import draw -from pymc_experimental.sampling.optimizations.conjugate_sampler import ConjugateRV -from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling +from pymc_extras.sampling.optimizations import optimize_model_for_mcmc_sampling +from pymc_extras.sampling.optimizations.conjugate_sampler import ConjugateRV @pytest.mark.parametrize("eager", [False, True]) diff --git a/tests/sampling/optimizations/test_summary_stats.py b/tests/sampling/optimizations/test_summary_stats.py index bb390e0e..bf439bbe 100644 --- a/tests/sampling/optimizations/test_summary_stats.py +++ b/tests/sampling/optimizations/test_summary_stats.py @@ -3,7 +3,7 @@ from pymc.distributions import HalfNormal, Normal from pymc.model.core import Model -from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling +from pymc_extras.sampling.optimizations import optimize_model_for_mcmc_sampling def test_summary_stats_normal():