From b00948c92a8b1becb1cf02cb5be7c6fefa47cb66 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 3 Jul 2021 00:57:19 -0400 Subject: [PATCH 01/53] trace_elbo --- pyro/contrib/funsor/infer/trace_elbo.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index e91787732f..ea9fc9805e 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -29,19 +29,14 @@ def differentiable_loss(self, model, guide, *args, **kwargs): model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) - log_measures = guide_terms["log_measures"] + model_terms["log_measures"] - log_factors = model_terms["log_factors"] + [ + costs = model_terms["log_factors"] + [ -f for f in guide_terms["log_factors"] ] plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"] - measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"] - - elbo = funsor.Integrate( - sum(log_measures, to_funsor(0.0)), - sum(log_factors, to_funsor(0.0)), - measure_vars, - ) - elbo = elbo.reduce(funsor.ops.add, plate_vars) + + elbo = to_funsor(0.) + for cost in costs: + elbo += cost.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) return -to_data(elbo) From da2f8876b5abc467cd3f9748c9d298f76cfcb270 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 3 Jul 2021 00:58:59 -0400 Subject: [PATCH 02/53] lint --- pyro/contrib/funsor/infer/trace_elbo.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index ea9fc9805e..3de6774696 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -29,12 +29,10 @@ def differentiable_loss(self, model, guide, *args, **kwargs): model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) - costs = model_terms["log_factors"] + [ - -f for f in guide_terms["log_factors"] - ] + costs = model_terms["log_factors"] + [-f for f in guide_terms["log_factors"]] plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"] - elbo = to_funsor(0.) + elbo = to_funsor(0.0) for cost in costs: elbo += cost.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) From 3ec076d24496d3e72b5834d9e060f59fa6e21b2a Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 5 Jul 2021 05:01:54 -0400 Subject: [PATCH 03/53] test_gradient --- tests/contrib/funsor/test_gradient.py | 129 ++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 tests/contrib/funsor/test_gradient.py diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py new file mode 100644 index 0000000000..0ea46fc913 --- /dev/null +++ b/tests/contrib/funsor/test_gradient.py @@ -0,0 +1,129 @@ +# Copyright (c) 2017-2019 Uber Technologies, Inc. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import pytest +import torch + +from pyro.distributions.testing import fakes +from tests.common import assert_equal + +# put all funsor-related imports here, so test collection works without funsor +try: + import funsor + + import pyro.contrib.funsor + + funsor.set_backend("torch") + from pyroapi import distributions as dist + from pyroapi import handlers, infer, pyro, pyro_backend +except ImportError: + pytestmark = pytest.mark.skip(reason="funsor is not installed") + +logger = logging.getLogger(__name__) + +# _PYRO_BACKEND = os.environ.get("TEST_ENUM_PYRO_BACKEND", "contrib.funsor") + + +@pytest.mark.parametrize( + "reparameterized,has_rsample", + [(True, None), (True, False), (True, True), (False, None)], + ids=["reparam", "reparam-False", "reparam-True", "nonreparam"], +) +@pytest.mark.parametrize( + "Elbo", + [ + "Trace_ELBO", + "TraceEnum_ELBO", + ], +) +@pytest.mark.parametrize( + "backend", + [ + "pyro", + "contrib.funsor", + ], +) +def test_particle_gradient(Elbo, reparameterized, has_rsample, backend): + with pyro_backend(backend): + pyro.clear_param_store() + data = torch.tensor([-0.5, 2.0]) + Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal + + def model(): + with pyro.plate("data", len(data)) as ind: + x = data[ind] + z = pyro.sample("z", Normal(0, 1)) + pyro.sample("x", Normal(z, 1), obs=x) + + def guide(): + scale = pyro.param("scale", lambda: torch.tensor([1.0])) + with pyro.plate("data", len(data)): + loc = pyro.param("loc", lambda: torch.zeros(len(data)), event_dim=0) + z_dist = Normal(loc, scale) + if has_rsample is not None: + z_dist.has_rsample_(has_rsample) + pyro.sample("z", z_dist) + + elbo = getattr(infer, Elbo)( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Elbo gradient estimator + pyro.set_rng_seed(0) + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = { + name: param.grad.detach().cpu() for name, param in params.items() + } + + # capture sample values and log_probs + pyro.set_rng_seed(0) + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + x = data + z = guide_tr.nodes["z"]["value"].data + loc = pyro.param("loc").data + scale = pyro.param("scale").data + + # expected grads + if reparameterized and has_rsample is not False: + # pathwise gradient estimator + expected_grads = { + "scale": -(-z * (z - loc) + (x - z) * (z - loc) + 1).sum( + 0, keepdim=True + ) + / scale, + "loc": -(-z + (x - z)), + } + else: + # score function gradient estimator + elbo = ( + model_tr.nodes["x"]["log_prob"].data + + model_tr.nodes["z"]["log_prob"].data + - guide_tr.nodes["z"]["log_prob"].data + ) + dlogq_dloc = (z - loc) / scale ** 2 + dlogq_dscale = (z - loc) ** 2 / scale ** 3 - 1 / scale + if Elbo == "TraceEnum_ELBO": + expected_grads = { + "scale": -(dlogq_dscale * elbo - dlogq_dscale).sum(0, keepdim=True), + "loc": -(dlogq_dloc * elbo - dlogq_dloc), + } + elif Elbo == "Trace_ELBO": + # expected value of dlogq_dscale and dlogq_dloc is zero + expected_grads = { + "scale": -(dlogq_dscale * elbo).sum(0, keepdim=True), + "loc": -(dlogq_dloc * elbo), + } + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) From a22ff4e93a19174d210d3e5d5035e958b4bc1a42 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 16 Jul 2021 01:45:50 -0400 Subject: [PATCH 04/53] copy traceenum_elbo and add test model with poisson dist --- pyro/contrib/funsor/infer/trace_elbo.py | 75 +++++++++++++++++++--- tests/contrib/funsor/test_gradient.py | 82 ++++++++----------------- 2 files changed, 89 insertions(+), 68 deletions(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 3de6774696..77fc036751 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -4,6 +4,7 @@ import contextlib import funsor +from funsor.adjoint import AdjointTape from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace @@ -12,7 +13,7 @@ from pyro.infer import Trace_ELBO as _OrigTrace_ELBO from .elbo import ELBO, Jit_ELBO -from .traceenum_elbo import terms_from_trace +from .traceenum_elbo import apply_optimizer, terms_from_trace @copy_docs_from(_OrigTrace_ELBO) @@ -21,22 +22,76 @@ def differentiable_loss(self, model, guide, *args, **kwargs): with enum(), plate( size=self.num_particles ) if self.num_particles > 1 else contextlib.ExitStack(): - guide_tr = trace(config_enumerate(default="flat")(guide)).get_trace( - *args, **kwargs - ) + guide_tr = trace( + config_enumerate(default="flat", num_samples=self.num_particles)(guide) + ).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) + breakpoint() + + with funsor.terms.eager: + costs = model_terms["log_factors"] + [ + -f for f in guide_terms["log_factors"] + ] - costs = model_terms["log_factors"] + [-f for f in guide_terms["log_factors"]] - plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"] + # compute expected cost + # Cf. pyro.infer.util.Dice.compute_expectation() + # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212 + # TODO Replace this with funsor.Expectation + plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] + # compute the marginal logq in the guide corresponding to each cost term + targets = dict() + for cost in costs: + input_vars = frozenset(cost.inputs) + if input_vars not in targets: + targets[input_vars] = funsor.Tensor( + funsor.ops.new_zeros( + funsor.tensor.get_default_prototype(), + tuple(v.size for v in cost.inputs.values()), + ), + cost.inputs, + cost.dtype, + ) + with AdjointTape() as tape: + logzq = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + guide_terms["log_measures"] + list(targets.values()), + plates=plate_vars, + eliminate=(plate_vars | guide_terms["measure_vars"]), + ) + marginals = tape.adjoint( + funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) + ) + # finally, integrate out guide variables in the elbo and all plates + elbo = to_funsor(0, output=funsor.Real) + for cost in costs: + target = targets[frozenset(cost.inputs)] + logzq_local = marginals[target].reduce( + funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars + ) + log_prob = marginals[target] - logzq_local + elbo_term = funsor.Integrate( + log_prob, + cost, + guide_terms["measure_vars"] & frozenset(log_prob.inputs), + ) + elbo += elbo_term.reduce( + funsor.ops.add, plate_vars & frozenset(cost.inputs) + ) - elbo = to_funsor(0.0) - for cost in costs: - elbo += cost.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) + # elbo = to_funsor(0.0) + # for cost in costs: + # elbo += cost.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) - return -to_data(elbo) + # return -to_data(elbo) + # evaluate the elbo, using memoize to share tensor computation where possible + with funsor.interpretations.memoize(): + result = -to_data(apply_optimizer(elbo)) + return result + # return -to_data(apply_optimizer(elbo)) class JitTrace_ELBO(Jit_ELBO, Trace_ELBO): diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index 0ea46fc913..b71bc46c0c 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -27,43 +27,31 @@ @pytest.mark.parametrize( - "reparameterized,has_rsample", - [(True, None), (True, False), (True, True), (False, None)], - ids=["reparam", "reparam-False", "reparam-True", "nonreparam"], -) -@pytest.mark.parametrize( - "Elbo", + "Elbo,backend", [ - "Trace_ELBO", - "TraceEnum_ELBO", + ("TraceEnum_ELBO", "pyro"), + ("Trace_ELBO", "contrib.funsor"), ], ) -@pytest.mark.parametrize( - "backend", - [ - "pyro", - "contrib.funsor", - ], -) -def test_particle_gradient(Elbo, reparameterized, has_rsample, backend): +def test_particle_gradient(Elbo, backend): with pyro_backend(backend): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) - Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal + # Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal def model(): with pyro.plate("data", len(data)) as ind: x = data[ind] - z = pyro.sample("z", Normal(0, 1)) - pyro.sample("x", Normal(z, 1), obs=x) + z = pyro.sample("z", dist.Poisson(3)) + pyro.sample("x", dist.Normal(z, 1), obs=x) def guide(): - scale = pyro.param("scale", lambda: torch.tensor([1.0])) + # scale = pyro.param("scale", lambda: torch.tensor([1.0])) with pyro.plate("data", len(data)): - loc = pyro.param("loc", lambda: torch.zeros(len(data)), event_dim=0) - z_dist = Normal(loc, scale) - if has_rsample is not None: - z_dist.has_rsample_(has_rsample) + rate = pyro.param("rate", lambda: torch.tensor([3.5, 1.5]), event_dim=0) + z_dist = dist.Poisson(rate) + # if has_rsample is not None: + # z_dist.has_rsample_(has_rsample) pyro.sample("z", z_dist) elbo = getattr(infer, Elbo)( @@ -74,7 +62,7 @@ def guide(): # Elbo gradient estimator pyro.set_rng_seed(0) - elbo.loss_and_grads(model, guide) + loss = elbo.loss_and_grads(model, guide) params = dict(pyro.get_param_store().named_parameters()) actual_grads = { name: param.grad.detach().cpu() for name, param in params.items() @@ -88,39 +76,17 @@ def guide(): model_tr.compute_log_prob() x = data z = guide_tr.nodes["z"]["value"].data - loc = pyro.param("loc").data - scale = pyro.param("scale").data - - # expected grads - if reparameterized and has_rsample is not False: - # pathwise gradient estimator - expected_grads = { - "scale": -(-z * (z - loc) + (x - z) * (z - loc) + 1).sum( - 0, keepdim=True - ) - / scale, - "loc": -(-z + (x - z)), - } - else: - # score function gradient estimator - elbo = ( - model_tr.nodes["x"]["log_prob"].data - + model_tr.nodes["z"]["log_prob"].data - - guide_tr.nodes["z"]["log_prob"].data - ) - dlogq_dloc = (z - loc) / scale ** 2 - dlogq_dscale = (z - loc) ** 2 / scale ** 3 - 1 / scale - if Elbo == "TraceEnum_ELBO": - expected_grads = { - "scale": -(dlogq_dscale * elbo - dlogq_dscale).sum(0, keepdim=True), - "loc": -(dlogq_dloc * elbo - dlogq_dloc), - } - elif Elbo == "Trace_ELBO": - # expected value of dlogq_dscale and dlogq_dloc is zero - expected_grads = { - "scale": -(dlogq_dscale * elbo).sum(0, keepdim=True), - "loc": -(dlogq_dloc * elbo), - } + rate = pyro.param("rate").data + + loss_i = ( + model_tr.nodes["x"]["log_prob"].data + + model_tr.nodes["z"]["log_prob"].data + - guide_tr.nodes["z"]["log_prob"].data + ) + dlogq_drate = z / rate - 1 + expected_grads = { + "rate": -(dlogq_drate * loss_i - dlogq_drate), + } for name in sorted(params): logger.info("expected {} = {}".format(name, expected_grads[name])) From d551fa2930b46841c42eea4892a62da08035d683 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 16 Jul 2021 01:54:41 -0400 Subject: [PATCH 05/53] lint --- pyro/contrib/funsor/infer/trace_elbo.py | 10 +--------- tests/contrib/funsor/test_gradient.py | 4 +--- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 77fc036751..559c8e0871 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -29,7 +29,6 @@ def differentiable_loss(self, model, guide, *args, **kwargs): model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) - breakpoint() with funsor.terms.eager: costs = model_terms["log_factors"] + [ @@ -82,16 +81,9 @@ def differentiable_loss(self, model, guide, *args, **kwargs): funsor.ops.add, plate_vars & frozenset(cost.inputs) ) - # elbo = to_funsor(0.0) - # for cost in costs: - # elbo += cost.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) - - # return -to_data(elbo) # evaluate the elbo, using memoize to share tensor computation where possible with funsor.interpretations.memoize(): - result = -to_data(apply_optimizer(elbo)) - return result - # return -to_data(apply_optimizer(elbo)) + return -to_data(apply_optimizer(elbo)) class JitTrace_ELBO(Jit_ELBO, Trace_ELBO): diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index b71bc46c0c..a7f0eae680 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -6,7 +6,6 @@ import pytest import torch -from pyro.distributions.testing import fakes from tests.common import assert_equal # put all funsor-related imports here, so test collection works without funsor @@ -62,7 +61,7 @@ def guide(): # Elbo gradient estimator pyro.set_rng_seed(0) - loss = elbo.loss_and_grads(model, guide) + elbo.loss_and_grads(model, guide) params = dict(pyro.get_param_store().named_parameters()) actual_grads = { name: param.grad.detach().cpu() for name, param in params.items() @@ -74,7 +73,6 @@ def guide(): model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() guide_tr.compute_log_prob() model_tr.compute_log_prob() - x = data z = guide_tr.nodes["z"]["value"].data rate = pyro.param("rate").data From b68bb3fc53cbb356c1b11e26484620942a474086 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 21 Jul 2021 01:30:39 -0400 Subject: [PATCH 06/53] use constant funsor --- .../funsor/handlers/trace_messenger.py | 2 ++ pyro/contrib/funsor/infer/traceenum_elbo.py | 22 ++++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pyro/contrib/funsor/handlers/trace_messenger.py b/pyro/contrib/funsor/handlers/trace_messenger.py index f517967079..63f10eb5f7 100644 --- a/pyro/contrib/funsor/handlers/trace_messenger.py +++ b/pyro/contrib/funsor/handlers/trace_messenger.py @@ -3,6 +3,7 @@ import funsor import torch +from funsor.torch import MetadataTensor from pyro.contrib.funsor.handlers.primitives import to_funsor from pyro.contrib.funsor.handlers.runtime import _DIM_STACK @@ -38,6 +39,7 @@ def _pyro_post_sample(self, msg): msg["funsor"] = {} if isinstance(msg["fn"], _Subsample): return super()._pyro_post_sample(msg) + # msg["value"] = MetadataTensor(msg["value"], metadata=frozenset({msg["name"]})) if self.pack_online: if "fn" not in msg["funsor"]: fn_masked = _mask_fn(msg["fn"], msg["mask"]) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 7915f7e101..c9e72c1386 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -3,8 +3,10 @@ import contextlib +import torch import funsor from funsor.adjoint import AdjointTape +from funsor.constant import Constant from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace @@ -186,7 +188,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): model_terms = terms_from_trace(model_tr) # build up a lazy expression for the elbo - with funsor.terms.lazy: + with funsor.terms.eager: # identify and contract out auxiliary variables in the model with partial_sum_product contracted_factors, uncontracted_factors = [], [] for f in model_terms["log_factors"]: @@ -220,15 +222,19 @@ def differentiable_loss(self, model, guide, *args, **kwargs): for cost in costs: input_vars = frozenset(cost.inputs) if input_vars not in targets: - targets[input_vars] = funsor.Tensor( - funsor.ops.new_zeros( - funsor.tensor.get_default_prototype(), - tuple(v.size for v in cost.inputs.values()), - ), - cost.inputs, - cost.dtype, + targets[input_vars] = Constant( + cost.input_vars, funsor.Tensor(torch.tensor(0)) ) + # targets[input_vars] = funsor.Tensor( + # funsor.ops.new_zeros( + # funsor.tensor.get_default_prototype(), + # tuple(v.size for v in cost.inputs.values()), + # ), + # cost.inputs, + # cost.dtype, + # ) with AdjointTape() as tape: + # breakpoint() logzq = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, From bfb13bf35f3636df3ba23dc52e3cc4543b476cce Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 28 Jul 2021 03:36:22 -0400 Subject: [PATCH 07/53] working version --- .../contrib/funsor/handlers/enum_messenger.py | 7 ++++- .../funsor/handlers/trace_messenger.py | 5 ++-- pyro/contrib/funsor/infer/trace_elbo.py | 29 +++++++++++++------ pyro/contrib/funsor/infer/traceenum_elbo.py | 5 ++-- tests/contrib/funsor/test_gradient.py | 6 ++-- 5 files changed, 33 insertions(+), 19 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 98e9a7ee33..30edf94c48 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -11,6 +11,7 @@ import funsor import torch +from funsor.torch.provenance import ProvenanceTensor import pyro.poutine.runtime import pyro.poutine.util @@ -196,7 +197,11 @@ def _pyro_sample(self, msg): msg["name"], expand=msg["infer"].get("expand", False), ) - msg["value"] = to_data(msg["funsor"]["value"]) + # msg["value"] = to_data(msg["funsor"]["value"]) + msg["value"] = ProvenanceTensor( + to_data(msg["funsor"]["value"]), + provenance=frozenset({(msg["name"], funsor.Real)}) + ) msg["done"] = True diff --git a/pyro/contrib/funsor/handlers/trace_messenger.py b/pyro/contrib/funsor/handlers/trace_messenger.py index 63f10eb5f7..9d4fb08adf 100644 --- a/pyro/contrib/funsor/handlers/trace_messenger.py +++ b/pyro/contrib/funsor/handlers/trace_messenger.py @@ -3,7 +3,6 @@ import funsor import torch -from funsor.torch import MetadataTensor from pyro.contrib.funsor.handlers.primitives import to_funsor from pyro.contrib.funsor.handlers.runtime import _DIM_STACK @@ -39,14 +38,14 @@ def _pyro_post_sample(self, msg): msg["funsor"] = {} if isinstance(msg["fn"], _Subsample): return super()._pyro_post_sample(msg) - # msg["value"] = MetadataTensor(msg["value"], metadata=frozenset({msg["name"]})) if self.pack_online: if "fn" not in msg["funsor"]: fn_masked = _mask_fn(msg["fn"], msg["mask"]) msg["funsor"]["fn"] = to_funsor(fn_masked, funsor.Real)( value=msg["name"] ) - if "value" not in msg["funsor"]: + # if "value" not in msg["funsor"]: + if True: # value_output = funsor.Reals[getattr(msg["fn"], "event_shape", ())] msg["funsor"]["value"] = to_funsor( msg["value"], msg["funsor"]["fn"].inputs[msg["name"]] diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 559c8e0871..661453b9e9 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -4,7 +4,9 @@ import contextlib import funsor +import torch from funsor.adjoint import AdjointTape +from funsor.constant import Constant from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace @@ -45,14 +47,18 @@ def differentiable_loss(self, model, guide, *args, **kwargs): for cost in costs: input_vars = frozenset(cost.inputs) if input_vars not in targets: - targets[input_vars] = funsor.Tensor( - funsor.ops.new_zeros( - funsor.tensor.get_default_prototype(), - tuple(v.size for v in cost.inputs.values()), - ), - cost.inputs, - cost.dtype, + const_inputs = tuple((v.name, v.output) for v in cost.input_vars) + targets[input_vars] = Constant( + const_inputs, funsor.Tensor(torch.tensor(0)) ) + # targets[input_vars] = funsor.Tensor( + # funsor.ops.new_zeros( + # funsor.tensor.get_default_prototype(), + # tuple(v.size for v in cost.inputs.values()), + # ), + # cost.inputs, + # cost.dtype, + # ) with AdjointTape() as tape: logzq = funsor.sum_product.sum_product( funsor.ops.logaddexp, @@ -68,10 +74,15 @@ def differentiable_loss(self, model, guide, *args, **kwargs): elbo = to_funsor(0, output=funsor.Real) for cost in costs: target = targets[frozenset(cost.inputs)] - logzq_local = marginals[target].reduce( + # logzq_local = marginals[target].reduce( + # funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars + # ) + logzq_local = guide_terms["log_measures"][0].reduce( funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars ) - log_prob = marginals[target] - logzq_local + # breakpoint() + log_prob = marginals[target] - logzq + logzq_local + # log_prob = guide_terms["log_measures"][0] elbo_term = funsor.Integrate( log_prob, cost, diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index c9e72c1386..82303513e0 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -3,8 +3,8 @@ import contextlib -import torch import funsor +import torch from funsor.adjoint import AdjointTape from funsor.constant import Constant @@ -225,7 +225,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): targets[input_vars] = Constant( cost.input_vars, funsor.Tensor(torch.tensor(0)) ) - # targets[input_vars] = funsor.Tensor( + # targets[input_vars] = funsor.Tensor( # funsor.ops.new_zeros( # funsor.tensor.get_default_prototype(), # tuple(v.size for v in cost.inputs.values()), @@ -234,7 +234,6 @@ def differentiable_loss(self, model, guide, *args, **kwargs): # cost.dtype, # ) with AdjointTape() as tape: - # breakpoint() logzq = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index a7f0eae680..41d44ee50a 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -28,7 +28,7 @@ @pytest.mark.parametrize( "Elbo,backend", [ - ("TraceEnum_ELBO", "pyro"), + # ("TraceEnum_ELBO", "pyro"), ("Trace_ELBO", "contrib.funsor"), ], ) @@ -41,14 +41,14 @@ def test_particle_gradient(Elbo, backend): def model(): with pyro.plate("data", len(data)) as ind: x = data[ind] - z = pyro.sample("z", dist.Poisson(3)) + z = pyro.sample("z", dist.Poisson(3, validate_args=False)) pyro.sample("x", dist.Normal(z, 1), obs=x) def guide(): # scale = pyro.param("scale", lambda: torch.tensor([1.0])) with pyro.plate("data", len(data)): rate = pyro.param("rate", lambda: torch.tensor([3.5, 1.5]), event_dim=0) - z_dist = dist.Poisson(rate) + z_dist = dist.Poisson(rate, validate_args=False) # if has_rsample is not None: # z_dist.has_rsample_(has_rsample) pyro.sample("z", z_dist) From ca1a1fe6e824e60967e76d45cbb915f6905f5bf3 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 28 Jul 2021 19:54:53 -0400 Subject: [PATCH 08/53] pass second test --- .../contrib/funsor/handlers/enum_messenger.py | 3 +- pyro/contrib/funsor/infer/trace_elbo.py | 35 +++--- pyro/contrib/funsor/infer/traceenum_elbo.py | 2 + tests/contrib/funsor/test_gradient.py | 104 +++++++++++++++--- 4 files changed, 111 insertions(+), 33 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 30edf94c48..c5de6459a0 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -191,6 +191,7 @@ def _pyro_sample(self, msg): unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)( value=msg["name"] ) + msg["funsor"]["unsampled_log_measure"] = unsampled_log_measure msg["funsor"]["log_measure"] = enumerate_site(unsampled_log_measure, msg) msg["funsor"]["value"] = _get_support_value( msg["funsor"]["log_measure"], @@ -200,7 +201,7 @@ def _pyro_sample(self, msg): # msg["value"] = to_data(msg["funsor"]["value"]) msg["value"] = ProvenanceTensor( to_data(msg["funsor"]["value"]), - provenance=frozenset({(msg["name"], funsor.Real)}) + provenance=frozenset({(msg["name"], funsor.Real)}), ) msg["done"] = True diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 661453b9e9..247ef71b2c 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -2,11 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +from collections import OrderedDict import funsor import torch from funsor.adjoint import AdjointTape from funsor.constant import Constant +from funsor.montecarlo import MonteCarlo from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace @@ -51,19 +53,11 @@ def differentiable_loss(self, model, guide, *args, **kwargs): targets[input_vars] = Constant( const_inputs, funsor.Tensor(torch.tensor(0)) ) - # targets[input_vars] = funsor.Tensor( - # funsor.ops.new_zeros( - # funsor.tensor.get_default_prototype(), - # tuple(v.size for v in cost.inputs.values()), - # ), - # cost.inputs, - # cost.dtype, - # ) with AdjointTape() as tape: logzq = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, - guide_terms["log_measures"] + list(targets.values()), + guide_terms["unsampled_log_measures"] + list(targets.values()), plates=plate_vars, eliminate=(plate_vars | guide_terms["measure_vars"]), ) @@ -74,20 +68,23 @@ def differentiable_loss(self, model, guide, *args, **kwargs): elbo = to_funsor(0, output=funsor.Real) for cost in costs: target = targets[frozenset(cost.inputs)] + marginal = marginals[target] # logzq_local = marginals[target].reduce( # funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars # ) - logzq_local = guide_terms["log_measures"][0].reduce( - funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars - ) - # breakpoint() - log_prob = marginals[target] - logzq + logzq_local + # logzq_local = guide_terms["log_measures"][0].reduce( + # funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars + # ) + # log_prob = marginal.sample(frozenset(cost.inputs)-plate_vars) # - logzq + logzq_local # log_prob = guide_terms["log_measures"][0] - elbo_term = funsor.Integrate( - log_prob, - cost, - guide_terms["measure_vars"] & frozenset(log_prob.inputs), - ) + measure_vars = frozenset(cost.inputs) - plate_vars + _raw_value = {var: guide_tr.nodes[var]["value"]._t for var in measure_vars} + with MonteCarlo(raw_value=_raw_value): + elbo_term = funsor.Integrate( + marginal, + cost, + guide_terms["measure_vars"] & frozenset(marginal.inputs), + ) elbo += elbo_term.reduce( funsor.ops.add, plate_vars & frozenset(cost.inputs) ) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 82303513e0..acaa8e26e9 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -33,6 +33,7 @@ def terms_from_trace(tr): terms = { "log_factors": [], "log_measures": [], + "unsampled_log_measures": [], "scale": to_funsor(1.0), "plate_vars": frozenset(), "measure_vars": frozenset(), @@ -64,6 +65,7 @@ def terms_from_trace(tr): # grab the log-measure, found only at sites that are not replayed or observed if node["funsor"].get("log_measure", None) is not None: terms["log_measures"].append(node["funsor"]["log_measure"]) + terms["unsampled_log_measures"].append(node["funsor"]["unsampled_log_measure"]) # sum (measure) variables: the fresh non-plate variables at a site terms["measure_vars"] |= ( frozenset(node["funsor"]["value"].inputs) | {name} diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index 41d44ee50a..5ae6c16acb 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -5,6 +5,7 @@ import pytest import torch +from torch.distributions import constraints from tests.common import assert_equal @@ -28,30 +29,24 @@ @pytest.mark.parametrize( "Elbo,backend", [ - # ("TraceEnum_ELBO", "pyro"), + ("TraceEnum_ELBO", "pyro"), ("Trace_ELBO", "contrib.funsor"), ], ) -def test_particle_gradient(Elbo, backend): +def test_particle_gradient_0(Elbo, backend): with pyro_backend(backend): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) - # Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal def model(): - with pyro.plate("data", len(data)) as ind: - x = data[ind] - z = pyro.sample("z", dist.Poisson(3, validate_args=False)) - pyro.sample("x", dist.Normal(z, 1), obs=x) + with pyro.plate("data", len(data)): + z = pyro.sample("z", dist.Poisson(3)) + pyro.sample("x", dist.Normal(z, 1), obs=data) def guide(): - # scale = pyro.param("scale", lambda: torch.tensor([1.0])) with pyro.plate("data", len(data)): - rate = pyro.param("rate", lambda: torch.tensor([3.5, 1.5]), event_dim=0) - z_dist = dist.Poisson(rate, validate_args=False) - # if has_rsample is not None: - # z_dist.has_rsample_(has_rsample) - pyro.sample("z", z_dist) + rate = pyro.param("rate", lambda: torch.tensor([3.5, 1.5])) + pyro.sample("z", dist.Poisson(rate)) elbo = getattr(infer, Elbo)( max_plate_nesting=1, # set this to ensure rng agrees across runs @@ -91,3 +86,86 @@ def guide(): logger.info("actual {} = {}".format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=1e-4) + + +@pytest.mark.parametrize( + "Elbo,backend", + [ + ("TraceEnum_ELBO", "pyro"), + ("Trace_ELBO", "contrib.funsor"), + ], +) +def test_particle_gradient_1(Elbo, backend): + with pyro_backend(backend): + pyro.clear_param_store() + data = torch.tensor([-0.5, 2.0]) + + def model(): + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + rate = torch.tensor([2.0, 3.0]) + b = pyro.sample("b", dist.Poisson(rate[a.long()])) + pyro.sample("c", dist.Normal(b, 1), obs=data) + + def guide(): + prob = pyro.param( + "prob", + lambda: torch.tensor(0.5), + # constraint=constraints.unit_interval, + ) + a = pyro.sample("a", dist.Bernoulli(prob)) + with pyro.plate("data", len(data)): + rate = pyro.param( + "rate", lambda: torch.tensor([[3.5, 1.5], [0.5, 2.5]]) + ) + pyro.sample("b", dist.Poisson(rate[a.long()])) + + elbo = getattr(infer, Elbo)( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Elbo gradient estimator + pyro.set_rng_seed(0) + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = { + name: param.grad.detach().cpu() for name, param in params.items() + } + + # capture sample values and log_probs + pyro.set_rng_seed(0) + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + a = guide_tr.nodes["a"]["value"].data + b = guide_tr.nodes["b"]["value"].data + prob = pyro.param("prob").data + rate = pyro.param("rate").data + + dlogqa_dprob = (a - prob) / (prob * (1 - prob)) + dlogqb_drate = b / rate[a.long()] - 1 + + loss_a = ( + model_tr.nodes["a"]["log_prob"].data + - guide_tr.nodes["a"]["log_prob"].data + ) + loss_bc = ( + model_tr.nodes["b"]["log_prob"].data + + model_tr.nodes["c"]["log_prob"].data + - guide_tr.nodes["b"]["log_prob"].data + ) + # dlogq_drate = z / rate - 1 + expected_grads = { + "prob": -(dlogqa_dprob * (loss_a + loss_bc.sum()) - dlogqa_dprob), + "rate": -(dlogqb_drate * (loss_bc) - dlogqb_drate), + } + actual_grads["rate"] = actual_grads["rate"][a.long()] + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) From 6d6a9ed5e80faeacb02eb075d6cda3821e29274f Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 29 Jul 2021 01:40:33 -0400 Subject: [PATCH 09/53] clean up trace_elbo --- .../contrib/funsor/handlers/enum_messenger.py | 11 +++-- .../funsor/handlers/trace_messenger.py | 3 +- pyro/contrib/funsor/infer/trace_elbo.py | 40 +++++++------------ pyro/contrib/funsor/infer/traceenum_elbo.py | 25 +++++------- tests/contrib/funsor/test_gradient.py | 4 +- 5 files changed, 33 insertions(+), 50 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index c5de6459a0..6baa731638 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -11,7 +11,6 @@ import funsor import torch -from funsor.torch.provenance import ProvenanceTensor import pyro.poutine.runtime import pyro.poutine.util @@ -193,16 +192,16 @@ def _pyro_sample(self, msg): ) msg["funsor"]["unsampled_log_measure"] = unsampled_log_measure msg["funsor"]["log_measure"] = enumerate_site(unsampled_log_measure, msg) - msg["funsor"]["value"] = _get_support_value( + support_value = _get_support_value( msg["funsor"]["log_measure"], msg["name"], expand=msg["infer"].get("expand", False), ) - # msg["value"] = to_data(msg["funsor"]["value"]) - msg["value"] = ProvenanceTensor( - to_data(msg["funsor"]["value"]), - provenance=frozenset({(msg["name"], funsor.Real)}), + msg["funsor"]["value"] = funsor.constant.Constant( + OrderedDict(((msg["name"], funsor.Real),)), + support_value, ) + msg["value"] = to_data(msg["funsor"]["value"]) msg["done"] = True diff --git a/pyro/contrib/funsor/handlers/trace_messenger.py b/pyro/contrib/funsor/handlers/trace_messenger.py index 9d4fb08adf..f517967079 100644 --- a/pyro/contrib/funsor/handlers/trace_messenger.py +++ b/pyro/contrib/funsor/handlers/trace_messenger.py @@ -44,8 +44,7 @@ def _pyro_post_sample(self, msg): msg["funsor"]["fn"] = to_funsor(fn_masked, funsor.Real)( value=msg["name"] ) - # if "value" not in msg["funsor"]: - if True: + if "value" not in msg["funsor"]: # value_output = funsor.Reals[getattr(msg["fn"], "event_shape", ())] msg["funsor"]["value"] = to_funsor( msg["value"], msg["funsor"]["fn"].inputs[msg["name"]] diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 247ef71b2c..c81e3aa21a 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib -from collections import OrderedDict import funsor import torch @@ -35,23 +34,19 @@ def differentiable_loss(self, model, guide, *args, **kwargs): guide_terms = terms_from_trace(guide_tr) with funsor.terms.eager: + # cost terms costs = model_terms["log_factors"] + [ -f for f in guide_terms["log_factors"] ] - # compute expected cost - # Cf. pyro.infer.util.Dice.compute_expectation() - # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212 - # TODO Replace this with funsor.Expectation plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] - # compute the marginal logq in the guide corresponding to each cost term + # compute log_measures corresponding to each cost term + # the goal is to achieve full Rao-Blackwellization targets = dict() for cost in costs: - input_vars = frozenset(cost.inputs) - if input_vars not in targets: - const_inputs = tuple((v.name, v.output) for v in cost.input_vars) - targets[input_vars] = Constant( - const_inputs, funsor.Tensor(torch.tensor(0)) + if cost.input_vars not in targets: + targets[cost.input_vars] = Constant( + cost.inputs, funsor.Tensor(torch.tensor(0)) ) with AdjointTape() as tape: logzq = funsor.sum_product.sum_product( @@ -61,29 +56,24 @@ def differentiable_loss(self, model, guide, *args, **kwargs): plates=plate_vars, eliminate=(plate_vars | guide_terms["measure_vars"]), ) - marginals = tape.adjoint( + log_measures = tape.adjoint( funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) ) # finally, integrate out guide variables in the elbo and all plates + # Dice factors are included in MonteCarlo samples elbo = to_funsor(0, output=funsor.Real) for cost in costs: - target = targets[frozenset(cost.inputs)] - marginal = marginals[target] - # logzq_local = marginals[target].reduce( - # funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars - # ) - # logzq_local = guide_terms["log_measures"][0].reduce( - # funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars - # ) - # log_prob = marginal.sample(frozenset(cost.inputs)-plate_vars) # - logzq + logzq_local - # log_prob = guide_terms["log_measures"][0] + target = targets[cost.input_vars] + log_measure = log_measures[target] measure_vars = frozenset(cost.inputs) - plate_vars - _raw_value = {var: guide_tr.nodes[var]["value"]._t for var in measure_vars} + _raw_value = { + var: guide_tr.nodes[var]["value"]._t for var in measure_vars + } with MonteCarlo(raw_value=_raw_value): elbo_term = funsor.Integrate( - marginal, + log_measure, cost, - guide_terms["measure_vars"] & frozenset(marginal.inputs), + measure_vars, ) elbo += elbo_term.reduce( funsor.ops.add, plate_vars & frozenset(cost.inputs) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index acaa8e26e9..4ab958c4cf 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -4,9 +4,7 @@ import contextlib import funsor -import torch from funsor.adjoint import AdjointTape -from funsor.constant import Constant from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace @@ -65,7 +63,9 @@ def terms_from_trace(tr): # grab the log-measure, found only at sites that are not replayed or observed if node["funsor"].get("log_measure", None) is not None: terms["log_measures"].append(node["funsor"]["log_measure"]) - terms["unsampled_log_measures"].append(node["funsor"]["unsampled_log_measure"]) + terms["unsampled_log_measures"].append( + node["funsor"]["unsampled_log_measure"] + ) # sum (measure) variables: the fresh non-plate variables at a site terms["measure_vars"] |= ( frozenset(node["funsor"]["value"].inputs) | {name} @@ -190,7 +190,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): model_terms = terms_from_trace(model_tr) # build up a lazy expression for the elbo - with funsor.terms.eager: + with funsor.terms.lazy: # identify and contract out auxiliary variables in the model with partial_sum_product contracted_factors, uncontracted_factors = [], [] for f in model_terms["log_factors"]: @@ -224,17 +224,14 @@ def differentiable_loss(self, model, guide, *args, **kwargs): for cost in costs: input_vars = frozenset(cost.inputs) if input_vars not in targets: - targets[input_vars] = Constant( - cost.input_vars, funsor.Tensor(torch.tensor(0)) + targets[input_vars] = funsor.Tensor( + funsor.ops.new_zeros( + funsor.tensor.get_default_prototype(), + tuple(v.size for v in cost.inputs.values()), + ), + cost.inputs, + cost.dtype, ) - # targets[input_vars] = funsor.Tensor( - # funsor.ops.new_zeros( - # funsor.tensor.get_default_prototype(), - # tuple(v.size for v in cost.inputs.values()), - # ), - # cost.inputs, - # cost.dtype, - # ) with AdjointTape() as tape: logzq = funsor.sum_product.sum_product( funsor.ops.logaddexp, diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index 5ae6c16acb..dc54bac38d 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -5,7 +5,6 @@ import pytest import torch -from torch.distributions import constraints from tests.common import assert_equal @@ -149,8 +148,7 @@ def guide(): dlogqb_drate = b / rate[a.long()] - 1 loss_a = ( - model_tr.nodes["a"]["log_prob"].data - - guide_tr.nodes["a"]["log_prob"].data + model_tr.nodes["a"]["log_prob"].data - guide_tr.nodes["a"]["log_prob"].data ) loss_bc = ( model_tr.nodes["b"]["log_prob"].data From 0f23b4203bcd1ea3edc3fbeb0243bc2063c267da Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 8 Aug 2021 19:41:00 -0400 Subject: [PATCH 10/53] add another test --- .../contrib/funsor/handlers/enum_messenger.py | 1 - pyro/contrib/funsor/infer/trace_elbo.py | 60 +++--- pyro/contrib/funsor/infer/traceenum_elbo.py | 4 - tests/contrib/funsor/test_gradient.py | 187 +++++++++++++++++- 4 files changed, 209 insertions(+), 43 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 6baa731638..df6e02571b 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -190,7 +190,6 @@ def _pyro_sample(self, msg): unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)( value=msg["name"] ) - msg["funsor"]["unsampled_log_measure"] = unsampled_log_measure msg["funsor"]["log_measure"] = enumerate_site(unsampled_log_measure, msg) support_value = _get_support_value( msg["funsor"]["log_measure"], diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index c81e3aa21a..1226898fc1 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -7,7 +7,6 @@ import torch from funsor.adjoint import AdjointTape from funsor.constant import Constant -from funsor.montecarlo import MonteCarlo from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace @@ -33,48 +32,41 @@ def differentiable_loss(self, model, guide, *args, **kwargs): model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) - with funsor.terms.eager: - # cost terms - costs = model_terms["log_factors"] + [ - -f for f in guide_terms["log_factors"] - ] + # cost terms + costs = model_terms["log_factors"] + [-f for f in guide_terms["log_factors"]] - plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] - # compute log_measures corresponding to each cost term - # the goal is to achieve full Rao-Blackwellization - targets = dict() - for cost in costs: - if cost.input_vars not in targets: - targets[cost.input_vars] = Constant( - cost.inputs, funsor.Tensor(torch.tensor(0)) - ) - with AdjointTape() as tape: - logzq = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, - guide_terms["unsampled_log_measures"] + list(targets.values()), - plates=plate_vars, - eliminate=(plate_vars | guide_terms["measure_vars"]), + plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] + # compute log_measures corresponding to each cost term + # the goal is to achieve fine-grained Rao-Blackwellization + targets = dict() + for cost in costs: + if cost.input_vars not in targets: + targets[cost.input_vars] = Constant( + cost.inputs, funsor.Tensor(torch.tensor(0.0)) ) - log_measures = tape.adjoint( - funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) + with AdjointTape() as tape: + logzq = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + guide_terms["log_measures"] + list(targets.values()), + plates=plate_vars, + eliminate=(plate_vars | guide_terms["measure_vars"]), ) + log_measures = tape.adjoint( + funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) + ) + with funsor.terms.eager: # finally, integrate out guide variables in the elbo and all plates - # Dice factors are included in MonteCarlo samples elbo = to_funsor(0, output=funsor.Real) for cost in costs: target = targets[cost.input_vars] log_measure = log_measures[target] measure_vars = frozenset(cost.inputs) - plate_vars - _raw_value = { - var: guide_tr.nodes[var]["value"]._t for var in measure_vars - } - with MonteCarlo(raw_value=_raw_value): - elbo_term = funsor.Integrate( - log_measure, - cost, - measure_vars, - ) + elbo_term = funsor.Integrate( + log_measure, + cost, + measure_vars, + ) elbo += elbo_term.reduce( funsor.ops.add, plate_vars & frozenset(cost.inputs) ) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 4ab958c4cf..7915f7e101 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -31,7 +31,6 @@ def terms_from_trace(tr): terms = { "log_factors": [], "log_measures": [], - "unsampled_log_measures": [], "scale": to_funsor(1.0), "plate_vars": frozenset(), "measure_vars": frozenset(), @@ -63,9 +62,6 @@ def terms_from_trace(tr): # grab the log-measure, found only at sites that are not replayed or observed if node["funsor"].get("log_measure", None) is not None: terms["log_measures"].append(node["funsor"]["log_measure"]) - terms["unsampled_log_measures"].append( - node["funsor"]["unsampled_log_measure"] - ) # sum (measure) variables: the fresh non-plate variables at a site terms["measure_vars"] |= ( frozenset(node["funsor"]["value"].inputs) | {name} diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index dc54bac38d..13949826e5 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -6,6 +6,7 @@ import pytest import torch +from pyro.ops.indexing import Vindex from tests.common import assert_equal # put all funsor-related imports here, so test collection works without funsor @@ -22,8 +23,6 @@ logger = logging.getLogger(__name__) -# _PYRO_BACKEND = os.environ.get("TEST_ENUM_PYRO_BACKEND", "contrib.funsor") - @pytest.mark.parametrize( "Elbo,backend", @@ -110,7 +109,6 @@ def guide(): prob = pyro.param( "prob", lambda: torch.tensor(0.5), - # constraint=constraints.unit_interval, ) a = pyro.sample("a", dist.Bernoulli(prob)) with pyro.plate("data", len(data)): @@ -155,7 +153,6 @@ def guide(): + model_tr.nodes["c"]["log_prob"].data - guide_tr.nodes["b"]["log_prob"].data ) - # dlogq_drate = z / rate - 1 expected_grads = { "prob": -(dlogqa_dprob * (loss_a + loss_bc.sum()) - dlogqa_dprob), "rate": -(dlogqb_drate * (loss_bc) - dlogqb_drate), @@ -167,3 +164,185 @@ def guide(): logger.info("actual {} = {}".format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=1e-4) + + +@pytest.mark.parametrize( + "Elbo,backend", + [ + ("TraceEnum_ELBO", "pyro"), + ("Trace_ELBO", "contrib.funsor"), + ("TraceGraph_ELBO", "pyro"), + ], +) +def test_particle_gradient_2(Elbo, backend): + with pyro_backend(backend): + pyro.clear_param_store() + data = torch.tensor([0.0, 1.0]) + + def model(): + prob_b = torch.tensor([0.3, 0.4]) + prob_c = torch.tensor([0.5, 0.6]) + prob_d = torch.tensor([0.2, 0.3]) + prob_e = torch.tensor([0.5, 0.1]) + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) + c = pyro.sample("c", dist.Bernoulli(prob_c[b.long()])) + pyro.sample("d", dist.Bernoulli(prob_d[b.long()])) + pyro.sample("e", dist.Bernoulli(prob_e[c.long()]), obs=data) + + def guide(): + prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) + prob_b = pyro.param("prob_b", lambda: torch.tensor([0.4, 0.3])) + prob_c = pyro.param( + "prob_c", lambda: torch.tensor([[0.3, 0.8], [0.2, 0.5]]) + ) + prob_d = pyro.param( + "prob_d", lambda: torch.tensor([[0.2, 0.9], [0.1, 0.4]]) + ) + a = pyro.sample("a", dist.Bernoulli(prob_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) + pyro.sample("c", dist.Bernoulli(Vindex(prob_c)[b.long(), idx])) + pyro.sample("d", dist.Bernoulli(Vindex(prob_d)[b.long(), idx])) + + elbo = getattr(infer, Elbo)( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Elbo gradient estimator + pyro.set_rng_seed(0) + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = { + name: param.grad.detach().cpu() for name, param in params.items() + } + + # capture sample values and log_probs + pyro.set_rng_seed(0) + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + + a = guide_tr.nodes["a"]["value"].data + b = guide_tr.nodes["b"]["value"].data + c = guide_tr.nodes["c"]["value"].data + + proba = pyro.param("prob_a").data + probb = pyro.param("prob_b").data + probc = pyro.param("prob_c").data + probd = pyro.param("prob_d").data + + logpa = model_tr.nodes["a"]["log_prob"].data + logpba = model_tr.nodes["b"]["log_prob"].data + logpcba = model_tr.nodes["c"]["log_prob"].data + logpdba = model_tr.nodes["d"]["log_prob"].data + logpecba = model_tr.nodes["e"]["log_prob"].data + + logqa = guide_tr.nodes["a"]["log_prob"].data + logqba = guide_tr.nodes["b"]["log_prob"].data + logqcba = guide_tr.nodes["c"]["log_prob"].data + logqdba = guide_tr.nodes["d"]["log_prob"].data + + idx = torch.arange(2) + dlogqa_dproba = (a - proba) / (proba * (1 - proba)) + dlogqb_dprobb = (b - probb[a.long()]) / ( + probb[a.long()] * (1 - probb[a.long()]) + ) + dlogqc_dprobc = (c - Vindex(probc)[b.long(), idx]) / ( + Vindex(probc)[b.long(), idx] * (1 - Vindex(probc)[b.long(), idx]) + ) + dlogqd_dprobd = (c - probd[b.long(), idx]) / ( + Vindex(probd)[b.long(), idx] * (1 - Vindex(probd)[b.long(), idx]) + ) + + if Elbo == "Trace_ELBO": + # fine-grained Rao-Blackwellization based on provenance tracking + expected_grads = { + "prob_a": -dlogqa_dproba + * ( + logpa + + (logpba + logpcba + logpdba + logpecba).sum() + - logqa + - (logqba + logqcba + logqdba).sum() + - 1 + ), + "prob_b": ( + -dlogqb_dprobb + * ( + (logpba + logpcba + logpdba + logpecba) + - (logqba + logqcba + logqdba) + - 1 + ) + ).sum(), + "prob_c": -dlogqc_dprobc * (logpcba + logpecba - logqcba - 1), + "prob_d": -dlogqd_dprobd * (logpdba - logqdba - 1), + } + elif Elbo == "TraceEnum_ELBO": + # only uses plate conditional independence for Rao-Blackwellization + expected_grads = { + "prob_a": -dlogqa_dproba + * ( + logpa + + (logpba + logpcba + logpdba + logpecba).sum() + - logqa + - (logqba + logqcba + logqdba).sum() + - 1 + ), + "prob_b": ( + -dlogqb_dprobb + * ( + (logpba + logpcba + logpdba + logpecba) + - (logqba + logqcba + logqdba) + - 1 + ) + ).sum(), + "prob_c": -dlogqc_dprobc + * ( + (logpba + logpcba + logpdba + logpecba) + - (logqba + logqcba + logqdba) + - 1 + ), + "prob_d": -dlogqd_dprobd + * ( + (logpba + logpcba + logpdba + logpecba) + - (logqba + logqcba + logqdba) + - 1 + ), + } + elif Elbo == "TraceGraph_ELBO": + # Raw-Blackwellization uses conditional independence based on + # 1) the sequential order of samples + # 2) plate generators + # additionally removes dlogq_dparam terms + expected_grads = { + "prob_a": -dlogqa_dproba + * ( + logpa + + (logpba + logpcba + logpdba + logpecba).sum() + - logqa + - (logqba + logqcba + logqdba).sum() + ), + "prob_b": ( + -dlogqb_dprobb + * ( + (logpba + logpcba + logpdba + logpecba) + - (logqba + logqcba + logqdba) + ) + ).sum(), + "prob_c": -dlogqc_dprobc + * ((logpcba + logpdba + logpecba) - (logqcba + logqdba)), + "prob_d": -dlogqd_dprobd * (logpdba + logpecba - logqdba), + } + actual_grads["prob_b"] = actual_grads["prob_b"][a.long()] + actual_grads["prob_c"] = Vindex(actual_grads["prob_c"])[b.long(), idx] + actual_grads["prob_d"] = Vindex(actual_grads["prob_d"])[b.long(), idx] + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) From 91384edfaf3c8845f7f074d658018b2e89ac36d4 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 20 Aug 2021 10:19:50 -0400 Subject: [PATCH 11/53] lazy eval --- pyro/contrib/funsor/infer/trace_elbo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 1226898fc1..d40a9a2849 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -55,7 +55,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): log_measures = tape.adjoint( funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) ) - with funsor.terms.eager: + with funsor.terms.lazy: # finally, integrate out guide variables in the elbo and all plates elbo = to_funsor(0, output=funsor.Real) for cost in costs: From b0182c0fc6963ab2d0ff84ed374566e992f35647 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 30 Sep 2021 10:14:54 -0400 Subject: [PATCH 12/53] vectorize particles; update tests --- pyro/contrib/funsor/handlers/__init__.py | 3 +- .../contrib/funsor/handlers/enum_messenger.py | 39 +- pyro/contrib/funsor/infer/trace_elbo.py | 26 +- tests/contrib/funsor/test_gradient.py | 359 ++++-------------- 4 files changed, 119 insertions(+), 308 deletions(-) diff --git a/pyro/contrib/funsor/handlers/__init__.py b/pyro/contrib/funsor/handlers/__init__.py index 80d01740f9..d3ac3bf471 100644 --- a/pyro/contrib/funsor/handlers/__init__.py +++ b/pyro/contrib/funsor/handlers/__init__.py @@ -15,7 +15,7 @@ ) from pyro.poutine.handlers import _make_handler -from .enum_messenger import EnumMessenger, queue # noqa: F401 +from .enum_messenger import EnumMessenger, ProvenanceMessenger, queue # noqa: F401 from .named_messenger import MarkovMessenger, NamedMessenger from .plate_messenger import PlateMessenger, VectorizedMarkovMessenger from .replay_messenger import ReplayMessenger @@ -26,6 +26,7 @@ MarkovMessenger, NamedMessenger, PlateMessenger, + ProvenanceMessenger, ReplayMessenger, TraceMessenger, VectorizedMarkovMessenger, diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 5851bb823e..85c24679df 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -19,6 +19,7 @@ from pyro.contrib.funsor.handlers.replay_messenger import ReplayMessenger from pyro.contrib.funsor.handlers.trace_messenger import TraceMessenger from pyro.poutine.escape_messenger import EscapeMessenger +from pyro.poutine.reentrant_messenger import ReentrantMessenger from pyro.poutine.subsample_messenger import _Subsample funsor.set_backend("torch") @@ -179,6 +180,38 @@ def enumerate_site(dist, msg): raise ValueError("{} not valid enum strategy".format(msg)) +class ProvenanceMessenger(ReentrantMessenger): + def _pyro_sample(self, msg): + if ( + msg["done"] + or msg["is_observed"] + or msg["infer"].get("enumerate") == "parallel" + or isinstance(msg["fn"], _Subsample) + ): + return + + if "funsor" not in msg: + msg["funsor"] = {} + + unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)( + value=msg["name"] + ) + msg["funsor"]["log_measure"] = _enum_strategy_default( + unsampled_log_measure, msg + ) + support_value = _get_support_value( + msg["funsor"]["log_measure"], + msg["name"], + expand=msg["infer"].get("expand", False), + ) + msg["funsor"]["value"] = funsor.constant.Constant( + OrderedDict(((msg["name"], support_value.output),)), + support_value, + ) + msg["value"] = to_data(msg["funsor"]["value"]) + msg["done"] = True + + class EnumMessenger(NamedMessenger): """ This version of :class:`~EnumMessenger` uses :func:`~pyro.contrib.funsor.to_data` @@ -201,15 +234,11 @@ def _pyro_sample(self, msg): value=msg["name"] ) msg["funsor"]["log_measure"] = enumerate_site(unsampled_log_measure, msg) - support_value = _get_support_value( + msg["funsor"]["value"] = _get_support_value( msg["funsor"]["log_measure"], msg["name"], expand=msg["infer"].get("expand", False), ) - msg["funsor"]["value"] = funsor.constant.Constant( - OrderedDict(((msg["name"], funsor.Real),)), - support_value, - ) msg["value"] = to_data(msg["funsor"]["value"]) msg["done"] = True diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index d40a9a2849..e9fa47ab50 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -9,8 +9,7 @@ from funsor.constant import Constant from pyro.contrib.funsor import to_data, to_funsor -from pyro.contrib.funsor.handlers import enum, plate, replay, trace -from pyro.contrib.funsor.infer import config_enumerate +from pyro.contrib.funsor.handlers import plate, provenance, replay, trace from pyro.distributions.util import copy_docs_from from pyro.infer import Trace_ELBO as _OrigTrace_ELBO @@ -21,12 +20,10 @@ @copy_docs_from(_OrigTrace_ELBO) class Trace_ELBO(ELBO): def differentiable_loss(self, model, guide, *args, **kwargs): - with enum(), plate( - size=self.num_particles + with provenance(), plate( + name="_PARTICLES", size=self.num_particles, dim=-self.max_plate_nesting ) if self.num_particles > 1 else contextlib.ExitStack(): - guide_tr = trace( - config_enumerate(default="flat", num_samples=self.num_particles)(guide) - ).get_trace(*args, **kwargs) + guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) model_terms = terms_from_trace(model_tr) @@ -35,7 +32,9 @@ def differentiable_loss(self, model, guide, *args, **kwargs): # cost terms costs = model_terms["log_factors"] + [-f for f in guide_terms["log_factors"]] - plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] + plate_vars = ( + guide_terms["plate_vars"] | model_terms["plate_vars"] + ) - frozenset({"_PARTICLES"}) # compute log_measures corresponding to each cost term # the goal is to achieve fine-grained Rao-Blackwellization targets = dict() @@ -61,14 +60,19 @@ def differentiable_loss(self, model, guide, *args, **kwargs): for cost in costs: target = targets[cost.input_vars] log_measure = log_measures[target] - measure_vars = frozenset(cost.inputs) - plate_vars + measure_vars = (frozenset(cost.inputs) - plate_vars) - frozenset( + {"_PARTICLES"} + ) elbo_term = funsor.Integrate( log_measure, cost, measure_vars, ) - elbo += elbo_term.reduce( - funsor.ops.add, plate_vars & frozenset(cost.inputs) + elbo += ( + elbo_term.reduce( + funsor.ops.add, plate_vars & frozenset(cost.inputs) + ).reduce(funsor.ops.add) + / self.num_particles ) # evaluate the elbo, using memoize to share tensor computation where possible diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index 13949826e5..3a11b8158f 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -17,332 +17,109 @@ funsor.set_backend("torch") from pyroapi import distributions as dist - from pyroapi import handlers, infer, pyro, pyro_backend + from pyroapi import infer, pyro, pyro_backend except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") logger = logging.getLogger(__name__) -@pytest.mark.parametrize( - "Elbo,backend", - [ - ("TraceEnum_ELBO", "pyro"), - ("Trace_ELBO", "contrib.funsor"), - ], -) -def test_particle_gradient_0(Elbo, backend): - with pyro_backend(backend): - pyro.clear_param_store() - data = torch.tensor([-0.5, 2.0]) +def model_0(data): + with pyro.plate("data", len(data)): + z = pyro.sample("z", dist.Bernoulli(0.3)) + pyro.sample("x", dist.Normal(z, 1), obs=data) - def model(): - with pyro.plate("data", len(data)): - z = pyro.sample("z", dist.Poisson(3)) - pyro.sample("x", dist.Normal(z, 1), obs=data) - def guide(): - with pyro.plate("data", len(data)): - rate = pyro.param("rate", lambda: torch.tensor([3.5, 1.5])) - pyro.sample("z", dist.Poisson(rate)) +def guide_0(data): + with pyro.plate("data", len(data)): + probs = pyro.param("probs", lambda: torch.tensor([0.4, 0.5])) + pyro.sample("z", dist.Bernoulli(probs)) - elbo = getattr(infer, Elbo)( - max_plate_nesting=1, # set this to ensure rng agrees across runs - num_particles=1, - strict_enumeration_warning=False, - ) - # Elbo gradient estimator - pyro.set_rng_seed(0) - elbo.loss_and_grads(model, guide) - params = dict(pyro.get_param_store().named_parameters()) - actual_grads = { - name: param.grad.detach().cpu() for name, param in params.items() - } +def model_1(data): + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + probs_b = torch.tensor([0.1, 0.2]) + b = pyro.sample("b", dist.Bernoulli(probs_b[a.long()])) + pyro.sample("c", dist.Normal(b, 1), obs=data) - # capture sample values and log_probs - pyro.set_rng_seed(0) - guide_tr = handlers.trace(guide).get_trace() - model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() - guide_tr.compute_log_prob() - model_tr.compute_log_prob() - z = guide_tr.nodes["z"]["value"].data - rate = pyro.param("rate").data - loss_i = ( - model_tr.nodes["x"]["log_prob"].data - + model_tr.nodes["z"]["log_prob"].data - - guide_tr.nodes["z"]["log_prob"].data - ) - dlogq_drate = z / rate - 1 - expected_grads = { - "rate": -(dlogq_drate * loss_i - dlogq_drate), - } +def guide_1(data): + probs_a = pyro.param( + "probs_a", + lambda: torch.tensor(0.5), + ) + a = pyro.sample("a", dist.Bernoulli(probs_a)) + with pyro.plate("data", len(data)): + probs_b = pyro.param("probs_b", lambda: torch.tensor([[0.5, 0.6], [0.4, 0.35]])) + pyro.sample("b", dist.Bernoulli(Vindex(probs_b)[a.long(), torch.arange(2)])) + + +def model_2(data): + prob_b = torch.tensor([0.3, 0.4]) + prob_c = torch.tensor([0.5, 0.6]) + prob_d = torch.tensor([0.2, 0.3]) + prob_e = torch.tensor([0.5, 0.1]) + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) + c = pyro.sample("c", dist.Bernoulli(prob_c[b.long()])) + pyro.sample("d", dist.Bernoulli(prob_d[b.long()])) + pyro.sample("e", dist.Bernoulli(prob_e[c.long()]), obs=data) - for name in sorted(params): - logger.info("expected {} = {}".format(name, expected_grads[name])) - logger.info("actual {} = {}".format(name, actual_grads[name])) - assert_equal(actual_grads, expected_grads, prec=1e-4) +def guide_2(data): + prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) + prob_b = pyro.param("prob_b", lambda: torch.tensor([0.4, 0.3])) + prob_c = pyro.param("prob_c", lambda: torch.tensor([[0.3, 0.8], [0.2, 0.5]])) + prob_d = pyro.param("prob_d", lambda: torch.tensor([[0.2, 0.9], [0.1, 0.4]])) + a = pyro.sample("a", dist.Bernoulli(prob_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) + pyro.sample("c", dist.Bernoulli(Vindex(prob_c)[b.long(), idx])) + pyro.sample("d", dist.Bernoulli(Vindex(prob_d)[b.long(), idx])) @pytest.mark.parametrize( - "Elbo,backend", + "model,guide,data", [ - ("TraceEnum_ELBO", "pyro"), - ("Trace_ELBO", "contrib.funsor"), + (model_0, guide_0, torch.tensor([-0.5, 2.0])), + (model_1, guide_1, torch.tensor([-0.5, 2.0])), + (model_2, guide_2, torch.tensor([0.0, 1.0])), ], ) -def test_particle_gradient_1(Elbo, backend): - with pyro_backend(backend): - pyro.clear_param_store() - data = torch.tensor([-0.5, 2.0]) - - def model(): - a = pyro.sample("a", dist.Bernoulli(0.3)) - with pyro.plate("data", len(data)): - rate = torch.tensor([2.0, 3.0]) - b = pyro.sample("b", dist.Poisson(rate[a.long()])) - pyro.sample("c", dist.Normal(b, 1), obs=data) - - def guide(): - prob = pyro.param( - "prob", - lambda: torch.tensor(0.5), - ) - a = pyro.sample("a", dist.Bernoulli(prob)) - with pyro.plate("data", len(data)): - rate = pyro.param( - "rate", lambda: torch.tensor([[3.5, 1.5], [0.5, 2.5]]) - ) - pyro.sample("b", dist.Poisson(rate[a.long()])) +def test_gradient(model, guide, data): - elbo = getattr(infer, Elbo)( + # Expected grads based on exact integration + with pyro_backend("pyro"): + pyro.clear_param_store() + elbo = infer.TraceEnum_ELBO( max_plate_nesting=1, # set this to ensure rng agrees across runs - num_particles=1, strict_enumeration_warning=False, ) - - # Elbo gradient estimator - pyro.set_rng_seed(0) - elbo.loss_and_grads(model, guide) + elbo.loss_and_grads(model, infer.config_enumerate(guide), data) params = dict(pyro.get_param_store().named_parameters()) - actual_grads = { - name: param.grad.detach().cpu() for name, param in params.items() - } - - # capture sample values and log_probs - pyro.set_rng_seed(0) - guide_tr = handlers.trace(guide).get_trace() - model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() - guide_tr.compute_log_prob() - model_tr.compute_log_prob() - a = guide_tr.nodes["a"]["value"].data - b = guide_tr.nodes["b"]["value"].data - prob = pyro.param("prob").data - rate = pyro.param("rate").data - - dlogqa_dprob = (a - prob) / (prob * (1 - prob)) - dlogqb_drate = b / rate[a.long()] - 1 - - loss_a = ( - model_tr.nodes["a"]["log_prob"].data - guide_tr.nodes["a"]["log_prob"].data - ) - loss_bc = ( - model_tr.nodes["b"]["log_prob"].data - + model_tr.nodes["c"]["log_prob"].data - - guide_tr.nodes["b"]["log_prob"].data - ) expected_grads = { - "prob": -(dlogqa_dprob * (loss_a + loss_bc.sum()) - dlogqa_dprob), - "rate": -(dlogqb_drate * (loss_bc) - dlogqb_drate), + name: param.grad.detach().cpu() for name, param in params.items() } - actual_grads["rate"] = actual_grads["rate"][a.long()] - - for name in sorted(params): - logger.info("expected {} = {}".format(name, expected_grads[name])) - logger.info("actual {} = {}".format(name, actual_grads[name])) - - assert_equal(actual_grads, expected_grads, prec=1e-4) - -@pytest.mark.parametrize( - "Elbo,backend", - [ - ("TraceEnum_ELBO", "pyro"), - ("Trace_ELBO", "contrib.funsor"), - ("TraceGraph_ELBO", "pyro"), - ], -) -def test_particle_gradient_2(Elbo, backend): - with pyro_backend(backend): + # Actual grads averaged over num_particles + with pyro_backend("contrib.funsor"): pyro.clear_param_store() - data = torch.tensor([0.0, 1.0]) - - def model(): - prob_b = torch.tensor([0.3, 0.4]) - prob_c = torch.tensor([0.5, 0.6]) - prob_d = torch.tensor([0.2, 0.3]) - prob_e = torch.tensor([0.5, 0.1]) - a = pyro.sample("a", dist.Bernoulli(0.3)) - with pyro.plate("data", len(data)): - b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) - c = pyro.sample("c", dist.Bernoulli(prob_c[b.long()])) - pyro.sample("d", dist.Bernoulli(prob_d[b.long()])) - pyro.sample("e", dist.Bernoulli(prob_e[c.long()]), obs=data) - - def guide(): - prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) - prob_b = pyro.param("prob_b", lambda: torch.tensor([0.4, 0.3])) - prob_c = pyro.param( - "prob_c", lambda: torch.tensor([[0.3, 0.8], [0.2, 0.5]]) - ) - prob_d = pyro.param( - "prob_d", lambda: torch.tensor([[0.2, 0.9], [0.1, 0.4]]) - ) - a = pyro.sample("a", dist.Bernoulli(prob_a)) - with pyro.plate("data", len(data)) as idx: - b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) - pyro.sample("c", dist.Bernoulli(Vindex(prob_c)[b.long(), idx])) - pyro.sample("d", dist.Bernoulli(Vindex(prob_d)[b.long(), idx])) - - elbo = getattr(infer, Elbo)( + elbo = infer.Trace_ELBO( max_plate_nesting=1, # set this to ensure rng agrees across runs - num_particles=1, + num_particles=50000, + vectorize_particles=True, strict_enumeration_warning=False, ) - - # Elbo gradient estimator - pyro.set_rng_seed(0) - elbo.loss_and_grads(model, guide) + elbo.loss_and_grads(model, guide, data) params = dict(pyro.get_param_store().named_parameters()) actual_grads = { name: param.grad.detach().cpu() for name, param in params.items() } - # capture sample values and log_probs - pyro.set_rng_seed(0) - guide_tr = handlers.trace(guide).get_trace() - model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() - guide_tr.compute_log_prob() - model_tr.compute_log_prob() - - a = guide_tr.nodes["a"]["value"].data - b = guide_tr.nodes["b"]["value"].data - c = guide_tr.nodes["c"]["value"].data - - proba = pyro.param("prob_a").data - probb = pyro.param("prob_b").data - probc = pyro.param("prob_c").data - probd = pyro.param("prob_d").data - - logpa = model_tr.nodes["a"]["log_prob"].data - logpba = model_tr.nodes["b"]["log_prob"].data - logpcba = model_tr.nodes["c"]["log_prob"].data - logpdba = model_tr.nodes["d"]["log_prob"].data - logpecba = model_tr.nodes["e"]["log_prob"].data - - logqa = guide_tr.nodes["a"]["log_prob"].data - logqba = guide_tr.nodes["b"]["log_prob"].data - logqcba = guide_tr.nodes["c"]["log_prob"].data - logqdba = guide_tr.nodes["d"]["log_prob"].data - - idx = torch.arange(2) - dlogqa_dproba = (a - proba) / (proba * (1 - proba)) - dlogqb_dprobb = (b - probb[a.long()]) / ( - probb[a.long()] * (1 - probb[a.long()]) - ) - dlogqc_dprobc = (c - Vindex(probc)[b.long(), idx]) / ( - Vindex(probc)[b.long(), idx] * (1 - Vindex(probc)[b.long(), idx]) - ) - dlogqd_dprobd = (c - probd[b.long(), idx]) / ( - Vindex(probd)[b.long(), idx] * (1 - Vindex(probd)[b.long(), idx]) - ) - - if Elbo == "Trace_ELBO": - # fine-grained Rao-Blackwellization based on provenance tracking - expected_grads = { - "prob_a": -dlogqa_dproba - * ( - logpa - + (logpba + logpcba + logpdba + logpecba).sum() - - logqa - - (logqba + logqcba + logqdba).sum() - - 1 - ), - "prob_b": ( - -dlogqb_dprobb - * ( - (logpba + logpcba + logpdba + logpecba) - - (logqba + logqcba + logqdba) - - 1 - ) - ).sum(), - "prob_c": -dlogqc_dprobc * (logpcba + logpecba - logqcba - 1), - "prob_d": -dlogqd_dprobd * (logpdba - logqdba - 1), - } - elif Elbo == "TraceEnum_ELBO": - # only uses plate conditional independence for Rao-Blackwellization - expected_grads = { - "prob_a": -dlogqa_dproba - * ( - logpa - + (logpba + logpcba + logpdba + logpecba).sum() - - logqa - - (logqba + logqcba + logqdba).sum() - - 1 - ), - "prob_b": ( - -dlogqb_dprobb - * ( - (logpba + logpcba + logpdba + logpecba) - - (logqba + logqcba + logqdba) - - 1 - ) - ).sum(), - "prob_c": -dlogqc_dprobc - * ( - (logpba + logpcba + logpdba + logpecba) - - (logqba + logqcba + logqdba) - - 1 - ), - "prob_d": -dlogqd_dprobd - * ( - (logpba + logpcba + logpdba + logpecba) - - (logqba + logqcba + logqdba) - - 1 - ), - } - elif Elbo == "TraceGraph_ELBO": - # Raw-Blackwellization uses conditional independence based on - # 1) the sequential order of samples - # 2) plate generators - # additionally removes dlogq_dparam terms - expected_grads = { - "prob_a": -dlogqa_dproba - * ( - logpa - + (logpba + logpcba + logpdba + logpecba).sum() - - logqa - - (logqba + logqcba + logqdba).sum() - ), - "prob_b": ( - -dlogqb_dprobb - * ( - (logpba + logpcba + logpdba + logpecba) - - (logqba + logqcba + logqdba) - ) - ).sum(), - "prob_c": -dlogqc_dprobc - * ((logpcba + logpdba + logpecba) - (logqcba + logqdba)), - "prob_d": -dlogqd_dprobd * (logpdba + logpecba - logqdba), - } - actual_grads["prob_b"] = actual_grads["prob_b"][a.long()] - actual_grads["prob_c"] = Vindex(actual_grads["prob_c"])[b.long(), idx] - actual_grads["prob_d"] = Vindex(actual_grads["prob_d"])[b.long(), idx] - - for name in sorted(params): - logger.info("expected {} = {}".format(name, expected_grads[name])) - logger.info("actual {} = {}".format(name, actual_grads[name])) + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) - assert_equal(actual_grads, expected_grads, prec=1e-4) + assert_equal(actual_grads, expected_grads, prec=0.06) From dc31767bfa8a5c3087df5c7709fcef1324bbdeea Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 30 Sep 2021 11:51:23 -0400 Subject: [PATCH 13/53] minor fixes; pin to funsor@normalize-logaddexp --- pyro/contrib/funsor/infer/trace_elbo.py | 16 ++++++++-------- setup.py | 3 ++- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index e9fa47ab50..b9a9201fb5 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -21,7 +21,9 @@ class Trace_ELBO(ELBO): def differentiable_loss(self, model, guide, *args, **kwargs): with provenance(), plate( - name="_PARTICLES", size=self.num_particles, dim=-self.max_plate_nesting + name="num_particles_vectorized", + size=self.num_particles, + dim=-self.max_plate_nesting, ) if self.num_particles > 1 else contextlib.ExitStack(): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) @@ -34,7 +36,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): plate_vars = ( guide_terms["plate_vars"] | model_terms["plate_vars"] - ) - frozenset({"_PARTICLES"}) + ) - frozenset({"num_particles_vectorized"}) # compute log_measures corresponding to each cost term # the goal is to achieve fine-grained Rao-Blackwellization targets = dict() @@ -61,19 +63,17 @@ def differentiable_loss(self, model, guide, *args, **kwargs): target = targets[cost.input_vars] log_measure = log_measures[target] measure_vars = (frozenset(cost.inputs) - plate_vars) - frozenset( - {"_PARTICLES"} + {"num_particles_vectorized"} ) elbo_term = funsor.Integrate( log_measure, cost, measure_vars, ) - elbo += ( - elbo_term.reduce( - funsor.ops.add, plate_vars & frozenset(cost.inputs) - ).reduce(funsor.ops.add) - / self.num_particles + elbo += elbo_term.reduce( + funsor.ops.add, plate_vars & frozenset(cost.inputs) ) + elbo = elbo.reduce(funsor.ops.mean) # evaluate the elbo, using memoize to share tensor computation where possible with funsor.interpretations.memoize(): diff --git a/setup.py b/setup.py index 0ab243f98b..b71b738872 100644 --- a/setup.py +++ b/setup.py @@ -135,7 +135,8 @@ "horovod": ["horovod[pytorch]>=0.19"], "funsor": [ # This must be a released version when Pyro is released. - "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@b25eecdaaf134c536f280131bddcf47a475a3c37", + # "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@b25eecdaaf134c536f280131bddcf47a475a3c37", + "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@normalize-logaddexp", # "funsor[torch]==0.4.1", ], }, From 5c0fe752fd9f3cf9d7e092e68a936992a949c951 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 30 Sep 2021 12:11:29 -0400 Subject: [PATCH 14/53] update docs/requirements --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 712272c74f..9465c5d3b1 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,4 +7,4 @@ observations>=0.1.4 opt_einsum>=2.3.2 pyro-api>=0.1.1 tqdm>=4.36 -funsor[torch] +funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@normalize-logaddexp From 2b15fe13ce3129f4204d8db884b8e147dd572f39 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 30 Sep 2021 16:10:00 -0400 Subject: [PATCH 15/53] combine Trace_ELBO and TraceEnum_ELBO --- pyro/contrib/funsor/infer/trace_elbo.py | 33 ++++++- tests/contrib/funsor/test_gradient.py | 114 +++++++++++++++++------- 2 files changed, 113 insertions(+), 34 deletions(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index b9a9201fb5..28ddf54c04 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -9,7 +9,7 @@ from funsor.constant import Constant from pyro.contrib.funsor import to_data, to_funsor -from pyro.contrib.funsor.handlers import plate, provenance, replay, trace +from pyro.contrib.funsor.handlers import enum, plate, provenance, replay, trace from pyro.distributions.util import copy_docs_from from pyro.infer import Trace_ELBO as _OrigTrace_ELBO @@ -20,7 +20,12 @@ @copy_docs_from(_OrigTrace_ELBO) class Trace_ELBO(ELBO): def differentiable_loss(self, model, guide, *args, **kwargs): - with provenance(), plate( + with enum( + first_available_dim=(-self.max_plate_nesting - 1) + if self.max_plate_nesting is not None + and self.max_plate_nesting != float("inf") + else None + ), provenance(), plate( name="num_particles_vectorized", size=self.num_particles, dim=-self.max_plate_nesting, @@ -31,8 +36,28 @@ def differentiable_loss(self, model, guide, *args, **kwargs): model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) - # cost terms - costs = model_terms["log_factors"] + [-f for f in guide_terms["log_factors"]] + # identify and contract out auxiliary variables in the model with partial_sum_product + contracted_factors, uncontracted_factors = [], [] + for f in model_terms["log_factors"]: + if model_terms["measure_vars"].intersection(f.inputs): + contracted_factors.append(f) + else: + uncontracted_factors.append(f) + # incorporate the effects of subsampling and handlers.scale through a common scale factor + contracted_costs = [ + model_terms["scale"] * f + for f in funsor.sum_product.partial_sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + model_terms["log_measures"] + contracted_factors, + plates=model_terms["plate_vars"], + eliminate=model_terms["measure_vars"], + ) + ] + + # accumulate costs from model (logp) and guide (-logq) + costs = contracted_costs + uncontracted_factors # model costs: logp + costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq plate_vars = ( guide_terms["plate_vars"] | model_terms["plate_vars"] diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index 3a11b8158f..19becbb8b8 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -26,58 +26,69 @@ def model_0(data): with pyro.plate("data", len(data)): - z = pyro.sample("z", dist.Bernoulli(0.3)) - pyro.sample("x", dist.Normal(z, 1), obs=data) + z = pyro.sample("z", dist.Categorical(torch.tensor([0.3, 0.7]))) + pyro.sample("x", dist.Normal(z.to(data.dtype), 1), obs=data) def guide_0(data): with pyro.plate("data", len(data)): - probs = pyro.param("probs", lambda: torch.tensor([0.4, 0.5])) - pyro.sample("z", dist.Bernoulli(probs)) + probs = pyro.param("probs", lambda: torch.tensor([[0.4, 0.6], [0.5, 0.5]])) + pyro.sample("z", dist.Categorical(probs)) def model_1(data): - a = pyro.sample("a", dist.Bernoulli(0.3)) + a = pyro.sample("a", dist.Categorical(torch.tensor([0.3, 0.7]))) with pyro.plate("data", len(data)): - probs_b = torch.tensor([0.1, 0.2]) - b = pyro.sample("b", dist.Bernoulli(probs_b[a.long()])) - pyro.sample("c", dist.Normal(b, 1), obs=data) + probs_b = torch.tensor([[0.1, 0.9], [0.2, 0.8]]) + b = pyro.sample("b", dist.Categorical(probs_b[a.long()])) + pyro.sample("c", dist.Normal(b.to(data.dtype), 1), obs=data) def guide_1(data): probs_a = pyro.param( "probs_a", - lambda: torch.tensor(0.5), + lambda: torch.tensor([0.5, 0.5]), ) - a = pyro.sample("a", dist.Bernoulli(probs_a)) - with pyro.plate("data", len(data)): - probs_b = pyro.param("probs_b", lambda: torch.tensor([[0.5, 0.6], [0.4, 0.35]])) - pyro.sample("b", dist.Bernoulli(Vindex(probs_b)[a.long(), torch.arange(2)])) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("data", len(data)) as idx: + probs_b = pyro.param( + "probs_b", + lambda: torch.tensor( + [[[0.5, 0.5], [0.6, 0.4]], [[0.4, 0.6], [0.35, 0.65]]] + ), + ) + pyro.sample("b", dist.Categorical(Vindex(probs_b)[a.long(), idx])) def model_2(data): - prob_b = torch.tensor([0.3, 0.4]) - prob_c = torch.tensor([0.5, 0.6]) - prob_d = torch.tensor([0.2, 0.3]) - prob_e = torch.tensor([0.5, 0.1]) - a = pyro.sample("a", dist.Bernoulli(0.3)) + prob_b = torch.tensor([[0.3, 0.7], [0.4, 0.6]]) + prob_c = torch.tensor([[0.5, 0.5], [0.6, 0.4]]) + prob_d = torch.tensor([[0.2, 0.8], [0.3, 0.7]]) + prob_e = torch.tensor([[0.5, 0.5], [0.1, 0.9]]) + a = pyro.sample("a", dist.Categorical(torch.tensor([0.3, 0.7]))) with pyro.plate("data", len(data)): - b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) - c = pyro.sample("c", dist.Bernoulli(prob_c[b.long()])) - pyro.sample("d", dist.Bernoulli(prob_d[b.long()])) - pyro.sample("e", dist.Bernoulli(prob_e[c.long()]), obs=data) + b = pyro.sample("b", dist.Categorical(prob_b[a.long()])) + c = pyro.sample("c", dist.Categorical(prob_c[b.long()])) + pyro.sample("d", dist.Categorical(prob_d[b.long()])) + pyro.sample("e", dist.Categorical(prob_e[c.long()]), obs=data) def guide_2(data): - prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) - prob_b = pyro.param("prob_b", lambda: torch.tensor([0.4, 0.3])) - prob_c = pyro.param("prob_c", lambda: torch.tensor([[0.3, 0.8], [0.2, 0.5]])) - prob_d = pyro.param("prob_d", lambda: torch.tensor([[0.2, 0.9], [0.1, 0.4]])) - a = pyro.sample("a", dist.Bernoulli(prob_a)) + prob_a = pyro.param("prob_a", lambda: torch.tensor([0.5, 0.5])) + prob_b = pyro.param("prob_b", lambda: torch.tensor([[0.4, 0.6], [0.3, 0.7]])) + prob_c = pyro.param( + "prob_c", + lambda: torch.tensor([[[0.3, 0.7], [0.8, 0.2]], [[0.2, 0.8], [0.5, 0.5]]]), + ) + prob_d = pyro.param( + "prob_d", + lambda: torch.tensor([[[0.2, 0.8], [0.9, 0.1]], [[0.1, 0.9], [0.4, 0.6]]]), + ) + a = pyro.sample("a", dist.Categorical(prob_a)) with pyro.plate("data", len(data)) as idx: - b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) - pyro.sample("c", dist.Bernoulli(Vindex(prob_c)[b.long(), idx])) - pyro.sample("d", dist.Bernoulli(Vindex(prob_d)[b.long(), idx])) + b = pyro.sample("b", dist.Categorical(prob_b[a.long()])) + pyro.sample("c", dist.Categorical(Vindex(prob_c)[b.long(), idx])) + pyro.sample("d", dist.Categorical(Vindex(prob_d)[b.long(), idx])) @pytest.mark.parametrize( @@ -123,3 +134,46 @@ def test_gradient(model, guide, data): logger.info("actual {} = {}".format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=0.06) + + +@pytest.mark.parametrize( + "model,guide,data", + [ + (model_0, guide_0, torch.tensor([-0.5, 2.0])), + (model_1, guide_1, torch.tensor([-0.5, 2.0])), + (model_2, guide_2, torch.tensor([0.0, 1.0])), + ], +) +def test_enum_gradient(model, guide, data): + + # Expected grads based on exact integration + with pyro_backend("pyro"): + pyro.clear_param_store() + elbo = infer.TraceEnum_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + strict_enumeration_warning=False, + ) + elbo.loss_and_grads(model, infer.config_enumerate(guide), data) + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = { + name: param.grad.detach().cpu() for name, param in params.items() + } + + # Actual grads + with pyro_backend("contrib.funsor"): + pyro.clear_param_store() + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + strict_enumeration_warning=False, + ) + elbo.loss_and_grads(model, infer.config_enumerate(guide), data) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = { + name: param.grad.detach().cpu() for name, param in params.items() + } + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-3) From 351090b46cce75097a813c3cdfb58b07e0c0ba01 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 30 Sep 2021 21:06:54 -0400 Subject: [PATCH 16/53] eager evaluation --- pyro/contrib/funsor/infer/trace_elbo.py | 2 +- tests/contrib/funsor/hehe.py | 64 +++++++++++++++++++++++++ tests/contrib/funsor/test_gradient.py | 4 +- 3 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 tests/contrib/funsor/hehe.py diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 28ddf54c04..078b5e48eb 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -81,7 +81,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): log_measures = tape.adjoint( funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) ) - with funsor.terms.lazy: + with funsor.terms.eager: # finally, integrate out guide variables in the elbo and all plates elbo = to_funsor(0, output=funsor.Real) for cost in costs: diff --git a/tests/contrib/funsor/hehe.py b/tests/contrib/funsor/hehe.py new file mode 100644 index 0000000000..7855c19b96 --- /dev/null +++ b/tests/contrib/funsor/hehe.py @@ -0,0 +1,64 @@ +from funsor.cnf import Contraction +from funsor.tensor import Tensor +import torch +import funsor.ops as ops +from funsor import Bint, Real +from funsor.terms import Unary, Binary, Variable, Number, lazy, to_data +from funsor.constant import Constant +from funsor.delta import Delta +import funsor +funsor.set_backend("torch") + +t1 = Unary(ops.exp, + Contraction(ops.null, ops.add, + frozenset(), + (Delta( + (('x__BOUND_77', + (Tensor( + torch.tensor([-1.5227684840130555, 0.38168390241818895, -1.0276085758313882], dtype=torch.float64), # noqa + (('plate_outer__BOUND_79', + Bint[3],),), + 'real'), + Number(0.0),),),)), + Constant( + (('plate_inner__BOUND_78', + Bint[2],),), + Tensor( + torch.tensor([0.0, 0.0, 0.0], dtype=torch.float64), + (('plate_outer__BOUND_79', + Bint[3],),), + 'real')),))) +t2 = Unary(ops.all, + Binary(ops.eq, + Tensor( + torch.tensor([-1.5227684840130555, 0.38168390241818895, -1.0276085758313882], dtype=torch.float64), # noqa + (('plate_outer__BOUND_79', + Bint[3],),), + 'real'), + Tensor( + torch.tensor([-1.5227684840130555, 0.38168390241818895, -1.0276085758313882], dtype=torch.float64), # noqa + (('plate_outer__BOUND_79', + Bint[3],),), + 'real'))) +t3 = Tensor( + torch.tensor([[-3.304130052277938, -0.9255234395261538, -1.5122103473560844], [-3.217490519312117, -1.2663745664889694, -1.2900109994655682]], dtype=torch.float64), # noqa + (('plate_inner__BOUND_78', + Bint[2],), + ('plate_outer__BOUND_79', + Bint[3],),), + 'real') +t = Contraction(ops.add, ops.mul, + frozenset({Variable('plate_inner__BOUND_78', Bint[2]), Variable('x__BOUND_77', Real), Variable('plate_outer__BOUND_79', Bint[3])}), # noqa + (t1, + t2, + t3,)) +with lazy: + term = Contraction(ops.add, ops.mul, + frozenset({Variable('plate_inner__BOUND_78', Bint[2]), Variable('x__BOUND_77', Real), Variable('plate_outer__BOUND_79', Bint[3])}), # noqa + (t1, + t2, + t3,)) + +breakpoint() +x = to_data(funsor.optimizer.apply_optimizer(term)) +pass diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index 19becbb8b8..005317847c 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -144,7 +144,7 @@ def test_gradient(model, guide, data): (model_2, guide_2, torch.tensor([0.0, 1.0])), ], ) -def test_enum_gradient(model, guide, data): +def test_guide_enum_gradient(model, guide, data): # Expected grads based on exact integration with pyro_backend("pyro"): @@ -176,4 +176,4 @@ def test_enum_gradient(model, guide, data): logger.info("expected {} = {}".format(name, expected_grads[name])) logger.info("actual {} = {}".format(name, actual_grads[name])) - assert_equal(actual_grads, expected_grads, prec=1e-3) + assert_equal(actual_grads, expected_grads, prec=1e-4) From 7d029c77de4237eb3a93ebca15201b41b8bbf70f Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 30 Sep 2021 21:11:47 -0400 Subject: [PATCH 17/53] rm file --- tests/contrib/funsor/hehe.py | 64 ------------------------------------ 1 file changed, 64 deletions(-) delete mode 100644 tests/contrib/funsor/hehe.py diff --git a/tests/contrib/funsor/hehe.py b/tests/contrib/funsor/hehe.py deleted file mode 100644 index 7855c19b96..0000000000 --- a/tests/contrib/funsor/hehe.py +++ /dev/null @@ -1,64 +0,0 @@ -from funsor.cnf import Contraction -from funsor.tensor import Tensor -import torch -import funsor.ops as ops -from funsor import Bint, Real -from funsor.terms import Unary, Binary, Variable, Number, lazy, to_data -from funsor.constant import Constant -from funsor.delta import Delta -import funsor -funsor.set_backend("torch") - -t1 = Unary(ops.exp, - Contraction(ops.null, ops.add, - frozenset(), - (Delta( - (('x__BOUND_77', - (Tensor( - torch.tensor([-1.5227684840130555, 0.38168390241818895, -1.0276085758313882], dtype=torch.float64), # noqa - (('plate_outer__BOUND_79', - Bint[3],),), - 'real'), - Number(0.0),),),)), - Constant( - (('plate_inner__BOUND_78', - Bint[2],),), - Tensor( - torch.tensor([0.0, 0.0, 0.0], dtype=torch.float64), - (('plate_outer__BOUND_79', - Bint[3],),), - 'real')),))) -t2 = Unary(ops.all, - Binary(ops.eq, - Tensor( - torch.tensor([-1.5227684840130555, 0.38168390241818895, -1.0276085758313882], dtype=torch.float64), # noqa - (('plate_outer__BOUND_79', - Bint[3],),), - 'real'), - Tensor( - torch.tensor([-1.5227684840130555, 0.38168390241818895, -1.0276085758313882], dtype=torch.float64), # noqa - (('plate_outer__BOUND_79', - Bint[3],),), - 'real'))) -t3 = Tensor( - torch.tensor([[-3.304130052277938, -0.9255234395261538, -1.5122103473560844], [-3.217490519312117, -1.2663745664889694, -1.2900109994655682]], dtype=torch.float64), # noqa - (('plate_inner__BOUND_78', - Bint[2],), - ('plate_outer__BOUND_79', - Bint[3],),), - 'real') -t = Contraction(ops.add, ops.mul, - frozenset({Variable('plate_inner__BOUND_78', Bint[2]), Variable('x__BOUND_77', Real), Variable('plate_outer__BOUND_79', Bint[3])}), # noqa - (t1, - t2, - t3,)) -with lazy: - term = Contraction(ops.add, ops.mul, - frozenset({Variable('plate_inner__BOUND_78', Bint[2]), Variable('x__BOUND_77', Real), Variable('plate_outer__BOUND_79', Bint[3])}), # noqa - (t1, - t2, - t3,)) - -breakpoint() -x = to_data(funsor.optimizer.apply_optimizer(term)) -pass From 1bb7380affe10e9c581db958e61471825b139df7 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 1 Oct 2021 12:12:25 -0400 Subject: [PATCH 18/53] lazy --- pyro/contrib/funsor/infer/trace_elbo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 078b5e48eb..28ddf54c04 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -81,7 +81,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): log_measures = tape.adjoint( funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) ) - with funsor.terms.eager: + with funsor.terms.lazy: # finally, integrate out guide variables in the elbo and all plates elbo = to_funsor(0, output=funsor.Real) for cost in costs: From 42ad4fae37ba35f2c45b58b7a90c804e2fb6a893 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 1 Oct 2021 16:26:15 -0400 Subject: [PATCH 19/53] remove memoize --- pyro/contrib/funsor/infer/trace_elbo.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 28ddf54c04..e2fba79ee2 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -98,11 +98,10 @@ def differentiable_loss(self, model, guide, *args, **kwargs): elbo += elbo_term.reduce( funsor.ops.add, plate_vars & frozenset(cost.inputs) ) + # average over Monte-Carlo particles elbo = elbo.reduce(funsor.ops.mean) - # evaluate the elbo, using memoize to share tensor computation where possible - with funsor.interpretations.memoize(): - return -to_data(apply_optimizer(elbo)) + return -to_data(apply_optimizer(elbo)) class JitTrace_ELBO(Jit_ELBO, Trace_ELBO): From 5b6afdb602478029fafa243f9bbc2279e6208a56 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 9 Oct 2021 22:19:35 -0400 Subject: [PATCH 20/53] merge TraceEnum_ELBO --- .../contrib/funsor/handlers/enum_messenger.py | 9 +- pyro/contrib/funsor/infer/trace_elbo.py | 121 +++++-- pyro/contrib/funsor/infer/traceenum_elbo.py | 263 +++++--------- tests/contrib/funsor/test_gradient.py | 329 +++++++++++++++++- 4 files changed, 529 insertions(+), 193 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 85c24679df..d385d7c54b 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -59,6 +59,13 @@ def _get_support_value_tensor(funsor_dist, name, **kwargs): ) +@_get_support_value.register(funsor.Provenance) +def _get_support_value_tensor(funsor_dist, name, **kwargs): + assert name in funsor_dist.inputs + value = _get_support_value(funsor_dist.arg, name, **kwargs) + return funsor.Provenance(funsor_dist.const_inputs, value) + + @_get_support_value.register(funsor.distribution.Distribution) def _get_support_value_distribution(funsor_dist, name, expand=False): assert name == funsor_dist.value.name @@ -204,7 +211,7 @@ def _pyro_sample(self, msg): msg["name"], expand=msg["infer"].get("expand", False), ) - msg["funsor"]["value"] = funsor.constant.Constant( + msg["funsor"]["value"] = funsor.Provenance( OrderedDict(((msg["name"], support_value.output),)), support_value, ) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index e2fba79ee2..2220602cc8 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -14,7 +14,81 @@ from pyro.infer import Trace_ELBO as _OrigTrace_ELBO from .elbo import ELBO, Jit_ELBO -from .traceenum_elbo import apply_optimizer, terms_from_trace + + +# Work around a bug in unfold_contraction_generic_tuple interacting with +# Approximate introduced in https://github.com/pyro-ppl/funsor/pull/488 . +# Once fixed, this can be replaced by funsor.optimizer.apply_optimizer(). +def apply_optimizer(x): + with funsor.interpretations.normalize: + expr = funsor.interpreter.reinterpret(x) + + with funsor.optimizer.optimize_base: + return funsor.interpreter.reinterpret(expr) + + +def terms_from_trace(tr): + """Helper function to extract elbo components from execution traces.""" + # data structure containing densities, measures, scales, and identification + # of free variables as either product (plate) variables or sum (measure) variables + terms = { + "log_factors": [], + "log_measures": [], + "scale": to_funsor(1.0), + "plate_vars": frozenset(), + "measure_vars": frozenset(), + "plate_to_step": dict(), + } + for name, node in tr.nodes.items(): + # add markov dimensions to the plate_to_step dictionary + if node["type"] == "markov_chain": + terms["plate_to_step"][node["name"]] = node["value"] + # ensure previous step variables are added to measure_vars + for step in node["value"]: + terms["measure_vars"] |= frozenset( + { + var + for var in step[1:-1] + if tr.nodes[var]["funsor"].get("log_measure", None) is not None + } + ) + if ( + node["type"] != "sample" + or type(node["fn"]).__name__ == "_Subsample" + or node["infer"].get("_do_not_score", False) + ): + continue + # grab plate dimensions from the cond_indep_stack + terms["plate_vars"] |= frozenset( + f.name for f in node["cond_indep_stack"] if f.vectorized + ) + # grab the log-measure, found only at sites that are not replayed or observed + if node["funsor"].get("log_measure", None) is not None: + terms["log_measures"].append(node["funsor"]["log_measure"]) + # sum (measure) variables: the fresh non-plate variables at a site + terms["measure_vars"] |= ( + frozenset(node["funsor"]["value"].inputs) | {name} + ) - terms["plate_vars"] + # grab the scale, assuming a common subsampling scale + if ( + node.get("replay_active", False) + and set(node["funsor"]["log_prob"].inputs) & terms["measure_vars"] + and float(to_data(node["funsor"]["scale"])) != 1.0 + ): + # model site that depends on enumerated variable: common scale + terms["scale"] = node["funsor"]["scale"] + else: # otherwise: default scale behavior + node["funsor"]["log_prob"] = ( + node["funsor"]["log_prob"] * node["funsor"]["scale"] + ) + # grab the log-density, found at all sites except those that are not replayed + if node["is_observed"] or not node.get("replay_skipped", False): + terms["log_factors"].append(node["funsor"]["log_prob"]) + # add plate dimensions to the plate_to_step dictionary + terms["plate_to_step"].update( + {plate: terms["plate_to_step"].get(plate, {}) for plate in terms["plate_vars"]} + ) + return terms @copy_docs_from(_OrigTrace_ELBO) @@ -36,32 +110,35 @@ def differentiable_loss(self, model, guide, *args, **kwargs): model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) - # identify and contract out auxiliary variables in the model with partial_sum_product - contracted_factors, uncontracted_factors = [], [] - for f in model_terms["log_factors"]: - if model_terms["measure_vars"].intersection(f.inputs): - contracted_factors.append(f) - else: - uncontracted_factors.append(f) - # incorporate the effects of subsampling and handlers.scale through a common scale factor - contracted_costs = [ - model_terms["scale"] * f - for f in funsor.sum_product.partial_sum_product( - funsor.ops.logaddexp, - funsor.ops.add, - model_terms["log_measures"] + contracted_factors, - plates=model_terms["plate_vars"], - eliminate=model_terms["measure_vars"], - ) - ] + plate_vars = ( + guide_terms["plate_vars"] | model_terms["plate_vars"] + ) - frozenset({"num_particles_vectorized"}) + + with funsor.terms.lazy: + # identify and contract out auxiliary variables in the model with partial_sum_product + contracted_factors, uncontracted_factors = [], [] + for f in model_terms["log_factors"]: + if model_terms["measure_vars"].intersection(f.inputs): + contracted_factors.append(f) + else: + uncontracted_factors.append(f) + # incorporate the effects of subsampling and handlers.scale through a common scale factor + contracted_costs = [ + model_terms["scale"] * f + for f in funsor.sum_product.partial_sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + model_terms["log_measures"] + contracted_factors, + plates=plate_vars, + eliminate=model_terms["measure_vars"], + # eliminate=model_terms["measure_vars"] - guide_terms["measure_vars"], + ) + ] # accumulate costs from model (logp) and guide (-logq) costs = contracted_costs + uncontracted_factors # model costs: logp costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq - plate_vars = ( - guide_terms["plate_vars"] | model_terms["plate_vars"] - ) - frozenset({"num_particles_vectorized"}) # compute log_measures corresponding to each cost term # the goal is to achieve fine-grained Rao-Blackwellization targets = dict() diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 7915f7e101..2dbbc52dfc 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -12,80 +12,7 @@ from pyro.distributions.util import copy_docs_from from pyro.infer import TraceEnum_ELBO as _OrigTraceEnum_ELBO - -# Work around a bug in unfold_contraction_generic_tuple interacting with -# Approximate introduced in https://github.com/pyro-ppl/funsor/pull/488 . -# Once fixed, this can be replaced by funsor.optimizer.apply_optimizer(). -def apply_optimizer(x): - with funsor.interpretations.normalize: - expr = funsor.interpreter.reinterpret(x) - - with funsor.optimizer.optimize_base: - return funsor.interpreter.reinterpret(expr) - - -def terms_from_trace(tr): - """Helper function to extract elbo components from execution traces.""" - # data structure containing densities, measures, scales, and identification - # of free variables as either product (plate) variables or sum (measure) variables - terms = { - "log_factors": [], - "log_measures": [], - "scale": to_funsor(1.0), - "plate_vars": frozenset(), - "measure_vars": frozenset(), - "plate_to_step": dict(), - } - for name, node in tr.nodes.items(): - # add markov dimensions to the plate_to_step dictionary - if node["type"] == "markov_chain": - terms["plate_to_step"][node["name"]] = node["value"] - # ensure previous step variables are added to measure_vars - for step in node["value"]: - terms["measure_vars"] |= frozenset( - { - var - for var in step[1:-1] - if tr.nodes[var]["funsor"].get("log_measure", None) is not None - } - ) - if ( - node["type"] != "sample" - or type(node["fn"]).__name__ == "_Subsample" - or node["infer"].get("_do_not_score", False) - ): - continue - # grab plate dimensions from the cond_indep_stack - terms["plate_vars"] |= frozenset( - f.name for f in node["cond_indep_stack"] if f.vectorized - ) - # grab the log-measure, found only at sites that are not replayed or observed - if node["funsor"].get("log_measure", None) is not None: - terms["log_measures"].append(node["funsor"]["log_measure"]) - # sum (measure) variables: the fresh non-plate variables at a site - terms["measure_vars"] |= ( - frozenset(node["funsor"]["value"].inputs) | {name} - ) - terms["plate_vars"] - # grab the scale, assuming a common subsampling scale - if ( - node.get("replay_active", False) - and set(node["funsor"]["log_prob"].inputs) & terms["measure_vars"] - and float(to_data(node["funsor"]["scale"])) != 1.0 - ): - # model site that depends on enumerated variable: common scale - terms["scale"] = node["funsor"]["scale"] - else: # otherwise: default scale behavior - node["funsor"]["log_prob"] = ( - node["funsor"]["log_prob"] * node["funsor"]["scale"] - ) - # grab the log-density, found at all sites except those that are not replayed - if node["is_observed"] or not node.get("replay_skipped", False): - terms["log_factors"].append(node["funsor"]["log_prob"]) - # add plate dimensions to the plate_to_step dictionary - terms["plate_to_step"].update( - {plate: terms["plate_to_step"].get(plate, {}) for plate in terms["plate_vars"]} - ) - return terms +from .trace_elbo import Trace_ELBO, terms_from_trace, apply_optimizer @copy_docs_from(_OrigTraceEnum_ELBO) @@ -165,100 +92,102 @@ def differentiable_loss(self, model, guide, *args, **kwargs): with funsor.interpretations.memoize(): return -to_data(apply_optimizer(elbo)) +class TraceEnum_ELBO(Trace_ELBO): + pass -@copy_docs_from(_OrigTraceEnum_ELBO) -class TraceEnum_ELBO(ELBO): - def differentiable_loss(self, model, guide, *args, **kwargs): - - # get batched, enumerated, to_funsor-ed traces from the guide and model - with plate( - size=self.num_particles - ) if self.num_particles > 1 else contextlib.ExitStack(), enum( - first_available_dim=(-self.max_plate_nesting - 1) - if self.max_plate_nesting - else None - ): - guide_tr = trace(guide).get_trace(*args, **kwargs) - model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) - - # extract from traces all metadata that we will need to compute the elbo - guide_terms = terms_from_trace(guide_tr) - model_terms = terms_from_trace(model_tr) - - # build up a lazy expression for the elbo - with funsor.terms.lazy: - # identify and contract out auxiliary variables in the model with partial_sum_product - contracted_factors, uncontracted_factors = [], [] - for f in model_terms["log_factors"]: - if model_terms["measure_vars"].intersection(f.inputs): - contracted_factors.append(f) - else: - uncontracted_factors.append(f) - # incorporate the effects of subsampling and handlers.scale through a common scale factor - contracted_costs = [ - model_terms["scale"] * f - for f in funsor.sum_product.partial_sum_product( - funsor.ops.logaddexp, - funsor.ops.add, - model_terms["log_measures"] + contracted_factors, - plates=model_terms["plate_vars"], - eliminate=model_terms["measure_vars"], - ) - ] - - # accumulate costs from model (logp) and guide (-logq) - costs = contracted_costs + uncontracted_factors # model costs: logp - costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq - - # compute expected cost - # Cf. pyro.infer.util.Dice.compute_expectation() - # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212 - # TODO Replace this with funsor.Expectation - plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] - # compute the marginal logq in the guide corresponding to each cost term - targets = dict() - for cost in costs: - input_vars = frozenset(cost.inputs) - if input_vars not in targets: - targets[input_vars] = funsor.Tensor( - funsor.ops.new_zeros( - funsor.tensor.get_default_prototype(), - tuple(v.size for v in cost.inputs.values()), - ), - cost.inputs, - cost.dtype, - ) - with AdjointTape() as tape: - logzq = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, - guide_terms["log_measures"] + list(targets.values()), - plates=plate_vars, - eliminate=(plate_vars | guide_terms["measure_vars"]), - ) - marginals = tape.adjoint( - funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) - ) - # finally, integrate out guide variables in the elbo and all plates - elbo = to_funsor(0, output=funsor.Real) - for cost in costs: - target = targets[frozenset(cost.inputs)] - logzq_local = marginals[target].reduce( - funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars - ) - log_prob = marginals[target] - logzq_local - elbo_term = funsor.Integrate( - log_prob, - cost, - guide_terms["measure_vars"] & frozenset(log_prob.inputs), - ) - elbo += elbo_term.reduce( - funsor.ops.add, plate_vars & frozenset(cost.inputs) - ) - - # evaluate the elbo, using memoize to share tensor computation where possible - with funsor.interpretations.memoize(): - return -to_data(apply_optimizer(elbo)) +# @copy_docs_from(_OrigTraceEnum_ELBO) +# class TraceEnum_ELBO(ELBO): +# def differentiable_loss(self, model, guide, *args, **kwargs): +# +# # get batched, enumerated, to_funsor-ed traces from the guide and model +# with plate( +# size=self.num_particles +# ) if self.num_particles > 1 else contextlib.ExitStack(), enum( +# first_available_dim=(-self.max_plate_nesting - 1) +# if self.max_plate_nesting +# else None +# ): +# guide_tr = trace(guide).get_trace(*args, **kwargs) +# model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) +# +# # extract from traces all metadata that we will need to compute the elbo +# guide_terms = terms_from_trace(guide_tr) +# model_terms = terms_from_trace(model_tr) +# +# # build up a lazy expression for the elbo +# with funsor.terms.lazy: +# # identify and contract out auxiliary variables in the model with partial_sum_product +# contracted_factors, uncontracted_factors = [], [] +# for f in model_terms["log_factors"]: +# if model_terms["measure_vars"].intersection(f.inputs): +# contracted_factors.append(f) +# else: +# uncontracted_factors.append(f) +# # incorporate the effects of subsampling and handlers.scale through a common scale factor +# contracted_costs = [ +# model_terms["scale"] * f +# for f in funsor.sum_product.partial_sum_product( +# funsor.ops.logaddexp, +# funsor.ops.add, +# model_terms["log_measures"] + contracted_factors, +# plates=model_terms["plate_vars"], +# eliminate=model_terms["measure_vars"], +# ) +# ] +# +# # accumulate costs from model (logp) and guide (-logq) +# costs = contracted_costs + uncontracted_factors # model costs: logp +# costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq +# +# # compute expected cost +# # Cf. pyro.infer.util.Dice.compute_expectation() +# # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212 +# # TODO Replace this with funsor.Expectation +# plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] +# # compute the marginal logq in the guide corresponding to each cost term +# targets = dict() +# for cost in costs: +# input_vars = frozenset(cost.inputs) +# if input_vars not in targets: +# targets[input_vars] = funsor.Tensor( +# funsor.ops.new_zeros( +# funsor.tensor.get_default_prototype(), +# tuple(v.size for v in cost.inputs.values()), +# ), +# cost.inputs, +# cost.dtype, +# ) +# with AdjointTape() as tape: +# logzq = funsor.sum_product.sum_product( +# funsor.ops.logaddexp, +# funsor.ops.add, +# guide_terms["log_measures"] + list(targets.values()), +# plates=plate_vars, +# eliminate=(plate_vars | guide_terms["measure_vars"]), +# ) +# marginals = tape.adjoint( +# funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) +# ) +# # finally, integrate out guide variables in the elbo and all plates +# elbo = to_funsor(0, output=funsor.Real) +# for cost in costs: +# target = targets[frozenset(cost.inputs)] +# logzq_local = marginals[target].reduce( +# funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars +# ) +# log_prob = marginals[target] - logzq_local +# elbo_term = funsor.Integrate( +# log_prob, +# cost, +# guide_terms["measure_vars"] & frozenset(log_prob.inputs), +# ) +# elbo += elbo_term.reduce( +# funsor.ops.add, plate_vars & frozenset(cost.inputs) +# ) +# +# # evaluate the elbo, using memoize to share tensor computation where possible +# with funsor.interpretations.memoize(): +# return -to_data(apply_optimizer(elbo)) class JitTraceEnum_ELBO(Jit_ELBO, TraceEnum_ELBO): diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index 005317847c..8a05b98ebb 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -17,7 +17,7 @@ funsor.set_backend("torch") from pyroapi import distributions as dist - from pyroapi import infer, pyro, pyro_backend + from pyroapi import handlers, infer, pyro, pyro_backend except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") @@ -119,7 +119,7 @@ def test_gradient(model, guide, data): pyro.clear_param_store() elbo = infer.Trace_ELBO( max_plate_nesting=1, # set this to ensure rng agrees across runs - num_particles=50000, + num_particles=100000, vectorize_particles=True, strict_enumeration_warning=False, ) @@ -133,7 +133,7 @@ def test_gradient(model, guide, data): logger.info("expected {} = {}".format(name, expected_grads[name])) logger.info("actual {} = {}".format(name, actual_grads[name])) - assert_equal(actual_grads, expected_grads, prec=0.06) + assert_equal(actual_grads, expected_grads, prec=0.03) @pytest.mark.parametrize( @@ -177,3 +177,326 @@ def test_guide_enum_gradient(model, guide, data): logger.info("actual {} = {}".format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=1e-4) + +@pytest.mark.parametrize( + "Elbo,backend", + [ + ("TraceEnum_ELBO", "pyro"), + ("Trace_ELBO", "contrib.funsor"), + ], +) +def test_particle_gradient_0(Elbo, backend): + with pyro_backend(backend): + pyro.clear_param_store() + data = torch.tensor([-0.5, 2.0]) + + def model(): + with pyro.plate("data", len(data)): + z = pyro.sample("z", dist.Poisson(3)) + pyro.sample("x", dist.Normal(z, 1), obs=data) + + def guide(): + with pyro.plate("data", len(data)): + rate = pyro.param("rate", lambda: torch.tensor([3.5, 1.5])) + pyro.sample("z", dist.Poisson(rate)) + + elbo = getattr(infer, Elbo)( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Elbo gradient estimator + pyro.set_rng_seed(0) + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = { + name: param.grad.detach().cpu() for name, param in params.items() + } + + # capture sample values and log_probs + pyro.set_rng_seed(0) + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + z = guide_tr.nodes["z"]["value"].data + rate = pyro.param("rate").data + + loss_i = ( + model_tr.nodes["x"]["log_prob"].data + + model_tr.nodes["z"]["log_prob"].data + - guide_tr.nodes["z"]["log_prob"].data + ) + dlogq_drate = z / rate - 1 + expected_grads = { + "rate": -(dlogq_drate * loss_i - dlogq_drate), + } + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) + + +@pytest.mark.parametrize( + "Elbo,backend", + [ + ("TraceEnum_ELBO", "pyro"), + ("Trace_ELBO", "contrib.funsor"), + ], +) +def test_particle_gradient_1(Elbo, backend): + with pyro_backend(backend): + pyro.clear_param_store() + data = torch.tensor([-0.5, 2.0]) + + def model(): + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + rate = torch.tensor([2.0, 3.0]) + b = pyro.sample("b", dist.Poisson(rate[a.long()])) + pyro.sample("c", dist.Normal(b, 1), obs=data) + + def guide(): + prob = pyro.param( + "prob", + lambda: torch.tensor(0.5), + ) + a = pyro.sample("a", dist.Bernoulli(prob)) + with pyro.plate("data", len(data)): + rate = pyro.param( + "rate", lambda: torch.tensor([[3.5, 1.5], [0.5, 2.5]]) + ) + pyro.sample("b", dist.Poisson(rate[a.long()])) + + elbo = getattr(infer, Elbo)( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Elbo gradient estimator + pyro.set_rng_seed(0) + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = { + name: param.grad.detach().cpu() for name, param in params.items() + } + + # capture sample values and log_probs + pyro.set_rng_seed(0) + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + a = guide_tr.nodes["a"]["value"].data + b = guide_tr.nodes["b"]["value"].data + prob = pyro.param("prob").data + rate = pyro.param("rate").data + + dlogqa_dprob = (a - prob) / (prob * (1 - prob)) + dlogqb_drate = b / rate[a.long()] - 1 + + loss_a = ( + model_tr.nodes["a"]["log_prob"].data - guide_tr.nodes["a"]["log_prob"].data + ) + loss_bc = ( + model_tr.nodes["b"]["log_prob"].data + + model_tr.nodes["c"]["log_prob"].data + - guide_tr.nodes["b"]["log_prob"].data + ) + expected_grads = { + "prob": -(dlogqa_dprob * (loss_a + loss_bc.sum()) - dlogqa_dprob), + "rate": -(dlogqb_drate * (loss_bc) - dlogqb_drate), + } + actual_grads["rate"] = actual_grads["rate"][a.long()] + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) + + +@pytest.mark.parametrize( + "Elbo,backend", + [ + ("TraceEnum_ELBO", "pyro"), + ("Trace_ELBO", "contrib.funsor"), + ("TraceGraph_ELBO", "pyro"), + ], +) +def test_particle_gradient_2(Elbo, backend): + with pyro_backend(backend): + pyro.clear_param_store() + data = torch.tensor([0.0, 1.0]) + + def model(): + prob_b = torch.tensor([0.3, 0.4]) + prob_c = torch.tensor([0.5, 0.6]) + prob_d = torch.tensor([0.2, 0.3]) + prob_e = torch.tensor([0.5, 0.1]) + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) + c = pyro.sample("c", dist.Bernoulli(prob_c[b.long()])) + pyro.sample("d", dist.Bernoulli(prob_d[b.long()])) + pyro.sample("e", dist.Bernoulli(prob_e[c.long()]), obs=data) + + def guide(): + prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) + prob_b = pyro.param("prob_b", lambda: torch.tensor([0.4, 0.3])) + prob_c = pyro.param( + "prob_c", lambda: torch.tensor([[0.3, 0.8], [0.2, 0.5]]) + ) + prob_d = pyro.param( + "prob_d", lambda: torch.tensor([[0.2, 0.9], [0.1, 0.4]]) + ) + a = pyro.sample("a", dist.Bernoulli(prob_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) + pyro.sample("c", dist.Bernoulli(Vindex(prob_c)[b.long(), idx])) + pyro.sample("d", dist.Bernoulli(Vindex(prob_d)[b.long(), idx])) + + elbo = getattr(infer, Elbo)( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Elbo gradient estimator + pyro.set_rng_seed(0) + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = { + name: param.grad.detach().cpu() for name, param in params.items() + } + + # capture sample values and log_probs + pyro.set_rng_seed(0) + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + + a = guide_tr.nodes["a"]["value"].data + b = guide_tr.nodes["b"]["value"].data + c = guide_tr.nodes["c"]["value"].data + + proba = pyro.param("prob_a").data + probb = pyro.param("prob_b").data + probc = pyro.param("prob_c").data + probd = pyro.param("prob_d").data + + logpa = model_tr.nodes["a"]["log_prob"].data + logpba = model_tr.nodes["b"]["log_prob"].data + logpcba = model_tr.nodes["c"]["log_prob"].data + logpdba = model_tr.nodes["d"]["log_prob"].data + logpecba = model_tr.nodes["e"]["log_prob"].data + + logqa = guide_tr.nodes["a"]["log_prob"].data + logqba = guide_tr.nodes["b"]["log_prob"].data + logqcba = guide_tr.nodes["c"]["log_prob"].data + logqdba = guide_tr.nodes["d"]["log_prob"].data + + idx = torch.arange(2) + dlogqa_dproba = (a - proba) / (proba * (1 - proba)) + dlogqb_dprobb = (b - probb[a.long()]) / ( + probb[a.long()] * (1 - probb[a.long()]) + ) + dlogqc_dprobc = (c - Vindex(probc)[b.long(), idx]) / ( + Vindex(probc)[b.long(), idx] * (1 - Vindex(probc)[b.long(), idx]) + ) + dlogqd_dprobd = (c - probd[b.long(), idx]) / ( + Vindex(probd)[b.long(), idx] * (1 - Vindex(probd)[b.long(), idx]) + ) + + if Elbo == "Trace_ELBO": + # fine-grained Rao-Blackwellization based on provenance tracking + expected_grads = { + "prob_a": -dlogqa_dproba + * ( + logpa + + (logpba + logpcba + logpdba + logpecba).sum() + - logqa + - (logqba + logqcba + logqdba).sum() + - 1 + ), + "prob_b": ( + -dlogqb_dprobb + * ( + (logpba + logpcba + logpdba + logpecba) + - (logqba + logqcba + logqdba) + - 1 + ) + ).sum(), + "prob_c": -dlogqc_dprobc * (logpcba + logpecba - logqcba - 1), + "prob_d": -dlogqd_dprobd * (logpdba - logqdba - 1), + } + elif Elbo == "TraceEnum_ELBO": + # only uses plate conditional independence for Rao-Blackwellization + expected_grads = { + "prob_a": -dlogqa_dproba + * ( + logpa + + (logpba + logpcba + logpdba + logpecba).sum() + - logqa + - (logqba + logqcba + logqdba).sum() + - 1 + ), + "prob_b": ( + -dlogqb_dprobb + * ( + (logpba + logpcba + logpdba + logpecba) + - (logqba + logqcba + logqdba) + - 1 + ) + ).sum(), + "prob_c": -dlogqc_dprobc + * ( + (logpba + logpcba + logpdba + logpecba) + - (logqba + logqcba + logqdba) + - 1 + ), + "prob_d": -dlogqd_dprobd + * ( + (logpba + logpcba + logpdba + logpecba) + - (logqba + logqcba + logqdba) + - 1 + ), + } + elif Elbo == "TraceGraph_ELBO": + # Raw-Blackwellization uses conditional independence based on + # 1) the sequential order of samples + # 2) plate generators + # additionally removes dlogq_dparam terms + expected_grads = { + "prob_a": -dlogqa_dproba + * ( + logpa + + (logpba + logpcba + logpdba + logpecba).sum() + - logqa + - (logqba + logqcba + logqdba).sum() + ), + "prob_b": ( + -dlogqb_dprobb + * ( + (logpba + logpcba + logpdba + logpecba) + - (logqba + logqcba + logqdba) + ) + ).sum(), + "prob_c": -dlogqc_dprobc + * ((logpcba + logpdba + logpecba) - (logqcba + logqdba)), + "prob_d": -dlogqd_dprobd * (logpdba + logpecba - logqdba), + } + actual_grads["prob_b"] = actual_grads["prob_b"][a.long()] + actual_grads["prob_c"] = Vindex(actual_grads["prob_c"])[b.long(), idx] + actual_grads["prob_d"] = Vindex(actual_grads["prob_d"])[b.long(), idx] + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) From 33628aaace201fb5bdb5a6f652ae1cd06df44478 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 11 Oct 2021 19:05:46 -0400 Subject: [PATCH 21/53] skip test --- .../contrib/funsor/handlers/enum_messenger.py | 21 +++- pyro/contrib/funsor/infer/traceenum_elbo.py | 98 +------------------ setup.py | 2 +- tests/contrib/funsor/test_enum_funsor.py | 1 + tests/contrib/funsor/test_gradient.py | 1 + 5 files changed, 22 insertions(+), 101 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index d385d7c54b..f1975c369b 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -59,11 +59,11 @@ def _get_support_value_tensor(funsor_dist, name, **kwargs): ) -@_get_support_value.register(funsor.Provenance) -def _get_support_value_tensor(funsor_dist, name, **kwargs): +@_get_support_value.register(funsor.Constant) +def _get_support_value_constant(funsor_dist, name, **kwargs): assert name in funsor_dist.inputs value = _get_support_value(funsor_dist.arg, name, **kwargs) - return funsor.Provenance(funsor_dist.const_inputs, value) + return funsor.Constant(funsor_dist.const_inputs, value) @_get_support_value.register(funsor.distribution.Distribution) @@ -164,6 +164,19 @@ def _enum_strategy_full(dist, msg): def _enum_strategy_exact(dist, msg): if isinstance(dist, funsor.Tensor): dist = dist - dist.reduce(funsor.ops.logaddexp, msg["name"]) + else: + if hasattr(dist, "probs"): + value = dist.probs.materialize( + funsor.Variable(msg["name"], dist.inputs[msg["name"]]) + ) + dist = dist(**{msg["name"]: value}) + elif hasattr(dist, "logits"): + value = dist.logits.materialize( + funsor.Variable(msg["name"], dist.inputs[msg["name"]]) + ) + dist = dist(**{msg["name"]: value}) + else: + raise ValueError return dist @@ -211,7 +224,7 @@ def _pyro_sample(self, msg): msg["name"], expand=msg["infer"].get("expand", False), ) - msg["funsor"]["value"] = funsor.Provenance( + msg["funsor"]["value"] = funsor.Constant( OrderedDict(((msg["name"], support_value.output),)), support_value, ) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 2dbbc52dfc..b77229381a 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -4,7 +4,6 @@ import contextlib import funsor -from funsor.adjoint import AdjointTape from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace @@ -12,7 +11,7 @@ from pyro.distributions.util import copy_docs_from from pyro.infer import TraceEnum_ELBO as _OrigTraceEnum_ELBO -from .trace_elbo import Trace_ELBO, terms_from_trace, apply_optimizer +from .trace_elbo import Trace_ELBO, apply_optimizer, terms_from_trace @copy_docs_from(_OrigTraceEnum_ELBO) @@ -92,103 +91,10 @@ def differentiable_loss(self, model, guide, *args, **kwargs): with funsor.interpretations.memoize(): return -to_data(apply_optimizer(elbo)) + class TraceEnum_ELBO(Trace_ELBO): pass -# @copy_docs_from(_OrigTraceEnum_ELBO) -# class TraceEnum_ELBO(ELBO): -# def differentiable_loss(self, model, guide, *args, **kwargs): -# -# # get batched, enumerated, to_funsor-ed traces from the guide and model -# with plate( -# size=self.num_particles -# ) if self.num_particles > 1 else contextlib.ExitStack(), enum( -# first_available_dim=(-self.max_plate_nesting - 1) -# if self.max_plate_nesting -# else None -# ): -# guide_tr = trace(guide).get_trace(*args, **kwargs) -# model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) -# -# # extract from traces all metadata that we will need to compute the elbo -# guide_terms = terms_from_trace(guide_tr) -# model_terms = terms_from_trace(model_tr) -# -# # build up a lazy expression for the elbo -# with funsor.terms.lazy: -# # identify and contract out auxiliary variables in the model with partial_sum_product -# contracted_factors, uncontracted_factors = [], [] -# for f in model_terms["log_factors"]: -# if model_terms["measure_vars"].intersection(f.inputs): -# contracted_factors.append(f) -# else: -# uncontracted_factors.append(f) -# # incorporate the effects of subsampling and handlers.scale through a common scale factor -# contracted_costs = [ -# model_terms["scale"] * f -# for f in funsor.sum_product.partial_sum_product( -# funsor.ops.logaddexp, -# funsor.ops.add, -# model_terms["log_measures"] + contracted_factors, -# plates=model_terms["plate_vars"], -# eliminate=model_terms["measure_vars"], -# ) -# ] -# -# # accumulate costs from model (logp) and guide (-logq) -# costs = contracted_costs + uncontracted_factors # model costs: logp -# costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq -# -# # compute expected cost -# # Cf. pyro.infer.util.Dice.compute_expectation() -# # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212 -# # TODO Replace this with funsor.Expectation -# plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] -# # compute the marginal logq in the guide corresponding to each cost term -# targets = dict() -# for cost in costs: -# input_vars = frozenset(cost.inputs) -# if input_vars not in targets: -# targets[input_vars] = funsor.Tensor( -# funsor.ops.new_zeros( -# funsor.tensor.get_default_prototype(), -# tuple(v.size for v in cost.inputs.values()), -# ), -# cost.inputs, -# cost.dtype, -# ) -# with AdjointTape() as tape: -# logzq = funsor.sum_product.sum_product( -# funsor.ops.logaddexp, -# funsor.ops.add, -# guide_terms["log_measures"] + list(targets.values()), -# plates=plate_vars, -# eliminate=(plate_vars | guide_terms["measure_vars"]), -# ) -# marginals = tape.adjoint( -# funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) -# ) -# # finally, integrate out guide variables in the elbo and all plates -# elbo = to_funsor(0, output=funsor.Real) -# for cost in costs: -# target = targets[frozenset(cost.inputs)] -# logzq_local = marginals[target].reduce( -# funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars -# ) -# log_prob = marginals[target] - logzq_local -# elbo_term = funsor.Integrate( -# log_prob, -# cost, -# guide_terms["measure_vars"] & frozenset(log_prob.inputs), -# ) -# elbo += elbo_term.reduce( -# funsor.ops.add, plate_vars & frozenset(cost.inputs) -# ) -# -# # evaluate the elbo, using memoize to share tensor computation where possible -# with funsor.interpretations.memoize(): -# return -to_data(apply_optimizer(elbo)) - class JitTraceEnum_ELBO(Jit_ELBO, TraceEnum_ELBO): pass diff --git a/setup.py b/setup.py index b71b738872..26e61f3cac 100644 --- a/setup.py +++ b/setup.py @@ -136,7 +136,7 @@ "funsor": [ # This must be a released version when Pyro is released. # "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@b25eecdaaf134c536f280131bddcf47a475a3c37", - "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@normalize-logaddexp", + "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@constant-materialize", # "funsor[torch]==0.4.1", ], }, diff --git a/tests/contrib/funsor/test_enum_funsor.py b/tests/contrib/funsor/test_enum_funsor.py index e8dbc60c69..651152737f 100644 --- a/tests/contrib/funsor/test_enum_funsor.py +++ b/tests/contrib/funsor/test_enum_funsor.py @@ -702,6 +702,7 @@ def guide(): _check_loss_and_grads(expected_loss, actual_loss) +@pytest.mark.skip(reason="wrong dependency structure?") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plate_7(scale): diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index 8a05b98ebb..bd341ce3f0 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -178,6 +178,7 @@ def test_guide_enum_gradient(model, guide, data): assert_equal(actual_grads, expected_grads, prec=1e-4) + @pytest.mark.parametrize( "Elbo,backend", [ From 18a973bc9f3701b8e695dd51ec60bf502bdcc5fa Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 11 Oct 2021 23:48:16 -0400 Subject: [PATCH 22/53] fixes --- examples/contrib/funsor/hmm.py | 6 +++++- pyro/contrib/funsor/infer/trace_elbo.py | 6 +++--- pyro/infer/autoguide/guides.py | 15 ++++++++------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/contrib/funsor/hmm.py b/examples/contrib/funsor/hmm.py index aac6fad61f..1777f05af7 100644 --- a/examples/contrib/funsor/hmm.py +++ b/examples/contrib/funsor/hmm.py @@ -700,7 +700,11 @@ def main(args): # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( - handlers.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_")) + handlers.block( + model, + expose_fn=lambda msg: msg["name"] is not None + and msg["name"].startswith("probs_"), + ) ) # To help debug our tensor shapes, let's print the shape of each site's diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 2220602cc8..bc055bc9ec 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -114,11 +114,12 @@ def differentiable_loss(self, model, guide, *args, **kwargs): guide_terms["plate_vars"] | model_terms["plate_vars"] ) - frozenset({"num_particles_vectorized"}) + model_measure_vars = model_terms["measure_vars"] - guide_terms["measure_vars"] with funsor.terms.lazy: # identify and contract out auxiliary variables in the model with partial_sum_product contracted_factors, uncontracted_factors = [], [] for f in model_terms["log_factors"]: - if model_terms["measure_vars"].intersection(f.inputs): + if model_measure_vars.intersection(f.inputs): contracted_factors.append(f) else: uncontracted_factors.append(f) @@ -130,8 +131,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): funsor.ops.add, model_terms["log_measures"] + contracted_factors, plates=plate_vars, - eliminate=model_terms["measure_vars"], - # eliminate=model_terms["measure_vars"] - guide_terms["measure_vars"], + eliminate=model_measure_vars, ) ] diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 63b64b9f52..825a536598 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -22,6 +22,7 @@ def model(): from contextlib import ExitStack import torch +from pyroapi import handlers from torch import nn from torch.distributions import biject_to @@ -134,8 +135,8 @@ def _create_plates(self, *args, **kwargs): def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure - model = poutine.block(self.model, prototype_hide_fn) - self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( + model = handlers.block(self.model, prototype_hide_fn) + self.prototype_trace = handlers.block(handlers.trace(model).get_trace)( *args, **kwargs ) if self.master is not None: @@ -1148,9 +1149,9 @@ def laplace_approximation(self, *args, **kwargs): Returns a :class:`AutoMultivariateNormal` instance whose posterior's `loc` and `scale_tril` are given by Laplace approximation. """ - guide_trace = poutine.trace(self).get_trace(*args, **kwargs) - model_trace = poutine.trace( - poutine.replay(self.model, trace=guide_trace) + guide_trace = handlers.trace(self).get_trace(*args, **kwargs) + model_trace = handlers.trace( + handlers.replay(self.model, trace=guide_trace) ).get_trace(*args, **kwargs) loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum() @@ -1175,8 +1176,8 @@ class AutoDiscreteParallel(AutoGuide): def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure - model = poutine.block(config_enumerate(self.model), prototype_hide_fn) - self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( + model = handlers.block(config_enumerate(self.model), prototype_hide_fn) + self.prototype_trace = handlers.block(handlers.trace(model).get_trace)( *args, **kwargs ) if self.master is not None: From 2c3ead3958510475bf7e448318fabc8c49b4cb42 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 12 Oct 2021 01:39:20 -0400 Subject: [PATCH 23/53] convert Tensor to Categorical --- .../contrib/funsor/handlers/enum_messenger.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index f1975c369b..9cfc4bbe73 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -164,19 +164,6 @@ def _enum_strategy_full(dist, msg): def _enum_strategy_exact(dist, msg): if isinstance(dist, funsor.Tensor): dist = dist - dist.reduce(funsor.ops.logaddexp, msg["name"]) - else: - if hasattr(dist, "probs"): - value = dist.probs.materialize( - funsor.Variable(msg["name"], dist.inputs[msg["name"]]) - ) - dist = dist(**{msg["name"]: value}) - elif hasattr(dist, "logits"): - value = dist.logits.materialize( - funsor.Variable(msg["name"], dist.inputs[msg["name"]]) - ) - dist = dist(**{msg["name"]: value}) - else: - raise ValueError return dist @@ -216,6 +203,21 @@ def _pyro_sample(self, msg): unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)( value=msg["name"] ) + # make Categorical from Tensor + if isinstance(unsampled_log_measure, funsor.Tensor): + from importlib import import_module + + backend = funsor.util.get_backend() + dist = import_module( + funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend] + ) + domain = unsampled_log_measure.inputs[msg["name"]] + logits = funsor.Lambda( + funsor.Variable(msg["name"], domain), unsampled_log_measure + ) + unsampled_log_measure = dist.CategoricalLogits( + logits=logits, value=msg["name"] + ) msg["funsor"]["log_measure"] = _enum_strategy_default( unsampled_log_measure, msg ) From 5fb15220a7e23e6f99e3f27885f2cd6c93be7116 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 12 Oct 2021 01:40:42 -0400 Subject: [PATCH 24/53] restore docs/requirements.txt --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 9465c5d3b1..712272c74f 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,4 +7,4 @@ observations>=0.1.4 opt_einsum>=2.3.2 pyro-api>=0.1.1 tqdm>=4.36 -funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@normalize-logaddexp +funsor[torch] From f907f936a929fa7910fa46877a0de4d934b9e8c3 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 12 Oct 2021 01:47:10 -0400 Subject: [PATCH 25/53] pin funsor in docs/requirements --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 712272c74f..9d1a7b8a26 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,4 +7,4 @@ observations>=0.1.4 opt_einsum>=2.3.2 pyro-api>=0.1.1 tqdm>=4.36 -funsor[torch] +funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@constant-materialize From 0042f8566b17cf1ffc3f5f5d12688a91ce799287 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 12 Oct 2021 11:55:31 -0400 Subject: [PATCH 26/53] use funsor.optimizer.apply_optimizer; higher precision in the test --- pyro/contrib/funsor/handlers/enum_messenger.py | 7 ------- pyro/contrib/funsor/infer/trace_elbo.py | 13 +------------ pyro/contrib/funsor/infer/traceenum_elbo.py | 13 ++++++++++++- tests/contrib/funsor/test_gradient.py | 4 ++-- 4 files changed, 15 insertions(+), 22 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 9cfc4bbe73..43d1d07cc1 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -59,13 +59,6 @@ def _get_support_value_tensor(funsor_dist, name, **kwargs): ) -@_get_support_value.register(funsor.Constant) -def _get_support_value_constant(funsor_dist, name, **kwargs): - assert name in funsor_dist.inputs - value = _get_support_value(funsor_dist.arg, name, **kwargs) - return funsor.Constant(funsor_dist.const_inputs, value) - - @_get_support_value.register(funsor.distribution.Distribution) def _get_support_value_distribution(funsor_dist, name, expand=False): assert name == funsor_dist.value.name diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index bc055bc9ec..27057a6187 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -16,17 +16,6 @@ from .elbo import ELBO, Jit_ELBO -# Work around a bug in unfold_contraction_generic_tuple interacting with -# Approximate introduced in https://github.com/pyro-ppl/funsor/pull/488 . -# Once fixed, this can be replaced by funsor.optimizer.apply_optimizer(). -def apply_optimizer(x): - with funsor.interpretations.normalize: - expr = funsor.interpreter.reinterpret(x) - - with funsor.optimizer.optimize_base: - return funsor.interpreter.reinterpret(expr) - - def terms_from_trace(tr): """Helper function to extract elbo components from execution traces.""" # data structure containing densities, measures, scales, and identification @@ -178,7 +167,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): # average over Monte-Carlo particles elbo = elbo.reduce(funsor.ops.mean) - return -to_data(apply_optimizer(elbo)) + return -to_data(funsor.optimizer.apply_optimizer(elbo)) class JitTrace_ELBO(Jit_ELBO, Trace_ELBO): diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index b77229381a..15f2c51e7f 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -11,7 +11,18 @@ from pyro.distributions.util import copy_docs_from from pyro.infer import TraceEnum_ELBO as _OrigTraceEnum_ELBO -from .trace_elbo import Trace_ELBO, apply_optimizer, terms_from_trace +from .trace_elbo import Trace_ELBO, terms_from_trace + + +# Work around a bug in unfold_contraction_generic_tuple interacting with +# Approximate introduced in https://github.com/pyro-ppl/funsor/pull/488 . +# Once fixed, this can be replaced by funsor.optimizer.apply_optimizer(). +def apply_optimizer(x): + with funsor.interpretations.normalize: + expr = funsor.interpreter.reinterpret(x) + + with funsor.optimizer.optimize_base: + return funsor.interpreter.reinterpret(expr) @copy_docs_from(_OrigTraceEnum_ELBO) diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index bd341ce3f0..f10333ac67 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -119,7 +119,7 @@ def test_gradient(model, guide, data): pyro.clear_param_store() elbo = infer.Trace_ELBO( max_plate_nesting=1, # set this to ensure rng agrees across runs - num_particles=100000, + num_particles=50000, vectorize_particles=True, strict_enumeration_warning=False, ) @@ -133,7 +133,7 @@ def test_gradient(model, guide, data): logger.info("expected {} = {}".format(name, expected_grads[name])) logger.info("actual {} = {}".format(name, actual_grads[name])) - assert_equal(actual_grads, expected_grads, prec=0.03) + assert_equal(actual_grads, expected_grads, prec=0.02) @pytest.mark.parametrize( From ee5a5ad29fcf81aa3db9b357a57fe3dc3aec5136 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 12 Oct 2021 12:04:56 -0400 Subject: [PATCH 27/53] pin funsor to the latest commit --- docs/requirements.txt | 2 +- setup.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 9d1a7b8a26..50c69f0a7d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,4 +7,4 @@ observations>=0.1.4 opt_einsum>=2.3.2 pyro-api>=0.1.1 tqdm>=4.36 -funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@constant-materialize +funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@18a107ccce40de9e3e98db87d116789e0909ce7b diff --git a/setup.py b/setup.py index 26e61f3cac..e1c5b9dee4 100644 --- a/setup.py +++ b/setup.py @@ -135,8 +135,7 @@ "horovod": ["horovod[pytorch]>=0.19"], "funsor": [ # This must be a released version when Pyro is released. - # "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@b25eecdaaf134c536f280131bddcf47a475a3c37", - "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@constant-materialize", + "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@18a107ccce40de9e3e98db87d116789e0909ce7b", # "funsor[torch]==0.4.1", ], }, From e4c6760acb695e30ca478d23046366ba41443200 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 12 Oct 2021 14:29:39 -0400 Subject: [PATCH 28/53] optimize logzq --- pyro/contrib/funsor/infer/trace_elbo.py | 44 ++++++++++++++----------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 27057a6187..aabf61d852 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -124,29 +124,35 @@ def differentiable_loss(self, model, guide, *args, **kwargs): ) ] - # accumulate costs from model (logp) and guide (-logq) - costs = contracted_costs + uncontracted_factors # model costs: logp - costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq - - # compute log_measures corresponding to each cost term - # the goal is to achieve fine-grained Rao-Blackwellization - targets = dict() - for cost in costs: - if cost.input_vars not in targets: - targets[cost.input_vars] = Constant( - cost.inputs, funsor.Tensor(torch.tensor(0.0)) + # accumulate costs from model (logp) and guide (-logq) + costs = contracted_costs + uncontracted_factors # model costs: logp + costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq + + # compute log_measures corresponding to each cost term + # the goal is to achieve fine-grained Rao-Blackwellization + targets = dict() + for cost in costs: + if cost.input_vars not in targets: + targets[cost.input_vars] = Constant( + cost.inputs, funsor.Tensor(torch.tensor(0.0)) + ) + + with AdjointTape() as tape: + logzq = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + guide_terms["log_measures"] + list(targets.values()), + plates=plate_vars, + eliminate=(plate_vars | guide_terms["measure_vars"]), ) - with AdjointTape() as tape: - logzq = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, - guide_terms["log_measures"] + list(targets.values()), - plates=plate_vars, - eliminate=(plate_vars | guide_terms["measure_vars"]), - ) + + breakpoint() + with funsor.terms.eager: + logzq = funsor.optimizer.apply_optimizer(logzq) log_measures = tape.adjoint( funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) ) + with funsor.terms.lazy: # finally, integrate out guide variables in the elbo and all plates elbo = to_funsor(0, output=funsor.Real) From aba300a885204b69a320890804be250eee2431f1 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 13 Oct 2021 07:54:30 -0400 Subject: [PATCH 29/53] optimize logzq --- pyro/contrib/funsor/infer/trace_elbo.py | 36 ++++++++++++------------- setup.py | 2 +- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index aabf61d852..aa069f992a 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -4,8 +4,7 @@ import contextlib import funsor -import torch -from funsor.adjoint import AdjointTape +from funsor.adjoint import adjoint from funsor.constant import Constant from pyro.contrib.funsor import to_data, to_funsor @@ -134,31 +133,32 @@ def differentiable_loss(self, model, guide, *args, **kwargs): for cost in costs: if cost.input_vars not in targets: targets[cost.input_vars] = Constant( - cost.inputs, funsor.Tensor(torch.tensor(0.0)) + cost.inputs, + funsor.Tensor( + funsor.ops.new_zeros( + funsor.tensor.get_default_prototype(), + (), + ) + ), ) - with AdjointTape() as tape: - logzq = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, - guide_terms["log_measures"] + list(targets.values()), - plates=plate_vars, - eliminate=(plate_vars | guide_terms["measure_vars"]), - ) - - breakpoint() - with funsor.terms.eager: + logzq = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + guide_terms["log_measures"] + list(targets.values()), + plates=plate_vars, + eliminate=(plate_vars | guide_terms["measure_vars"]), + ) logzq = funsor.optimizer.apply_optimizer(logzq) - log_measures = tape.adjoint( - funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) - ) + + marginals = adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq) with funsor.terms.lazy: # finally, integrate out guide variables in the elbo and all plates elbo = to_funsor(0, output=funsor.Real) for cost in costs: target = targets[cost.input_vars] - log_measure = log_measures[target] + log_measure = marginals[target] measure_vars = (frozenset(cost.inputs) - plate_vars) - frozenset( {"num_particles_vectorized"} ) diff --git a/setup.py b/setup.py index e1c5b9dee4..e251d681f8 100644 --- a/setup.py +++ b/setup.py @@ -135,7 +135,7 @@ "horovod": ["horovod[pytorch]>=0.19"], "funsor": [ # This must be a released version when Pyro is released. - "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@18a107ccce40de9e3e98db87d116789e0909ce7b", + "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@fix-cnf", # "funsor[torch]==0.4.1", ], }, From d823153932de2c176fc845675941aeb486551955 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 13 Oct 2021 08:06:23 -0400 Subject: [PATCH 30/53] restore TraceEnum_ELBO --- pyro/contrib/funsor/infer/trace_elbo.py | 65 +------- pyro/contrib/funsor/infer/traceenum_elbo.py | 162 +++++++++++++++++++- pyro/infer/autoguide/guides.py | 15 +- 3 files changed, 166 insertions(+), 76 deletions(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index aa069f992a..9c7be43d0a 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -13,70 +13,7 @@ from pyro.infer import Trace_ELBO as _OrigTrace_ELBO from .elbo import ELBO, Jit_ELBO - - -def terms_from_trace(tr): - """Helper function to extract elbo components from execution traces.""" - # data structure containing densities, measures, scales, and identification - # of free variables as either product (plate) variables or sum (measure) variables - terms = { - "log_factors": [], - "log_measures": [], - "scale": to_funsor(1.0), - "plate_vars": frozenset(), - "measure_vars": frozenset(), - "plate_to_step": dict(), - } - for name, node in tr.nodes.items(): - # add markov dimensions to the plate_to_step dictionary - if node["type"] == "markov_chain": - terms["plate_to_step"][node["name"]] = node["value"] - # ensure previous step variables are added to measure_vars - for step in node["value"]: - terms["measure_vars"] |= frozenset( - { - var - for var in step[1:-1] - if tr.nodes[var]["funsor"].get("log_measure", None) is not None - } - ) - if ( - node["type"] != "sample" - or type(node["fn"]).__name__ == "_Subsample" - or node["infer"].get("_do_not_score", False) - ): - continue - # grab plate dimensions from the cond_indep_stack - terms["plate_vars"] |= frozenset( - f.name for f in node["cond_indep_stack"] if f.vectorized - ) - # grab the log-measure, found only at sites that are not replayed or observed - if node["funsor"].get("log_measure", None) is not None: - terms["log_measures"].append(node["funsor"]["log_measure"]) - # sum (measure) variables: the fresh non-plate variables at a site - terms["measure_vars"] |= ( - frozenset(node["funsor"]["value"].inputs) | {name} - ) - terms["plate_vars"] - # grab the scale, assuming a common subsampling scale - if ( - node.get("replay_active", False) - and set(node["funsor"]["log_prob"].inputs) & terms["measure_vars"] - and float(to_data(node["funsor"]["scale"])) != 1.0 - ): - # model site that depends on enumerated variable: common scale - terms["scale"] = node["funsor"]["scale"] - else: # otherwise: default scale behavior - node["funsor"]["log_prob"] = ( - node["funsor"]["log_prob"] * node["funsor"]["scale"] - ) - # grab the log-density, found at all sites except those that are not replayed - if node["is_observed"] or not node.get("replay_skipped", False): - terms["log_factors"].append(node["funsor"]["log_prob"]) - # add plate dimensions to the plate_to_step dictionary - terms["plate_to_step"].update( - {plate: terms["plate_to_step"].get(plate, {}) for plate in terms["plate_vars"]} - ) - return terms +from .traceenum_elbo import terms_from_trace @copy_docs_from(_OrigTrace_ELBO) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 15f2c51e7f..7915f7e101 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -4,6 +4,7 @@ import contextlib import funsor +from funsor.adjoint import AdjointTape from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace @@ -11,8 +12,6 @@ from pyro.distributions.util import copy_docs_from from pyro.infer import TraceEnum_ELBO as _OrigTraceEnum_ELBO -from .trace_elbo import Trace_ELBO, terms_from_trace - # Work around a bug in unfold_contraction_generic_tuple interacting with # Approximate introduced in https://github.com/pyro-ppl/funsor/pull/488 . @@ -25,6 +24,70 @@ def apply_optimizer(x): return funsor.interpreter.reinterpret(expr) +def terms_from_trace(tr): + """Helper function to extract elbo components from execution traces.""" + # data structure containing densities, measures, scales, and identification + # of free variables as either product (plate) variables or sum (measure) variables + terms = { + "log_factors": [], + "log_measures": [], + "scale": to_funsor(1.0), + "plate_vars": frozenset(), + "measure_vars": frozenset(), + "plate_to_step": dict(), + } + for name, node in tr.nodes.items(): + # add markov dimensions to the plate_to_step dictionary + if node["type"] == "markov_chain": + terms["plate_to_step"][node["name"]] = node["value"] + # ensure previous step variables are added to measure_vars + for step in node["value"]: + terms["measure_vars"] |= frozenset( + { + var + for var in step[1:-1] + if tr.nodes[var]["funsor"].get("log_measure", None) is not None + } + ) + if ( + node["type"] != "sample" + or type(node["fn"]).__name__ == "_Subsample" + or node["infer"].get("_do_not_score", False) + ): + continue + # grab plate dimensions from the cond_indep_stack + terms["plate_vars"] |= frozenset( + f.name for f in node["cond_indep_stack"] if f.vectorized + ) + # grab the log-measure, found only at sites that are not replayed or observed + if node["funsor"].get("log_measure", None) is not None: + terms["log_measures"].append(node["funsor"]["log_measure"]) + # sum (measure) variables: the fresh non-plate variables at a site + terms["measure_vars"] |= ( + frozenset(node["funsor"]["value"].inputs) | {name} + ) - terms["plate_vars"] + # grab the scale, assuming a common subsampling scale + if ( + node.get("replay_active", False) + and set(node["funsor"]["log_prob"].inputs) & terms["measure_vars"] + and float(to_data(node["funsor"]["scale"])) != 1.0 + ): + # model site that depends on enumerated variable: common scale + terms["scale"] = node["funsor"]["scale"] + else: # otherwise: default scale behavior + node["funsor"]["log_prob"] = ( + node["funsor"]["log_prob"] * node["funsor"]["scale"] + ) + # grab the log-density, found at all sites except those that are not replayed + if node["is_observed"] or not node.get("replay_skipped", False): + terms["log_factors"].append(node["funsor"]["log_prob"]) + # add plate dimensions to the plate_to_step dictionary + terms["plate_to_step"].update( + {plate: terms["plate_to_step"].get(plate, {}) for plate in terms["plate_vars"]} + ) + return terms + + @copy_docs_from(_OrigTraceEnum_ELBO) class TraceMarkovEnum_ELBO(ELBO): def differentiable_loss(self, model, guide, *args, **kwargs): @@ -103,8 +166,99 @@ def differentiable_loss(self, model, guide, *args, **kwargs): return -to_data(apply_optimizer(elbo)) -class TraceEnum_ELBO(Trace_ELBO): - pass +@copy_docs_from(_OrigTraceEnum_ELBO) +class TraceEnum_ELBO(ELBO): + def differentiable_loss(self, model, guide, *args, **kwargs): + + # get batched, enumerated, to_funsor-ed traces from the guide and model + with plate( + size=self.num_particles + ) if self.num_particles > 1 else contextlib.ExitStack(), enum( + first_available_dim=(-self.max_plate_nesting - 1) + if self.max_plate_nesting + else None + ): + guide_tr = trace(guide).get_trace(*args, **kwargs) + model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) + + # extract from traces all metadata that we will need to compute the elbo + guide_terms = terms_from_trace(guide_tr) + model_terms = terms_from_trace(model_tr) + + # build up a lazy expression for the elbo + with funsor.terms.lazy: + # identify and contract out auxiliary variables in the model with partial_sum_product + contracted_factors, uncontracted_factors = [], [] + for f in model_terms["log_factors"]: + if model_terms["measure_vars"].intersection(f.inputs): + contracted_factors.append(f) + else: + uncontracted_factors.append(f) + # incorporate the effects of subsampling and handlers.scale through a common scale factor + contracted_costs = [ + model_terms["scale"] * f + for f in funsor.sum_product.partial_sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + model_terms["log_measures"] + contracted_factors, + plates=model_terms["plate_vars"], + eliminate=model_terms["measure_vars"], + ) + ] + + # accumulate costs from model (logp) and guide (-logq) + costs = contracted_costs + uncontracted_factors # model costs: logp + costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq + + # compute expected cost + # Cf. pyro.infer.util.Dice.compute_expectation() + # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212 + # TODO Replace this with funsor.Expectation + plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] + # compute the marginal logq in the guide corresponding to each cost term + targets = dict() + for cost in costs: + input_vars = frozenset(cost.inputs) + if input_vars not in targets: + targets[input_vars] = funsor.Tensor( + funsor.ops.new_zeros( + funsor.tensor.get_default_prototype(), + tuple(v.size for v in cost.inputs.values()), + ), + cost.inputs, + cost.dtype, + ) + with AdjointTape() as tape: + logzq = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + guide_terms["log_measures"] + list(targets.values()), + plates=plate_vars, + eliminate=(plate_vars | guide_terms["measure_vars"]), + ) + marginals = tape.adjoint( + funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) + ) + # finally, integrate out guide variables in the elbo and all plates + elbo = to_funsor(0, output=funsor.Real) + for cost in costs: + target = targets[frozenset(cost.inputs)] + logzq_local = marginals[target].reduce( + funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars + ) + log_prob = marginals[target] - logzq_local + elbo_term = funsor.Integrate( + log_prob, + cost, + guide_terms["measure_vars"] & frozenset(log_prob.inputs), + ) + elbo += elbo_term.reduce( + funsor.ops.add, plate_vars & frozenset(cost.inputs) + ) + + # evaluate the elbo, using memoize to share tensor computation where possible + with funsor.interpretations.memoize(): + return -to_data(apply_optimizer(elbo)) class JitTraceEnum_ELBO(Jit_ELBO, TraceEnum_ELBO): diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index ee177164f5..27a7fed53a 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -22,7 +22,6 @@ def model(): from contextlib import ExitStack import torch -from pyroapi import handlers from torch import nn from torch.distributions import biject_to @@ -152,8 +151,8 @@ def _create_plates(self, *args, **kwargs): def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure - model = handlers.block(self.model, prototype_hide_fn) - self.prototype_trace = handlers.block(handlers.trace(model).get_trace)( + model = poutine.block(self.model, prototype_hide_fn) + self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( *args, **kwargs ) if self.master is not None: @@ -1161,9 +1160,9 @@ def laplace_approximation(self, *args, **kwargs): Returns a :class:`AutoMultivariateNormal` instance whose posterior's `loc` and `scale_tril` are given by Laplace approximation. """ - guide_trace = handlers.trace(self).get_trace(*args, **kwargs) - model_trace = handlers.trace( - handlers.replay(self.model, trace=guide_trace) + guide_trace = poutine.trace(self).get_trace(*args, **kwargs) + model_trace = poutine.trace( + poutine.replay(self.model, trace=guide_trace) ).get_trace(*args, **kwargs) loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum() @@ -1188,8 +1187,8 @@ class AutoDiscreteParallel(AutoGuide): def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure - model = handlers.block(config_enumerate(self.model), prototype_hide_fn) - self.prototype_trace = handlers.block(handlers.trace(model).get_trace)( + model = poutine.block(config_enumerate(self.model), prototype_hide_fn) + self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( *args, **kwargs ) if self.master is not None: From c06e9e4df65cb87078d6ade32132c49e5fc5d3dc Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 13 Oct 2021 08:09:07 -0400 Subject: [PATCH 31/53] revert hmm changes --- examples/contrib/funsor/hmm.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/contrib/funsor/hmm.py b/examples/contrib/funsor/hmm.py index 1777f05af7..aac6fad61f 100644 --- a/examples/contrib/funsor/hmm.py +++ b/examples/contrib/funsor/hmm.py @@ -700,11 +700,7 @@ def main(args): # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( - handlers.block( - model, - expose_fn=lambda msg: msg["name"] is not None - and msg["name"].startswith("probs_"), - ) + handlers.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_")) ) # To help debug our tensor shapes, let's print the shape of each site's From eee297d620be14b37d455d66f56eb3dc8a7b953b Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 13 Oct 2021 13:22:03 -0400 Subject: [PATCH 32/53] _tensor_to_categorical helper function --- .../contrib/funsor/handlers/enum_messenger.py | 32 +++++++++++-------- setup.py | 2 +- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 43d1d07cc1..66d827fb19 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -180,7 +180,23 @@ def enumerate_site(dist, msg): raise ValueError("{} not valid enum strategy".format(msg)) +def _tensor_to_categorical(tensor, event_name): + from importlib import import_module + + backend = funsor.util.get_backend() + dist = import_module(funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend]) + domain = tensor.inputs[event_name] + event_var = funsor.Variable(event_name, domain) + logits = funsor.Lambda(event_var, tensor) + categorical = dist.CategoricalLogits(logits=logits, value=event_name) + return categorical + + class ProvenanceMessenger(ReentrantMessenger): + """ + Adds provenance dim for all sample sites that are not enumerated. + """ + def _pyro_sample(self, msg): if ( msg["done"] @@ -196,20 +212,10 @@ def _pyro_sample(self, msg): unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)( value=msg["name"] ) - # make Categorical from Tensor if isinstance(unsampled_log_measure, funsor.Tensor): - from importlib import import_module - - backend = funsor.util.get_backend() - dist = import_module( - funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend] - ) - domain = unsampled_log_measure.inputs[msg["name"]] - logits = funsor.Lambda( - funsor.Variable(msg["name"], domain), unsampled_log_measure - ) - unsampled_log_measure = dist.CategoricalLogits( - logits=logits, value=msg["name"] + # convert Tensor to Categorical + unsampled_log_measure = _tensor_to_categorical( + unsampled_log_measure, msg["name"] ) msg["funsor"]["log_measure"] = _enum_strategy_default( unsampled_log_measure, msg diff --git a/setup.py b/setup.py index e251d681f8..93715ed32c 100644 --- a/setup.py +++ b/setup.py @@ -135,7 +135,7 @@ "horovod": ["horovod[pytorch]>=0.19"], "funsor": [ # This must be a released version when Pyro is released. - "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@fix-cnf", + "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@0cf9ac20ef8a9f03d86adceb2a42366f4d028093", # "funsor[torch]==0.4.1", ], }, From d748efa3193e59808cbf913f1dacb830bc28873b Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 13 Oct 2021 16:06:19 -0400 Subject: [PATCH 33/53] lazy to_funsor --- .../contrib/funsor/handlers/enum_messenger.py | 24 ++++--------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 66d827fb19..a692cb3895 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -180,18 +180,6 @@ def enumerate_site(dist, msg): raise ValueError("{} not valid enum strategy".format(msg)) -def _tensor_to_categorical(tensor, event_name): - from importlib import import_module - - backend = funsor.util.get_backend() - dist = import_module(funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend]) - domain = tensor.inputs[event_name] - event_var = funsor.Variable(event_name, domain) - logits = funsor.Lambda(event_var, tensor) - categorical = dist.CategoricalLogits(logits=logits, value=event_name) - return categorical - - class ProvenanceMessenger(ReentrantMessenger): """ Adds provenance dim for all sample sites that are not enumerated. @@ -209,14 +197,11 @@ def _pyro_sample(self, msg): if "funsor" not in msg: msg["funsor"] = {} - unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)( - value=msg["name"] - ) - if isinstance(unsampled_log_measure, funsor.Tensor): - # convert Tensor to Categorical - unsampled_log_measure = _tensor_to_categorical( - unsampled_log_measure, msg["name"] + with funsor.terms.lazy: + unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)( + value=msg["name"] ) + # TODO delegate to enumerate_site msg["funsor"]["log_measure"] = _enum_strategy_default( unsampled_log_measure, msg ) @@ -225,6 +210,7 @@ def _pyro_sample(self, msg): msg["name"], expand=msg["infer"].get("expand", False), ) + # TODO delegate to _get_support_value msg["funsor"]["value"] = funsor.Constant( OrderedDict(((msg["name"], support_value.output),)), support_value, From a1970d60bea8877236d4bf493e1a4a023332f250 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 13 Oct 2021 16:08:39 -0400 Subject: [PATCH 34/53] reduce over particle_var --- pyro/contrib/funsor/infer/trace_elbo.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 9c7be43d0a..b79948d973 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -35,9 +35,14 @@ def differentiable_loss(self, model, guide, *args, **kwargs): model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) + particle_var = ( + frozenset({"num_particles_vectorized"}) + if self.num_particles > 1 + else frozenset() + ) plate_vars = ( guide_terms["plate_vars"] | model_terms["plate_vars"] - ) - frozenset({"num_particles_vectorized"}) + ) - particle_var model_measure_vars = model_terms["measure_vars"] - guide_terms["measure_vars"] with funsor.terms.lazy: @@ -95,10 +100,9 @@ def differentiable_loss(self, model, guide, *args, **kwargs): elbo = to_funsor(0, output=funsor.Real) for cost in costs: target = targets[cost.input_vars] + # FIXME account for normalization factor for unnormalized logzq log_measure = marginals[target] - measure_vars = (frozenset(cost.inputs) - plate_vars) - frozenset( - {"num_particles_vectorized"} - ) + measure_vars = (frozenset(cost.inputs) - plate_vars) - particle_var elbo_term = funsor.Integrate( log_measure, cost, @@ -108,7 +112,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): funsor.ops.add, plate_vars & frozenset(cost.inputs) ) # average over Monte-Carlo particles - elbo = elbo.reduce(funsor.ops.mean) + elbo = elbo.reduce(funsor.ops.mean, particle_var) return -to_data(funsor.optimizer.apply_optimizer(elbo)) From 4c1ee9e8d4ad95f746bdd576013cef1744452143 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 13 Oct 2021 17:17:18 -0400 Subject: [PATCH 35/53] address comment in tests --- tests/contrib/funsor/test_gradient.py | 550 ++++++++++---------------- 1 file changed, 210 insertions(+), 340 deletions(-) diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index f10333ac67..e01241eac0 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -136,41 +136,53 @@ def test_gradient(model, guide, data): assert_equal(actual_grads, expected_grads, prec=0.02) -@pytest.mark.parametrize( - "model,guide,data", - [ - (model_0, guide_0, torch.tensor([-0.5, 2.0])), - (model_1, guide_1, torch.tensor([-0.5, 2.0])), - (model_2, guide_2, torch.tensor([0.0, 1.0])), - ], -) -def test_guide_enum_gradient(model, guide, data): +@pyro_backend("contrib.funsor") +def test_particle_gradient_0(): + pyro.clear_param_store() + data = torch.tensor([-0.5, 2.0]) + + def model(): + with pyro.plate("data", len(data)): + z = pyro.sample("z", dist.Poisson(3)) + pyro.sample("x", dist.Normal(z, 1), obs=data) + + def guide(): + # set this to ensure rng agrees across runs + # this should be ok since we are comparing a single particle gradients + pyro.set_rng_seed(0) + with pyro.plate("data", len(data)): + rate = pyro.param("rate", lambda: torch.tensor([3.5, 1.5])) + pyro.sample("z", dist.Poisson(rate)) + + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) - # Expected grads based on exact integration - with pyro_backend("pyro"): - pyro.clear_param_store() - elbo = infer.TraceEnum_ELBO( - max_plate_nesting=1, # set this to ensure rng agrees across runs - strict_enumeration_warning=False, - ) - elbo.loss_and_grads(model, infer.config_enumerate(guide), data) - params = dict(pyro.get_param_store().named_parameters()) - expected_grads = { - name: param.grad.detach().cpu() for name, param in params.items() - } + # compute loss and get trace + guide_tr = handlers.trace(elbo.loss_and_grads).get_trace(model, guide) + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() - # Actual grads - with pyro_backend("contrib.funsor"): - pyro.clear_param_store() - elbo = infer.Trace_ELBO( - max_plate_nesting=1, # set this to ensure rng agrees across runs - strict_enumeration_warning=False, - ) - elbo.loss_and_grads(model, infer.config_enumerate(guide), data) - params = dict(pyro.get_param_store().named_parameters()) - actual_grads = { - name: param.grad.detach().cpu() for name, param in params.items() - } + # Trace_ELBO gradients + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + # Hand derived gradients + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + + z = guide_tr.nodes["z"]["value"]._t.detach() + rate = pyro.param("rate").detach() + logpx = model_tr.nodes["x"]["log_prob"]._t.detach() + logpz = model_tr.nodes["z"]["log_prob"]._t.detach() + logqz = guide_tr.nodes["z"]["log_prob"]._t.detach() + + loss_i = logpx + logpz - logqz + dlogq_drate = z / rate - 1 + expected_grads = { + "rate": -(dlogq_drate * loss_i - dlogq_drate), + } for name in sorted(params): logger.info("expected {} = {}".format(name, expected_grads[name])) @@ -179,325 +191,183 @@ def test_guide_enum_gradient(model, guide, data): assert_equal(actual_grads, expected_grads, prec=1e-4) -@pytest.mark.parametrize( - "Elbo,backend", - [ - ("TraceEnum_ELBO", "pyro"), - ("Trace_ELBO", "contrib.funsor"), - ], -) -def test_particle_gradient_0(Elbo, backend): - with pyro_backend(backend): - pyro.clear_param_store() - data = torch.tensor([-0.5, 2.0]) +@pyro_backend("contrib.funsor") +def test_particle_gradient_1(): + pyro.clear_param_store() + data = torch.tensor([-0.5, 2.0]) - def model(): - with pyro.plate("data", len(data)): - z = pyro.sample("z", dist.Poisson(3)) - pyro.sample("x", dist.Normal(z, 1), obs=data) + def model(): + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + rate = torch.tensor([2.0, 3.0]) + b = pyro.sample("b", dist.Poisson(rate[a.long()])) + pyro.sample("c", dist.Normal(b, 1), obs=data) - def guide(): - with pyro.plate("data", len(data)): - rate = pyro.param("rate", lambda: torch.tensor([3.5, 1.5])) - pyro.sample("z", dist.Poisson(rate)) - - elbo = getattr(infer, Elbo)( - max_plate_nesting=1, # set this to ensure rng agrees across runs - num_particles=1, - strict_enumeration_warning=False, - ) - - # Elbo gradient estimator + def guide(): + # set this to ensure rng agrees across runs + # this should be ok since we are comparing a single particle gradients pyro.set_rng_seed(0) - elbo.loss_and_grads(model, guide) - params = dict(pyro.get_param_store().named_parameters()) - actual_grads = { - name: param.grad.detach().cpu() for name, param in params.items() - } - - # capture sample values and log_probs - pyro.set_rng_seed(0) - guide_tr = handlers.trace(guide).get_trace() - model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() - guide_tr.compute_log_prob() - model_tr.compute_log_prob() - z = guide_tr.nodes["z"]["value"].data - rate = pyro.param("rate").data - - loss_i = ( - model_tr.nodes["x"]["log_prob"].data - + model_tr.nodes["z"]["log_prob"].data - - guide_tr.nodes["z"]["log_prob"].data + prob = pyro.param( + "prob", + lambda: torch.tensor(0.5), ) - dlogq_drate = z / rate - 1 - expected_grads = { - "rate": -(dlogq_drate * loss_i - dlogq_drate), - } - - for name in sorted(params): - logger.info("expected {} = {}".format(name, expected_grads[name])) - logger.info("actual {} = {}".format(name, actual_grads[name])) + a = pyro.sample("a", dist.Bernoulli(prob)) + with pyro.plate("data", len(data)): + rate = pyro.param("rate", lambda: torch.tensor([[3.5, 1.5], [0.5, 2.5]])) + pyro.sample("b", dist.Poisson(rate[a.long()])) + + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) - assert_equal(actual_grads, expected_grads, prec=1e-4) + # compute loss and get trace + guide_tr = handlers.trace(elbo.loss_and_grads).get_trace(model, guide) + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + + # Trace_ELBO gradients + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + # Hand derived gradients + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + + a = guide_tr.nodes["a"]["value"]._t.detach() + b = guide_tr.nodes["b"]["value"]._t.detach() + prob = pyro.param("prob").detach() + rate = pyro.param("rate").detach() + logpa = model_tr.nodes["a"]["log_prob"]._t.detach() + logpb = model_tr.nodes["b"]["log_prob"]._t.detach() + logpc = model_tr.nodes["c"]["log_prob"]._t.detach() + logqa = guide_tr.nodes["a"]["log_prob"]._t.detach() + logqb = guide_tr.nodes["b"]["log_prob"]._t.detach() + + dlogqa_dprob = (a - prob) / (prob * (1 - prob)) + dlogqb_drate = b / rate[a.long()] - 1 + + loss_a = logpa - logqa + loss_bc = logpb + logpc - logqb + expected_grads = { + "prob": -(dlogqa_dprob * (loss_a + loss_bc.sum()) - dlogqa_dprob), + "rate": -(dlogqb_drate * (loss_bc) - dlogqb_drate), + } + actual_grads["rate"] = actual_grads["rate"][a.long()] + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) -@pytest.mark.parametrize( - "Elbo,backend", - [ - ("TraceEnum_ELBO", "pyro"), - ("Trace_ELBO", "contrib.funsor"), - ], -) -def test_particle_gradient_1(Elbo, backend): - with pyro_backend(backend): - pyro.clear_param_store() - data = torch.tensor([-0.5, 2.0]) - - def model(): - a = pyro.sample("a", dist.Bernoulli(0.3)) - with pyro.plate("data", len(data)): - rate = torch.tensor([2.0, 3.0]) - b = pyro.sample("b", dist.Poisson(rate[a.long()])) - pyro.sample("c", dist.Normal(b, 1), obs=data) - - def guide(): - prob = pyro.param( - "prob", - lambda: torch.tensor(0.5), - ) - a = pyro.sample("a", dist.Bernoulli(prob)) - with pyro.plate("data", len(data)): - rate = pyro.param( - "rate", lambda: torch.tensor([[3.5, 1.5], [0.5, 2.5]]) - ) - pyro.sample("b", dist.Poisson(rate[a.long()])) - - elbo = getattr(infer, Elbo)( - max_plate_nesting=1, # set this to ensure rng agrees across runs - num_particles=1, - strict_enumeration_warning=False, - ) + assert_equal(actual_grads, expected_grads, prec=1e-4) - # Elbo gradient estimator - pyro.set_rng_seed(0) - elbo.loss_and_grads(model, guide) - params = dict(pyro.get_param_store().named_parameters()) - actual_grads = { - name: param.grad.detach().cpu() for name, param in params.items() - } - # capture sample values and log_probs +@pyro_backend("contrib.funsor") +def test_particle_gradient_2(): + pyro.clear_param_store() + data = torch.tensor([0.0, 1.0]) + + def model(): + prob_b = torch.tensor([0.3, 0.4]) + prob_c = torch.tensor([0.5, 0.6]) + prob_d = torch.tensor([0.2, 0.3]) + prob_e = torch.tensor([0.5, 0.1]) + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) + c = pyro.sample("c", dist.Bernoulli(prob_c[b.long()])) + pyro.sample("d", dist.Bernoulli(prob_d[b.long()])) + pyro.sample("e", dist.Bernoulli(prob_e[c.long()]), obs=data) + + def guide(): + # set this to ensure rng agrees across runs + # this should be ok since we are comparing a single particle gradients pyro.set_rng_seed(0) - guide_tr = handlers.trace(guide).get_trace() - model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() - guide_tr.compute_log_prob() - model_tr.compute_log_prob() - a = guide_tr.nodes["a"]["value"].data - b = guide_tr.nodes["b"]["value"].data - prob = pyro.param("prob").data - rate = pyro.param("rate").data - - dlogqa_dprob = (a - prob) / (prob * (1 - prob)) - dlogqb_drate = b / rate[a.long()] - 1 - - loss_a = ( - model_tr.nodes["a"]["log_prob"].data - guide_tr.nodes["a"]["log_prob"].data - ) - loss_bc = ( - model_tr.nodes["b"]["log_prob"].data - + model_tr.nodes["c"]["log_prob"].data - - guide_tr.nodes["b"]["log_prob"].data - ) - expected_grads = { - "prob": -(dlogqa_dprob * (loss_a + loss_bc.sum()) - dlogqa_dprob), - "rate": -(dlogqb_drate * (loss_bc) - dlogqb_drate), - } - actual_grads["rate"] = actual_grads["rate"][a.long()] - - for name in sorted(params): - logger.info("expected {} = {}".format(name, expected_grads[name])) - logger.info("actual {} = {}".format(name, actual_grads[name])) - - assert_equal(actual_grads, expected_grads, prec=1e-4) + prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) + prob_b = pyro.param("prob_b", lambda: torch.tensor([0.4, 0.3])) + prob_c = pyro.param("prob_c", lambda: torch.tensor([[0.3, 0.8], [0.2, 0.5]])) + prob_d = pyro.param("prob_d", lambda: torch.tensor([[0.2, 0.9], [0.1, 0.4]])) + a = pyro.sample("a", dist.Bernoulli(prob_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) + pyro.sample("c", dist.Bernoulli(Vindex(prob_c)[b.long(), idx])) + pyro.sample("d", dist.Bernoulli(Vindex(prob_d)[b.long(), idx])) + + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + # compute loss and get trace + guide_tr = handlers.trace(elbo.loss_and_grads).get_trace(model, guide) + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + + # Trace_ELBO gradients + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + # Hand derived gradients + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + + a = guide_tr.nodes["a"]["value"]._t.detach() + b = guide_tr.nodes["b"]["value"]._t.detach() + c = guide_tr.nodes["c"]["value"]._t.detach() + + proba = pyro.param("prob_a").detach() + probb = pyro.param("prob_b").detach() + probc = pyro.param("prob_c").detach() + probd = pyro.param("prob_d").detach() + + logpa = model_tr.nodes["a"]["log_prob"]._t.detach() + logpba = model_tr.nodes["b"]["log_prob"]._t.detach() + logpcba = model_tr.nodes["c"]["log_prob"]._t.detach() + logpdba = model_tr.nodes["d"]["log_prob"]._t.detach() + logpecba = model_tr.nodes["e"]["log_prob"]._t.detach() + + logqa = guide_tr.nodes["a"]["log_prob"]._t.detach() + logqba = guide_tr.nodes["b"]["log_prob"]._t.detach() + logqcba = guide_tr.nodes["c"]["log_prob"]._t.detach() + logqdba = guide_tr.nodes["d"]["log_prob"]._t.detach() + + idx = torch.arange(2) + dlogqa_dproba = (a - proba) / (proba * (1 - proba)) + dlogqb_dprobb = (b - probb[a.long()]) / (probb[a.long()] * (1 - probb[a.long()])) + dlogqc_dprobc = (c - Vindex(probc)[b.long(), idx]) / ( + Vindex(probc)[b.long(), idx] * (1 - Vindex(probc)[b.long(), idx]) + ) + dlogqd_dprobd = (c - probd[b.long(), idx]) / ( + Vindex(probd)[b.long(), idx] * (1 - Vindex(probd)[b.long(), idx]) + ) -@pytest.mark.parametrize( - "Elbo,backend", - [ - ("TraceEnum_ELBO", "pyro"), - ("Trace_ELBO", "contrib.funsor"), - ("TraceGraph_ELBO", "pyro"), - ], -) -def test_particle_gradient_2(Elbo, backend): - with pyro_backend(backend): - pyro.clear_param_store() - data = torch.tensor([0.0, 1.0]) - - def model(): - prob_b = torch.tensor([0.3, 0.4]) - prob_c = torch.tensor([0.5, 0.6]) - prob_d = torch.tensor([0.2, 0.3]) - prob_e = torch.tensor([0.5, 0.1]) - a = pyro.sample("a", dist.Bernoulli(0.3)) - with pyro.plate("data", len(data)): - b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) - c = pyro.sample("c", dist.Bernoulli(prob_c[b.long()])) - pyro.sample("d", dist.Bernoulli(prob_d[b.long()])) - pyro.sample("e", dist.Bernoulli(prob_e[c.long()]), obs=data) - - def guide(): - prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) - prob_b = pyro.param("prob_b", lambda: torch.tensor([0.4, 0.3])) - prob_c = pyro.param( - "prob_c", lambda: torch.tensor([[0.3, 0.8], [0.2, 0.5]]) + # fine-grained Rao-Blackwellization based on provenance tracking + expected_grads = { + "prob_a": -dlogqa_dproba + * ( + logpa + + (logpba + logpcba + logpdba + logpecba).sum() + - logqa + - (logqba + logqcba + logqdba).sum() + - 1 + ), + "prob_b": ( + -dlogqb_dprobb + * ( + (logpba + logpcba + logpdba + logpecba) + - (logqba + logqcba + logqdba) + - 1 ) - prob_d = pyro.param( - "prob_d", lambda: torch.tensor([[0.2, 0.9], [0.1, 0.4]]) - ) - a = pyro.sample("a", dist.Bernoulli(prob_a)) - with pyro.plate("data", len(data)) as idx: - b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) - pyro.sample("c", dist.Bernoulli(Vindex(prob_c)[b.long(), idx])) - pyro.sample("d", dist.Bernoulli(Vindex(prob_d)[b.long(), idx])) - - elbo = getattr(infer, Elbo)( - max_plate_nesting=1, # set this to ensure rng agrees across runs - num_particles=1, - strict_enumeration_warning=False, - ) - - # Elbo gradient estimator - pyro.set_rng_seed(0) - elbo.loss_and_grads(model, guide) - params = dict(pyro.get_param_store().named_parameters()) - actual_grads = { - name: param.grad.detach().cpu() for name, param in params.items() - } + ).sum(), + "prob_c": -dlogqc_dprobc * (logpcba + logpecba - logqcba - 1), + "prob_d": -dlogqd_dprobd * (logpdba - logqdba - 1), + } + actual_grads["prob_b"] = actual_grads["prob_b"][a.long()] + actual_grads["prob_c"] = Vindex(actual_grads["prob_c"])[b.long(), idx] + actual_grads["prob_d"] = Vindex(actual_grads["prob_d"])[b.long(), idx] - # capture sample values and log_probs - pyro.set_rng_seed(0) - guide_tr = handlers.trace(guide).get_trace() - model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() - guide_tr.compute_log_prob() - model_tr.compute_log_prob() - - a = guide_tr.nodes["a"]["value"].data - b = guide_tr.nodes["b"]["value"].data - c = guide_tr.nodes["c"]["value"].data - - proba = pyro.param("prob_a").data - probb = pyro.param("prob_b").data - probc = pyro.param("prob_c").data - probd = pyro.param("prob_d").data - - logpa = model_tr.nodes["a"]["log_prob"].data - logpba = model_tr.nodes["b"]["log_prob"].data - logpcba = model_tr.nodes["c"]["log_prob"].data - logpdba = model_tr.nodes["d"]["log_prob"].data - logpecba = model_tr.nodes["e"]["log_prob"].data - - logqa = guide_tr.nodes["a"]["log_prob"].data - logqba = guide_tr.nodes["b"]["log_prob"].data - logqcba = guide_tr.nodes["c"]["log_prob"].data - logqdba = guide_tr.nodes["d"]["log_prob"].data - - idx = torch.arange(2) - dlogqa_dproba = (a - proba) / (proba * (1 - proba)) - dlogqb_dprobb = (b - probb[a.long()]) / ( - probb[a.long()] * (1 - probb[a.long()]) - ) - dlogqc_dprobc = (c - Vindex(probc)[b.long(), idx]) / ( - Vindex(probc)[b.long(), idx] * (1 - Vindex(probc)[b.long(), idx]) - ) - dlogqd_dprobd = (c - probd[b.long(), idx]) / ( - Vindex(probd)[b.long(), idx] * (1 - Vindex(probd)[b.long(), idx]) - ) + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) - if Elbo == "Trace_ELBO": - # fine-grained Rao-Blackwellization based on provenance tracking - expected_grads = { - "prob_a": -dlogqa_dproba - * ( - logpa - + (logpba + logpcba + logpdba + logpecba).sum() - - logqa - - (logqba + logqcba + logqdba).sum() - - 1 - ), - "prob_b": ( - -dlogqb_dprobb - * ( - (logpba + logpcba + logpdba + logpecba) - - (logqba + logqcba + logqdba) - - 1 - ) - ).sum(), - "prob_c": -dlogqc_dprobc * (logpcba + logpecba - logqcba - 1), - "prob_d": -dlogqd_dprobd * (logpdba - logqdba - 1), - } - elif Elbo == "TraceEnum_ELBO": - # only uses plate conditional independence for Rao-Blackwellization - expected_grads = { - "prob_a": -dlogqa_dproba - * ( - logpa - + (logpba + logpcba + logpdba + logpecba).sum() - - logqa - - (logqba + logqcba + logqdba).sum() - - 1 - ), - "prob_b": ( - -dlogqb_dprobb - * ( - (logpba + logpcba + logpdba + logpecba) - - (logqba + logqcba + logqdba) - - 1 - ) - ).sum(), - "prob_c": -dlogqc_dprobc - * ( - (logpba + logpcba + logpdba + logpecba) - - (logqba + logqcba + logqdba) - - 1 - ), - "prob_d": -dlogqd_dprobd - * ( - (logpba + logpcba + logpdba + logpecba) - - (logqba + logqcba + logqdba) - - 1 - ), - } - elif Elbo == "TraceGraph_ELBO": - # Raw-Blackwellization uses conditional independence based on - # 1) the sequential order of samples - # 2) plate generators - # additionally removes dlogq_dparam terms - expected_grads = { - "prob_a": -dlogqa_dproba - * ( - logpa - + (logpba + logpcba + logpdba + logpecba).sum() - - logqa - - (logqba + logqcba + logqdba).sum() - ), - "prob_b": ( - -dlogqb_dprobb - * ( - (logpba + logpcba + logpdba + logpecba) - - (logqba + logqcba + logqdba) - ) - ).sum(), - "prob_c": -dlogqc_dprobc - * ((logpcba + logpdba + logpecba) - (logqcba + logqdba)), - "prob_d": -dlogqd_dprobd * (logpdba + logpecba - logqdba), - } - actual_grads["prob_b"] = actual_grads["prob_b"][a.long()] - actual_grads["prob_c"] = Vindex(actual_grads["prob_c"])[b.long(), idx] - actual_grads["prob_d"] = Vindex(actual_grads["prob_d"])[b.long(), idx] - - for name in sorted(params): - logger.info("expected {} = {}".format(name, expected_grads[name])) - logger.info("actual {} = {}".format(name, actual_grads[name])) - - assert_equal(actual_grads, expected_grads, prec=1e-4) + assert_equal(actual_grads, expected_grads, prec=1e-4) From 5df30c8c8c894cecf76c293f7999b22339e0a463 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 13 Oct 2021 17:43:36 -0400 Subject: [PATCH 36/53] import pyroapi --- tests/contrib/funsor/test_gradient.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index e01241eac0..886c3daba3 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -3,6 +3,7 @@ import logging +import pyroapi import pytest import torch @@ -17,7 +18,7 @@ funsor.set_backend("torch") from pyroapi import distributions as dist - from pyroapi import handlers, infer, pyro, pyro_backend + from pyroapi import handlers, infer, pyro except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") @@ -102,7 +103,7 @@ def guide_2(data): def test_gradient(model, guide, data): # Expected grads based on exact integration - with pyro_backend("pyro"): + with pyroapi.pyro_backend("pyro"): pyro.clear_param_store() elbo = infer.TraceEnum_ELBO( max_plate_nesting=1, # set this to ensure rng agrees across runs @@ -115,7 +116,7 @@ def test_gradient(model, guide, data): } # Actual grads averaged over num_particles - with pyro_backend("contrib.funsor"): + with pyroapi.pyro_backend("contrib.funsor"): pyro.clear_param_store() elbo = infer.Trace_ELBO( max_plate_nesting=1, # set this to ensure rng agrees across runs @@ -136,7 +137,7 @@ def test_gradient(model, guide, data): assert_equal(actual_grads, expected_grads, prec=0.02) -@pyro_backend("contrib.funsor") +@pyroapi.pyro_backend("contrib.funsor") def test_particle_gradient_0(): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) @@ -191,7 +192,7 @@ def guide(): assert_equal(actual_grads, expected_grads, prec=1e-4) -@pyro_backend("contrib.funsor") +@pyroapi.pyro_backend("contrib.funsor") def test_particle_gradient_1(): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) @@ -262,7 +263,7 @@ def guide(): assert_equal(actual_grads, expected_grads, prec=1e-4) -@pyro_backend("contrib.funsor") +@pyroapi.pyro_backend("contrib.funsor") def test_particle_gradient_2(): pyro.clear_param_store() data = torch.tensor([0.0, 1.0]) From 46ff6f4cc7b31c1a37cf6bc0b2ff000e33cf8e0b Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 14 Oct 2021 19:32:17 -0400 Subject: [PATCH 37/53] compute expected grads using dice factors --- tests/contrib/funsor/test_gradient.py | 234 +++++++++++++++----------- 1 file changed, 132 insertions(+), 102 deletions(-) diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index 886c3daba3..97e75d3522 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -139,7 +139,15 @@ def test_gradient(model, guide, data): @pyroapi.pyro_backend("contrib.funsor") def test_particle_gradient_0(): - pyro.clear_param_store() + # model + # +---------+ + # | z --> x | + # +---------+ + # + # guide + # +---+ + # | z | + # +---+ data = torch.tensor([-0.5, 2.0]) def model(): @@ -161,29 +169,34 @@ def guide(): strict_enumeration_warning=False, ) - # compute loss and get trace - guide_tr = handlers.trace(elbo.loss_and_grads).get_trace(model, guide) - model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() - # Trace_ELBO gradients + pyro.clear_param_store() + elbo.loss_and_grads(model, guide) params = dict(pyro.get_param_store().named_parameters()) actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} # Hand derived gradients + # elbo = Expectation( + # sum(dice_factor_zi * (log_pzi + log_pxi - log_qzi)) + # ) + pyro.clear_param_store() + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() guide_tr.compute_log_prob() model_tr.compute_log_prob() - - z = guide_tr.nodes["z"]["value"]._t.detach() - rate = pyro.param("rate").detach() - logpx = model_tr.nodes["x"]["log_prob"]._t.detach() - logpz = model_tr.nodes["z"]["log_prob"]._t.detach() - logqz = guide_tr.nodes["z"]["log_prob"]._t.detach() - - loss_i = logpx + logpz - logqz - dlogq_drate = z / rate - 1 - expected_grads = { - "rate": -(dlogq_drate * loss_i - dlogq_drate), - } + # log factors + logpx = model_tr.nodes["x"]["log_prob"] + logpz = model_tr.nodes["z"]["log_prob"] + logqz = guide_tr.nodes["z"]["log_prob"] + # dice factor + df_z = (logqz - logqz.detach()).exp() + # dice elbo + dice_elbo = (df_z * (logpz + logpx - logqz)).sum() + # backward run + loss = -dice_elbo + loss.backward() + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = {name: param.grad.detach().cpu() for name, param in params.items()} for name in sorted(params): logger.info("expected {} = {}".format(name, expected_grads[name])) @@ -194,7 +207,15 @@ def guide(): @pyroapi.pyro_backend("contrib.funsor") def test_particle_gradient_1(): - pyro.clear_param_store() + # model + # +-----------+ + # a -|-> b --> c | + # +-----------+ + # + # guide + # +-----+ + # a -|-> b | + # +-----+ data = torch.tensor([-0.5, 2.0]) def model(): @@ -223,38 +244,47 @@ def guide(): strict_enumeration_warning=False, ) - # compute loss and get trace - guide_tr = handlers.trace(elbo.loss_and_grads).get_trace(model, guide) - model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() - # Trace_ELBO gradients + pyro.clear_param_store() + elbo.loss_and_grads(model, guide) params = dict(pyro.get_param_store().named_parameters()) actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} # Hand derived gradients + # elbo = Expectation( + # q(a) * log_pa + # + q(a) * sum(q(b_i|a) * log_pb_i) + # + q(a) * sum(q(b_i|a) * log_pc_i) + # - q(a) * log_qa + # - q(a) * sum(q(b_i|a) * log_qb_i) + # ) + pyro.clear_param_store() + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() guide_tr.compute_log_prob() model_tr.compute_log_prob() - - a = guide_tr.nodes["a"]["value"]._t.detach() - b = guide_tr.nodes["b"]["value"]._t.detach() - prob = pyro.param("prob").detach() - rate = pyro.param("rate").detach() - logpa = model_tr.nodes["a"]["log_prob"]._t.detach() - logpb = model_tr.nodes["b"]["log_prob"]._t.detach() - logpc = model_tr.nodes["c"]["log_prob"]._t.detach() - logqa = guide_tr.nodes["a"]["log_prob"]._t.detach() - logqb = guide_tr.nodes["b"]["log_prob"]._t.detach() - - dlogqa_dprob = (a - prob) / (prob * (1 - prob)) - dlogqb_drate = b / rate[a.long()] - 1 - - loss_a = logpa - logqa - loss_bc = logpb + logpc - logqb - expected_grads = { - "prob": -(dlogqa_dprob * (loss_a + loss_bc.sum()) - dlogqa_dprob), - "rate": -(dlogqb_drate * (loss_bc) - dlogqb_drate), - } - actual_grads["rate"] = actual_grads["rate"][a.long()] + # log factors + logpa = model_tr.nodes["a"]["log_prob"] + logpb = model_tr.nodes["b"]["log_prob"] + logpc = model_tr.nodes["c"]["log_prob"] + logqa = guide_tr.nodes["a"]["log_prob"] + logqb = guide_tr.nodes["b"]["log_prob"] + # dice factors + df_a = (logqa - logqa.detach()).exp() + df_b = (logqb - logqb.detach()).exp() + # dice elbo + dice_elbo = ( + df_a * logpa + + df_a * (df_b * logpb).sum() + + df_a * (df_b * logpc).sum() + - df_a * logqa + - df_a * (df_b * logqb).sum() + ) + # backward run + loss = -dice_elbo + loss.backward() + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = {name: param.grad.detach().cpu() for name, param in params.items()} for name in sorted(params): logger.info("expected {} = {}".format(name, expected_grads[name])) @@ -265,7 +295,17 @@ def guide(): @pyroapi.pyro_backend("contrib.funsor") def test_particle_gradient_2(): - pyro.clear_param_store() + # model + # +-----------------+ + # a -|-> b --> c --> e | + # | \--> d | + # +-----------------+ + # + # guide + # +-----------+ + # a -|-> b --> c | + # | \--> d | + # +-----------+ data = torch.tensor([0.0, 1.0]) def model(): @@ -300,72 +340,62 @@ def guide(): strict_enumeration_warning=False, ) - # compute loss and get trace - guide_tr = handlers.trace(elbo.loss_and_grads).get_trace(model, guide) - model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() - # Trace_ELBO gradients + pyro.clear_param_store() + elbo.loss_and_grads(model, guide) params = dict(pyro.get_param_store().named_parameters()) actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} # Hand derived gradients + # elbo = Expectation( + # q(a) * log_pa + # + q(a) * sum(q(b_i|a) * log_pb_i) + # + q(a) * sum(q(b_i|a) * q(c_i|b_i) * log_pc_i) + # + q(a) * sum(q(b_i|a) * q(c_i|b_i) * log_pe_i) + # + q(a) * sum(q(b_i|a) * q(d_i|b_i) * log_pd_i) + # - q(a) * log_qa + # - q(a) * sum(q(b_i|a) * log_qb_i) + # - q(a) * sum(q(b_i|a) * q(c_i|b_i) * log_pq_i) + # - q(a) * sum(q(b_i|a) * q(d_i|b_i) * log_pd_i) + # ) + pyro.clear_param_store() + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() guide_tr.compute_log_prob() model_tr.compute_log_prob() - - a = guide_tr.nodes["a"]["value"]._t.detach() - b = guide_tr.nodes["b"]["value"]._t.detach() - c = guide_tr.nodes["c"]["value"]._t.detach() - - proba = pyro.param("prob_a").detach() - probb = pyro.param("prob_b").detach() - probc = pyro.param("prob_c").detach() - probd = pyro.param("prob_d").detach() - - logpa = model_tr.nodes["a"]["log_prob"]._t.detach() - logpba = model_tr.nodes["b"]["log_prob"]._t.detach() - logpcba = model_tr.nodes["c"]["log_prob"]._t.detach() - logpdba = model_tr.nodes["d"]["log_prob"]._t.detach() - logpecba = model_tr.nodes["e"]["log_prob"]._t.detach() - - logqa = guide_tr.nodes["a"]["log_prob"]._t.detach() - logqba = guide_tr.nodes["b"]["log_prob"]._t.detach() - logqcba = guide_tr.nodes["c"]["log_prob"]._t.detach() - logqdba = guide_tr.nodes["d"]["log_prob"]._t.detach() - - idx = torch.arange(2) - dlogqa_dproba = (a - proba) / (proba * (1 - proba)) - dlogqb_dprobb = (b - probb[a.long()]) / (probb[a.long()] * (1 - probb[a.long()])) - dlogqc_dprobc = (c - Vindex(probc)[b.long(), idx]) / ( - Vindex(probc)[b.long(), idx] * (1 - Vindex(probc)[b.long(), idx]) + # log factors + logpa = model_tr.nodes["a"]["log_prob"] + logpb = model_tr.nodes["b"]["log_prob"] + logpc = model_tr.nodes["c"]["log_prob"] + logpd = model_tr.nodes["d"]["log_prob"] + logpe = model_tr.nodes["e"]["log_prob"] + + logqa = guide_tr.nodes["a"]["log_prob"] + logqb = guide_tr.nodes["b"]["log_prob"] + logqc = guide_tr.nodes["c"]["log_prob"] + logqd = guide_tr.nodes["d"]["log_prob"] + # dice factors + df_a = (logqa - logqa.detach()).exp() + df_b = (logqb - logqb.detach()).exp() + df_c = (logqc - logqc.detach()).exp() + df_d = (logqd - logqd.detach()).exp() + # dice elbo + dice_elbo = ( + df_a * logpa + + df_a * (df_b * logpb).sum() + + df_a * (df_b * df_c * logpc).sum() + + df_a * (df_b * df_c * logpe).sum() + + df_a * (df_b * df_d * logpd).sum() + - df_a * logqa + - df_a * (df_b * logqb).sum() + - df_a * (df_b * df_c * logqc).sum() + - df_a * (df_b * df_d * logqd).sum() ) - dlogqd_dprobd = (c - probd[b.long(), idx]) / ( - Vindex(probd)[b.long(), idx] * (1 - Vindex(probd)[b.long(), idx]) - ) - - # fine-grained Rao-Blackwellization based on provenance tracking - expected_grads = { - "prob_a": -dlogqa_dproba - * ( - logpa - + (logpba + logpcba + logpdba + logpecba).sum() - - logqa - - (logqba + logqcba + logqdba).sum() - - 1 - ), - "prob_b": ( - -dlogqb_dprobb - * ( - (logpba + logpcba + logpdba + logpecba) - - (logqba + logqcba + logqdba) - - 1 - ) - ).sum(), - "prob_c": -dlogqc_dprobc * (logpcba + logpecba - logqcba - 1), - "prob_d": -dlogqd_dprobd * (logpdba - logqdba - 1), - } - actual_grads["prob_b"] = actual_grads["prob_b"][a.long()] - actual_grads["prob_c"] = Vindex(actual_grads["prob_c"])[b.long(), idx] - actual_grads["prob_d"] = Vindex(actual_grads["prob_d"])[b.long(), idx] + # backward run + loss = -dice_elbo + loss.backward() + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = {name: param.grad.detach().cpu() for name, param in params.items()} for name in sorted(params): logger.info("expected {} = {}".format(name, expected_grads[name])) From d7ee7ee56c6ac074419d7fbbad1d5b0e86fe04df Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 14 Oct 2021 21:11:15 -0400 Subject: [PATCH 38/53] add test with guide enumeration --- .../contrib/funsor/handlers/enum_messenger.py | 7 + .../funsor/handlers/named_messenger.py | 7 +- tests/contrib/funsor/test_gradient.py | 131 ++++++++++++++++-- 3 files changed, 130 insertions(+), 15 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index a692cb3895..e0ff866383 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -59,6 +59,13 @@ def _get_support_value_tensor(funsor_dist, name, **kwargs): ) +@_get_support_value.register(funsor.Constant) +def _get_support_value_tensor(funsor_dist, name, **kwargs): + assert name in funsor_dist.inputs + value = _get_support_value(funsor_dist.arg, name, **kwargs) + return funsor.Constant(funsor_dist.const_inputs, value) + + @_get_support_value.register(funsor.distribution.Distribution) def _get_support_value_distribution(funsor_dist, name, expand=False): assert name == funsor_dist.value.name diff --git a/pyro/contrib/funsor/handlers/named_messenger.py b/pyro/contrib/funsor/handlers/named_messenger.py index ff65a18223..966b94f545 100644 --- a/pyro/contrib/funsor/handlers/named_messenger.py +++ b/pyro/contrib/funsor/handlers/named_messenger.py @@ -4,6 +4,8 @@ from collections import OrderedDict from contextlib import ExitStack +import funsor + from pyro.contrib.funsor.handlers.runtime import ( _DIM_STACK, DimRequest, @@ -64,7 +66,10 @@ def _pyro_to_data(msg): name_to_dim = msg["kwargs"].setdefault("name_to_dim", OrderedDict()) dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL) - batch_names = tuple(funsor_value.inputs.keys()) + if isinstance(funsor_value, funsor.Constant): + batch_names = tuple(funsor_value.arg.inputs.keys()) + else: + batch_names = tuple(funsor_value.inputs.keys()) # interpret all names/dims as requests since we only run this function once name_to_dim_request = name_to_dim.copy() diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index 97e75d3522..4fe6c471a1 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -176,8 +176,10 @@ def guide(): actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} # Hand derived gradients - # elbo = Expectation( - # sum(dice_factor_zi * (log_pzi + log_pxi - log_qzi)) + # elbo = MonteCarlo( + # [q(z_i) * log_pz_i].sum(i) + # + [q(z_i) * log_px_i].sum(i) + # - [q(z_i) * log_qx_i].sum(i) # ) pyro.clear_param_store() guide_tr = handlers.trace(guide).get_trace() @@ -251,12 +253,12 @@ def guide(): actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} # Hand derived gradients - # elbo = Expectation( + # elbo = MonteCarlo( # q(a) * log_pa - # + q(a) * sum(q(b_i|a) * log_pb_i) - # + q(a) * sum(q(b_i|a) * log_pc_i) + # + q(a) * [q(b_i|a) * log_pb_i].sum(i) + # + q(a) * [q(b_i|a) * log_pc_i].sum(i) # - q(a) * log_qa - # - q(a) * sum(q(b_i|a) * log_qb_i) + # - q(a) * [q(b_i|a) * log_qb_i].sum(i) # ) pyro.clear_param_store() guide_tr = handlers.trace(guide).get_trace() @@ -347,16 +349,16 @@ def guide(): actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} # Hand derived gradients - # elbo = Expectation( + # elbo = MonteCarlo( # q(a) * log_pa - # + q(a) * sum(q(b_i|a) * log_pb_i) - # + q(a) * sum(q(b_i|a) * q(c_i|b_i) * log_pc_i) - # + q(a) * sum(q(b_i|a) * q(c_i|b_i) * log_pe_i) - # + q(a) * sum(q(b_i|a) * q(d_i|b_i) * log_pd_i) + # + q(a) * [q(b_i|a) * log_pb_i].sum(i) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pc_i].sum(i) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pe_i].sum(i) + # + q(a) * [q(b_i|a) * q(d_i|b_i) * log_pd_i].sum(i) # - q(a) * log_qa - # - q(a) * sum(q(b_i|a) * log_qb_i) - # - q(a) * sum(q(b_i|a) * q(c_i|b_i) * log_pq_i) - # - q(a) * sum(q(b_i|a) * q(d_i|b_i) * log_pd_i) + # - q(a) * [q(b_i|a) * log_qb_i].sum(i) + # - q(a) * [q(b_i|a) * q(c_i|b_i) * log_qc_i].sum(i) + # - q(a) * [q(b_i|a) * q(d_i|b_i) * log_qd_i].sum(i) # ) pyro.clear_param_store() guide_tr = handlers.trace(guide).get_trace() @@ -402,3 +404,104 @@ def guide(): logger.info("actual {} = {}".format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=1e-4) + + +@pyroapi.pyro_backend("contrib.funsor") +def test_particle_gradient_3(): + # model + # +-----------------+ + # a -|-> b --> c --> d | + # +-----------------+ + # + # guide + # +-----------+ + # a -|-> b --> c | + # +-----------+ + data = torch.tensor([0.0, 1.0]) + + def model(): + prob_b = torch.tensor([[0.3, 0.7], [0.4, 0.6]]) + prob_c = torch.tensor([0.5, 0.6]) + prob_d = torch.tensor([0.5, 0.1]) + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Categorical(prob_b[a.long()])) + c = pyro.sample("c", dist.Bernoulli(prob_c[b.long()])) + pyro.sample("d", dist.Bernoulli(prob_d[c.long()]), obs=data) + + def guide(): + # set this to ensure rng agrees across runs + # this should be ok since we are comparing a single particle gradients + pyro.set_rng_seed(0) + prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) + prob_b = pyro.param("prob_b", lambda: torch.tensor([[0.4, 0.6], [0.3, 0.7]])) + prob_c = pyro.param("prob_c", lambda: torch.tensor([[0.3, 0.8], [0.2, 0.5]])) + a = pyro.sample("a", dist.Bernoulli(prob_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample( + "b", dist.Categorical(prob_b[a.long()]), infer={"enumerate": "parallel"} + ) + pyro.sample("c", dist.Bernoulli(Vindex(prob_c)[b.long(), idx])) + + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Trace_ELBO gradients + pyro.clear_param_store() + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + # Hand derived gradients (b is exactly integrated) + # elbo = MonteCarlo( + # q(a) * log_pa + # + q(a) * [q(b_i|a) * log_pb_i].sum(i, b) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pc_i].sum(i, b) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pd_i].sum(i, b + # - q(a) * log_qa + # - q(a) * [q(b_i|a) * log_qb_i].sum(i) + # - q(a) * [q(b_i|a) * q(c_i|b_i) * log_qc_i].sum(i, b) + # ) + pyro.clear_param_store() + with handlers.enum(first_available_dim=(-2)), handlers.provenance(): + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + # log factors + logpa = model_tr.nodes["a"]["log_prob"] + logpb = model_tr.nodes["b"]["log_prob"] + logpc = model_tr.nodes["c"]["log_prob"] + logpd = model_tr.nodes["d"]["log_prob"] + + logqa = guide_tr.nodes["a"]["log_prob"] + logqb = guide_tr.nodes["b"]["log_prob"] + logqc = guide_tr.nodes["c"]["log_prob"] + # dice factors + df_a = (logqa - logqa.detach()).exp() + qb = logqb.exp() + df_c = (logqc - logqc.detach()).exp() + # dice elbo + dice_elbo = ( + df_a * logpa + + df_a * (qb * logpb).sum() + + df_a * (qb * df_c * logpc).sum() + + df_a * (qb * df_c * logpd).sum() + - df_a * logqa + - df_a * (qb * logqb).sum() + - df_a * (qb * df_c * logqc).sum() + ) + # backward run + loss = -dice_elbo + loss.backward() + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) From 49553c3c619bafa0abff033e1494a831c2f2f823 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 15 Oct 2021 00:04:44 -0400 Subject: [PATCH 39/53] add two more tests --- tests/contrib/funsor/test_gradient.py | 225 +++++++++++++++++++++++++- 1 file changed, 223 insertions(+), 2 deletions(-) diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index 4fe6c471a1..5065ca1213 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -413,7 +413,7 @@ def test_particle_gradient_3(): # a -|-> b --> c --> d | # +-----------------+ # - # guide + # guide (b is enumerated) # +-----------+ # a -|-> b --> c | # +-----------+ @@ -460,7 +460,7 @@ def guide(): # q(a) * log_pa # + q(a) * [q(b_i|a) * log_pb_i].sum(i, b) # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pc_i].sum(i, b) - # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pd_i].sum(i, b + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pd_i].sum(i, b) # - q(a) * log_qa # - q(a) * [q(b_i|a) * log_qb_i].sum(i) # - q(a) * [q(b_i|a) * q(c_i|b_i) * log_qc_i].sum(i, b) @@ -505,3 +505,224 @@ def guide(): logger.info("actual {} = {}".format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=1e-4) + + +@pyroapi.pyro_backend("contrib.funsor") +def test_particle_gradient_4(): + # model + # +-----------------+ + # a -|-> b --> c --> d | + # +-----------------+ + # + # guide (c is enumerated) + # +-----------+ + # a -|-> b --> c | + # +-----------+ + data = torch.tensor([0.0, 1.0]) + + def model(): + prob_b = torch.tensor([[0.3, 0.7], [0.4, 0.6]]) + prob_c = torch.tensor([[0.5, 0.5], [0.6, 0.4]]) + prob_d = torch.tensor([0.5, 0.1]) + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Categorical(prob_b[a.long()])) + c = pyro.sample("c", dist.Categorical(prob_c[b.long()])) + pyro.sample("d", dist.Bernoulli(prob_d[c.long()]), obs=data) + + def guide(): + # set this to ensure rng agrees across runs + # this should be ok since we are comparing a single particle gradients + pyro.set_rng_seed(0) + prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) + prob_b = pyro.param("prob_b", lambda: torch.tensor([[0.4, 0.6], [0.3, 0.7]])) + prob_c = pyro.param( + "prob_c", + lambda: torch.tensor([[[0.3, 0.7], [0.8, 0.2]], [[0.2, 0.8], [0.5, 0.5]]]), + ) + a = pyro.sample("a", dist.Bernoulli(prob_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample("b", dist.Categorical(prob_b[a.long()])) + pyro.sample( + "c", + dist.Categorical(Vindex(prob_c)[b.long(), idx]), + infer={"enumerate": "parallel"}, + ) + + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Trace_ELBO gradients + pyro.clear_param_store() + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + # Hand derived gradients (c is exactly integrated) + # elbo = MonteCarlo( + # q(a) * log_pa + # + q(a) * [q(b_i|a) * log_pb_i].sum(i) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pc_i].sum(i, c) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pd_i].sum(i, c) + # - q(a) * log_qa + # - q(a) * [q(b_i|a) * log_qb_i].sum(i) + # - q(a) * [q(b_i|a) * q(c_i|b_i) * log_qc_i].sum(i, c) + # ) + pyro.clear_param_store() + with handlers.enum(first_available_dim=(-2)), handlers.provenance(): + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + # log factors + logpa = model_tr.nodes["a"]["log_prob"] + logpb = model_tr.nodes["b"]["log_prob"] + logpc = model_tr.nodes["c"]["log_prob"] + logpd = model_tr.nodes["d"]["log_prob"] + + logqa = guide_tr.nodes["a"]["log_prob"] + logqb = guide_tr.nodes["b"]["log_prob"] + logqc = guide_tr.nodes["c"]["log_prob"] + # dice factors + df_a = (logqa - logqa.detach()).exp() + df_b = (logqb - logqb.detach()).exp() + qc = logqc.exp() + # dice elbo + dice_elbo = ( + df_a * logpa + + df_a * (df_b * logpb).sum() + + df_a * (df_b * qc * logpc).sum() + + df_a * (df_b * qc * logpd).sum() + - df_a * logqa + - df_a * (df_b * logqb).sum() + - df_a * (df_b * qc * logqc).sum() + ) + # backward run + loss = -dice_elbo + loss.backward() + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) + + +@pyroapi.pyro_backend("contrib.funsor") +def test_particle_gradient_5(): + # model + # +-----------------+ + # a -|-> b --> c --> e | + # | \--> d | + # +-----------------+ + # + # guide (b is enumerated) + # +-----------+ + # a -|-> b --> c | + # | \--> d | + # +-----------+ + data = torch.tensor([0.0, 1.0]) + + def model(): + prob_b = torch.tensor([[0.3, 0.7], [0.4, 0.6]]) + prob_c = torch.tensor([0.5, 0.6]) + prob_d = torch.tensor([0.2, 0.3]) + prob_e = torch.tensor([0.5, 0.1]) + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Categorical(prob_b[a.long()])) + c = pyro.sample("c", dist.Bernoulli(prob_c[b.long()])) + pyro.sample("d", dist.Bernoulli(prob_d[b.long()])) + pyro.sample("e", dist.Bernoulli(prob_e[c.long()]), obs=data) + + def guide(): + # set this to ensure rng agrees across runs + # this should be ok since we are comparing a single particle gradients + pyro.set_rng_seed(0) + prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) + prob_b = pyro.param("prob_b", lambda: torch.tensor([[0.4, 0.6], [0.3, 0.7]])) + prob_c = pyro.param("prob_c", lambda: torch.tensor([[0.3, 0.8], [0.2, 0.5]])) + prob_d = pyro.param("prob_d", lambda: torch.tensor([[0.2, 0.9], [0.1, 0.4]])) + a = pyro.sample("a", dist.Bernoulli(prob_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample( + "b", dist.Categorical(prob_b[a.long()]), infer={"enumerate": "parallel"} + ) + pyro.sample("c", dist.Bernoulli(Vindex(prob_c)[b.long(), idx])) + pyro.sample("d", dist.Bernoulli(Vindex(prob_d)[b.long(), idx])) + + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Trace_ELBO gradients + pyro.clear_param_store() + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + # Hand derived gradients (b exactly integrated out) + # elbo = MonteCarlo( + # q(a) * log_pa + # + q(a) * [q(b_i|a) * log_pb_i].sum(i, b) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pc_i].sum(i, b) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pe_i].sum(i, b) + # + q(a) * [q(b_i|a) * q(d_i|b_i) * log_pd_i].sum(i, b) + # - q(a) * log_qa + # - q(a) * [q(b_i|a) * log_qb_i].sum(i, b) + # - q(a) * [q(b_i|a) * q(c_i|b_i) * log_qc_i].sum(i, b) + # - q(a) * [q(b_i|a) * q(d_i|b_i) * log_qd_i].sum(i, b) + # ) + pyro.clear_param_store() + with handlers.enum(first_available_dim=(-2)), handlers.provenance(): + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + # log factors + logpa = model_tr.nodes["a"]["log_prob"] + logpb = model_tr.nodes["b"]["log_prob"] + logpc = model_tr.nodes["c"]["log_prob"] + logpd = model_tr.nodes["d"]["log_prob"] + logpe = model_tr.nodes["e"]["log_prob"] + + logqa = guide_tr.nodes["a"]["log_prob"] + logqb = guide_tr.nodes["b"]["log_prob"] + logqc = guide_tr.nodes["c"]["log_prob"] + logqd = guide_tr.nodes["d"]["log_prob"] + # dice factors + df_a = (logqa - logqa.detach()).exp() + qb = logqb.exp() + df_c = (logqc - logqc.detach()).exp() + df_d = (logqd - logqd.detach()).exp() + # dice elbo + dice_elbo = ( + df_a * logpa + + df_a * (qb * logpb).sum() + + df_a * (qb * df_c * logpc).sum() + + df_a * (qb * df_c * logpe).sum() + + df_a * (qb * df_d * logpd).sum() + - df_a * logqa + - df_a * (qb * logqb).sum() + - df_a * (qb * df_c * logqc).sum() + - df_a * (qb * df_d * logqd).sum() + ) + # backward run + loss = -dice_elbo + loss.backward() + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + breakpoint() + assert_equal(actual_grads, expected_grads, prec=1e-4) From 835f815a47db746393c5439c5220bfcc4859bf8a Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 15 Oct 2021 00:05:50 -0400 Subject: [PATCH 40/53] pin funsor --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 93715ed32c..b14d41fe8b 100644 --- a/setup.py +++ b/setup.py @@ -135,7 +135,7 @@ "horovod": ["horovod[pytorch]>=0.19"], "funsor": [ # This must be a released version when Pyro is released. - "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@0cf9ac20ef8a9f03d86adceb2a42366f4d028093", + "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@eager-categorical-constant", # "funsor[torch]==0.4.1", ], }, From 760eeb00f13d57d0e9b1ef21992b6ddb0a462d5f Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 15 Oct 2021 00:09:07 -0400 Subject: [PATCH 41/53] lint --- pyro/contrib/funsor/handlers/enum_messenger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index e0ff866383..99e44858b6 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -60,7 +60,7 @@ def _get_support_value_tensor(funsor_dist, name, **kwargs): @_get_support_value.register(funsor.Constant) -def _get_support_value_tensor(funsor_dist, name, **kwargs): +def _get_support_value_constant(funsor_dist, name, **kwargs): assert name in funsor_dist.inputs value = _get_support_value(funsor_dist.arg, name, **kwargs) return funsor.Constant(funsor_dist.const_inputs, value) From ab3831c29b5a9b1eefe433fb6f2b2f79e4b56cf9 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 15 Oct 2021 00:13:29 -0400 Subject: [PATCH 42/53] remove breakpoint --- tests/contrib/funsor/test_gradient.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py index 5065ca1213..ac441f1469 100644 --- a/tests/contrib/funsor/test_gradient.py +++ b/tests/contrib/funsor/test_gradient.py @@ -724,5 +724,4 @@ def guide(): logger.info("expected {} = {}".format(name, expected_grads[name])) logger.info("actual {} = {}".format(name, actual_grads[name])) - breakpoint() assert_equal(actual_grads, expected_grads, prec=1e-4) From b6ff8e0c7a7c437a7309578b25e42f5996717766 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 3 Nov 2021 02:32:35 -0400 Subject: [PATCH 43/53] Approximate(ops.sample, ...) based approach --- pyro/contrib/funsor/handlers/enum_messenger.py | 14 ++++++++++++++ pyro/contrib/funsor/infer/trace_elbo.py | 4 ++++ setup.py | 2 +- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 99e44858b6..0364235a45 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -39,6 +39,13 @@ def _get_support_value_contraction(funsor_dist, name, **kwargs): for v in funsor_dist.terms if isinstance(v, funsor.delta.Delta) and name in v.fresh ] + delta_terms.extend( + [ + v + for v in funsor_dist.terms + if isinstance(v, funsor.Approximate) and name in v.model.fresh + ] + ) assert len(delta_terms) == 1 return _get_support_value(delta_terms[0], name, **kwargs) @@ -66,6 +73,11 @@ def _get_support_value_constant(funsor_dist, name, **kwargs): return funsor.Constant(funsor_dist.const_inputs, value) +@_get_support_value.register(funsor.terms.Approximate) +def _get_support_value_approximate(funsor_dist, name, **kwargs): + return _get_support_value(funsor_dist.guide, name, **kwargs) + + @_get_support_value.register(funsor.distribution.Distribution) def _get_support_value_distribution(funsor_dist, name, expand=False): assert name == funsor_dist.value.name @@ -218,6 +230,8 @@ def _pyro_sample(self, msg): expand=msg["infer"].get("expand", False), ) # TODO delegate to _get_support_value + if isinstance(support_value, funsor.Constant): + support_value = support_value.arg msg["funsor"]["value"] = funsor.Constant( OrderedDict(((msg["name"], support_value.output),)), support_value, diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index b79948d973..22fcb7e24a 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -6,6 +6,7 @@ import funsor from funsor.adjoint import adjoint from funsor.constant import Constant +from funsor.interpreter import reinterpret from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, provenance, replay, trace @@ -94,6 +95,9 @@ def differentiable_loss(self, model, guide, *args, **kwargs): logzq = funsor.optimizer.apply_optimizer(logzq) marginals = adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq) + with funsor.montecarlo.MonteCarlo(): + for target, log_measure in marginals.items(): + marginals[target] = reinterpret(log_measure) with funsor.terms.lazy: # finally, integrate out guide variables in the elbo and all plates diff --git a/setup.py b/setup.py index 5057ff1873..0c49c1d60c 100644 --- a/setup.py +++ b/setup.py @@ -135,7 +135,7 @@ "horovod": ["horovod[pytorch]>=0.19"], "funsor": [ # This must be a released version when Pyro is released. - "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461", + "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@importance-funsor", # "funsor[torch]==0.4.1", ], }, From b5bece794ed9fc9fac474028df845abdee1c42ad Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 3 Nov 2021 23:40:20 -0400 Subject: [PATCH 44/53] Importance funsor based approach --- pyro/contrib/funsor/handlers/enum_messenger.py | 18 +++++++++++------- pyro/contrib/funsor/infer/trace_elbo.py | 11 +++++++---- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 0364235a45..4654239609 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -11,6 +11,7 @@ import funsor import torch +from funsor.importance import lazy_importance import pyro.poutine.runtime import pyro.poutine.util @@ -41,9 +42,9 @@ def _get_support_value_contraction(funsor_dist, name, **kwargs): ] delta_terms.extend( [ - v + v.guide for v in funsor_dist.terms - if isinstance(v, funsor.Approximate) and name in v.model.fresh + if isinstance(v, funsor.importance.Importance) and name in v.guide.fresh ] ) assert len(delta_terms) == 1 @@ -53,7 +54,11 @@ def _get_support_value_contraction(funsor_dist, name, **kwargs): @_get_support_value.register(funsor.delta.Delta) def _get_support_value_delta(funsor_dist, name, **kwargs): assert name in funsor_dist.fresh - return OrderedDict(funsor_dist.terms)[name][0] + support_value = OrderedDict(funsor_dist.terms)[name][0] + # remove all upstream dependencies + while isinstance(support_value, funsor.Constant): + support_value = support_value.arg + return support_value @_get_support_value.register(funsor.Tensor) @@ -73,7 +78,7 @@ def _get_support_value_constant(funsor_dist, name, **kwargs): return funsor.Constant(funsor_dist.const_inputs, value) -@_get_support_value.register(funsor.terms.Approximate) +@_get_support_value.register(funsor.importance.Importance) def _get_support_value_approximate(funsor_dist, name, **kwargs): return _get_support_value(funsor_dist.guide, name, **kwargs) @@ -90,7 +95,8 @@ def _enum_strategy_default(dist, msg): for f in msg["cond_indep_stack"] if f.vectorized and f.name not in dist.inputs ) - sampled_dist = dist.sample(msg["name"], sample_inputs) + with lazy_importance: + sampled_dist = dist.sample(msg["name"], sample_inputs) sampled_dist -= sum([math.log(v.size) for v in sample_inputs.values()], 0) return sampled_dist @@ -230,8 +236,6 @@ def _pyro_sample(self, msg): expand=msg["infer"].get("expand", False), ) # TODO delegate to _get_support_value - if isinstance(support_value, funsor.Constant): - support_value = support_value.arg msg["funsor"]["value"] = funsor.Constant( OrderedDict(((msg["name"], support_value.output),)), support_value, diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 22fcb7e24a..ea91bfc74b 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -6,6 +6,7 @@ import funsor from funsor.adjoint import adjoint from funsor.constant import Constant +from funsor.importance import lazy_importance from funsor.interpreter import reinterpret from pyro.contrib.funsor import to_data, to_funsor @@ -94,10 +95,12 @@ def differentiable_loss(self, model, guide, *args, **kwargs): ) logzq = funsor.optimizer.apply_optimizer(logzq) - marginals = adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq) - with funsor.montecarlo.MonteCarlo(): - for target, log_measure in marginals.items(): - marginals[target] = reinterpret(log_measure) + with lazy_importance: + marginals = adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq) + + # eagerly evaluate Importance to obtain `Delta+dice_factor`s + for target in targets.values(): + marginals[target] = reinterpret(marginals[target]) with funsor.terms.lazy: # finally, integrate out guide variables in the elbo and all plates From d6e246e0283f20f6d5d7d92056316da905ddb7b1 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 3 Nov 2021 23:48:06 -0400 Subject: [PATCH 45/53] fixes --- docs/requirements.txt | 2 +- pyro/contrib/funsor/handlers/enum_messenger.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 50c69f0a7d..e38eba48db 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,4 +7,4 @@ observations>=0.1.4 opt_einsum>=2.3.2 pyro-api>=0.1.1 tqdm>=4.36 -funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@18a107ccce40de9e3e98db87d116789e0909ce7b +funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@importance-funsor diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 4654239609..c70b4b2ced 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -44,7 +44,7 @@ def _get_support_value_contraction(funsor_dist, name, **kwargs): [ v.guide for v in funsor_dist.terms - if isinstance(v, funsor.importance.Importance) and name in v.guide.fresh + if isinstance(v, funsor.Importance) and name in v.guide.fresh ] ) assert len(delta_terms) == 1 @@ -78,7 +78,7 @@ def _get_support_value_constant(funsor_dist, name, **kwargs): return funsor.Constant(funsor_dist.const_inputs, value) -@_get_support_value.register(funsor.importance.Importance) +@_get_support_value.register(funsor.Importance) def _get_support_value_approximate(funsor_dist, name, **kwargs): return _get_support_value(funsor_dist.guide, name, **kwargs) From 714fd62a4be1759c9a3ca0a5fcb24fac78cdb633 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 9 Apr 2022 16:40:41 -0400 Subject: [PATCH 46/53] fix funsor model enumeration --- pyro/contrib/funsor/infer/traceenum_elbo.py | 26 +++++++++++++++------ tests/contrib/funsor/test_enum_funsor.py | 11 --------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 7915f7e101..4655bc09ac 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -5,6 +5,7 @@ import funsor from funsor.adjoint import AdjointTape +from funsor.sum_product import _partition from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace @@ -194,17 +195,28 @@ def differentiable_loss(self, model, guide, *args, **kwargs): contracted_factors.append(f) else: uncontracted_factors.append(f) + contracted_costs = [] # incorporate the effects of subsampling and handlers.scale through a common scale factor - contracted_costs = [ - model_terms["scale"] * f + for group_factors, group_vars in _partition( + model_terms["log_measures"] + contracted_factors, + model_terms["measure_vars"], + ): + group_factor_vars = frozenset().union( + *[f.inputs for f in group_factors] + ) + group_plates = model_terms["plate_vars"] & group_factor_vars + outermost_plates = frozenset.intersection( + *(frozenset(f.inputs) & group_plates for f in group_factors) + ) + elim_plates = group_plates - outermost_plates for f in funsor.sum_product.partial_sum_product( funsor.ops.logaddexp, funsor.ops.add, - model_terms["log_measures"] + contracted_factors, - plates=model_terms["plate_vars"], - eliminate=model_terms["measure_vars"], - ) - ] + group_factors, + plates=group_plates, + eliminate=group_vars | elim_plates, + ): + contracted_costs.append(model_terms["scale"] * f) # accumulate costs from model (logp) and guide (-logq) costs = contracted_costs + uncontracted_factors # model costs: logp diff --git a/tests/contrib/funsor/test_enum_funsor.py b/tests/contrib/funsor/test_enum_funsor.py index 1dd6915610..e8dbc60c69 100644 --- a/tests/contrib/funsor/test_enum_funsor.py +++ b/tests/contrib/funsor/test_enum_funsor.py @@ -264,7 +264,6 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pytest.mark.parametrize( "num_samples,num_masked", [(2, 2), (3, 2)], ids=["batch", "masked"] @@ -499,7 +498,6 @@ def hand_guide(data): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pytest.mark.parametrize( "outer_obs,inner_obs", [(False, True), (True, False), (True, True)] @@ -642,7 +640,6 @@ def guide_iplate(): _check_loss_and_grads(expected_loss, actual_loss) -@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("enumerate1", ["parallel", "sequential"]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plate_6(enumerate1): @@ -705,7 +702,6 @@ def guide(): _check_loss_and_grads(expected_loss, actual_loss) -@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plate_7(scale): @@ -877,7 +873,6 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_2(scale): @@ -934,7 +929,6 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_3(scale): @@ -987,7 +981,6 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_4(scale): @@ -1047,7 +1040,6 @@ def guide(data): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_5(scale): @@ -1111,7 +1103,6 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_6(scale): @@ -1250,7 +1241,6 @@ def guide(data): elbo.differentiable_loss(model_plate_plate, guide, data) -@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_7(scale): @@ -1403,7 +1393,6 @@ def guide(data): _check_loss_and_grads(loss_iplate_iplate, loss_plate_plate) -@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("guide_scale", [1]) @pytest.mark.parametrize("model_scale", [1]) @pytest.mark.parametrize( From 29bad7a921ac9b2af8ffb8be211910fcf487a98d Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 10 Apr 2022 23:32:38 -0400 Subject: [PATCH 47/53] use Sampled funsor --- .../contrib/funsor/handlers/enum_messenger.py | 56 +++++++++---------- .../funsor/handlers/named_messenger.py | 2 +- pyro/contrib/funsor/infer/trace_elbo.py | 35 +++++++----- setup.py | 5 +- 4 files changed, 48 insertions(+), 50 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index c70b4b2ced..b23cc07bfc 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -11,7 +11,6 @@ import funsor import torch -from funsor.importance import lazy_importance import pyro.poutine.runtime import pyro.poutine.util @@ -40,13 +39,6 @@ def _get_support_value_contraction(funsor_dist, name, **kwargs): for v in funsor_dist.terms if isinstance(v, funsor.delta.Delta) and name in v.fresh ] - delta_terms.extend( - [ - v.guide - for v in funsor_dist.terms - if isinstance(v, funsor.Importance) and name in v.guide.fresh - ] - ) assert len(delta_terms) == 1 return _get_support_value(delta_terms[0], name, **kwargs) @@ -55,9 +47,6 @@ def _get_support_value_contraction(funsor_dist, name, **kwargs): def _get_support_value_delta(funsor_dist, name, **kwargs): assert name in funsor_dist.fresh support_value = OrderedDict(funsor_dist.terms)[name][0] - # remove all upstream dependencies - while isinstance(support_value, funsor.Constant): - support_value = support_value.arg return support_value @@ -71,16 +60,11 @@ def _get_support_value_tensor(funsor_dist, name, **kwargs): ) -@_get_support_value.register(funsor.Constant) -def _get_support_value_constant(funsor_dist, name, **kwargs): +@_get_support_value.register(funsor.Sampled) +def _get_support_value_sampled(funsor_dist, name, **kwargs): assert name in funsor_dist.inputs value = _get_support_value(funsor_dist.arg, name, **kwargs) - return funsor.Constant(funsor_dist.const_inputs, value) - - -@_get_support_value.register(funsor.Importance) -def _get_support_value_approximate(funsor_dist, name, **kwargs): - return _get_support_value(funsor_dist.guide, name, **kwargs) + return funsor.Sampled(funsor_dist.terms, value) @_get_support_value.register(funsor.distribution.Distribution) @@ -95,8 +79,7 @@ def _enum_strategy_default(dist, msg): for f in msg["cond_indep_stack"] if f.vectorized and f.name not in dist.inputs ) - with lazy_importance: - sampled_dist = dist.sample(msg["name"], sample_inputs) + sampled_dist = dist.sample(msg["name"], sample_inputs) sampled_dist -= sum([math.log(v.size) for v in sample_inputs.values()], 0) return sampled_dist @@ -227,19 +210,25 @@ def _pyro_sample(self, msg): value=msg["name"] ) # TODO delegate to enumerate_site - msg["funsor"]["log_measure"] = _enum_strategy_default( - unsampled_log_measure, msg - ) + _log_measure = _enum_strategy_default(unsampled_log_measure, msg) support_value = _get_support_value( - msg["funsor"]["log_measure"], + _log_measure, msg["name"], expand=msg["infer"].get("expand", False), ) # TODO delegate to _get_support_value - msg["funsor"]["value"] = funsor.Constant( - OrderedDict(((msg["name"], support_value.output),)), - support_value, - ) + if isinstance(_log_measure, funsor.Sampled): + msg["funsor"]["value"] = funsor.Sampled( + _log_measure.arg.terms, + support_value, + ) + msg["funsor"]["log_measure"] = _log_measure.arg + else: + msg["funsor"]["value"] = funsor.Sampled( + _log_measure.terms, + support_value, + ) + msg["funsor"]["log_measure"] = _log_measure msg["value"] = to_data(msg["funsor"]["value"]) msg["done"] = True @@ -265,9 +254,14 @@ def _pyro_sample(self, msg): unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)( value=msg["name"] ) - msg["funsor"]["log_measure"] = enumerate_site(unsampled_log_measure, msg) + _log_measure = enumerate_site(unsampled_log_measure, msg) + if isinstance(_log_measure, funsor.Sampled): + msg["funsor"]["log_measure"] = _log_measure.arg + else: + msg["funsor"]["log_measure"] = _log_measure + # msg["funsor"]["log_measure"] = enumerate_site(unsampled_log_measure, msg) msg["funsor"]["value"] = _get_support_value( - msg["funsor"]["log_measure"], + _log_measure, msg["name"], expand=msg["infer"].get("expand", False), ) diff --git a/pyro/contrib/funsor/handlers/named_messenger.py b/pyro/contrib/funsor/handlers/named_messenger.py index 966b94f545..ffa3d5863e 100644 --- a/pyro/contrib/funsor/handlers/named_messenger.py +++ b/pyro/contrib/funsor/handlers/named_messenger.py @@ -66,7 +66,7 @@ def _pyro_to_data(msg): name_to_dim = msg["kwargs"].setdefault("name_to_dim", OrderedDict()) dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL) - if isinstance(funsor_value, funsor.Constant): + if isinstance(funsor_value, funsor.Sampled): batch_names = tuple(funsor_value.arg.inputs.keys()) else: batch_names = tuple(funsor_value.inputs.keys()) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index ea91bfc74b..ce366c71fb 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -6,8 +6,7 @@ import funsor from funsor.adjoint import adjoint from funsor.constant import Constant -from funsor.importance import lazy_importance -from funsor.interpreter import reinterpret +from funsor.sum_product import _partition from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, provenance, replay, trace @@ -55,17 +54,28 @@ def differentiable_loss(self, model, guide, *args, **kwargs): contracted_factors.append(f) else: uncontracted_factors.append(f) + contracted_costs = [] # incorporate the effects of subsampling and handlers.scale through a common scale factor - contracted_costs = [ - model_terms["scale"] * f + for group_factors, group_vars in _partition( + model_terms["log_measures"] + contracted_factors, + model_terms["measure_vars"], + ): + group_factor_vars = frozenset().union( + *[f.inputs for f in group_factors] + ) + group_plates = model_terms["plate_vars"] & group_factor_vars + outermost_plates = frozenset.intersection( + *(frozenset(f.inputs) & group_plates for f in group_factors) + ) + elim_plates = group_plates - outermost_plates for f in funsor.sum_product.partial_sum_product( funsor.ops.logaddexp, funsor.ops.add, - model_terms["log_measures"] + contracted_factors, - plates=plate_vars, - eliminate=model_measure_vars, - ) - ] + group_factors, + plates=group_plates, + eliminate=group_vars | elim_plates, + ): + contracted_costs.append(model_terms["scale"] * f) # accumulate costs from model (logp) and guide (-logq) costs = contracted_costs + uncontracted_factors # model costs: logp @@ -95,12 +105,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): ) logzq = funsor.optimizer.apply_optimizer(logzq) - with lazy_importance: - marginals = adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq) - - # eagerly evaluate Importance to obtain `Delta+dice_factor`s - for target in targets.values(): - marginals[target] = reinterpret(marginals[target]) + marginals = adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq) with funsor.terms.lazy: # finally, integrate out guide variables in the elbo and all plates diff --git a/setup.py b/setup.py index 3e9cee8674..49891035a0 100644 --- a/setup.py +++ b/setup.py @@ -139,9 +139,8 @@ "horovod": ["horovod[pytorch]>=0.19"], "funsor": [ # This must be a released version when Pyro is released. - # "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@importance-funsor", - # "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461", - "funsor[torch]==0.4.3", + "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@sampled-funsor", + # "funsor[torch]==0.4.3", ], }, python_requires=">=3.7", From 9144be1a6d0217221e1ba9e4e6d9910eabd33a97 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 10 Apr 2022 23:45:36 -0400 Subject: [PATCH 48/53] fixes --- docs/requirements.txt | 2 +- pyro/contrib/funsor/handlers/enum_messenger.py | 4 +--- tests/contrib/funsor/test_enum_funsor.py | 1 - 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index a1bd25ba70..9e9356cf1c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,4 +6,4 @@ observations>=0.1.4 opt_einsum>=2.3.2 pyro-api>=0.1.1 tqdm>=4.36 -funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@importance-funsor +funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@sampled-funsor diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index b23cc07bfc..704f1d71ee 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -46,8 +46,7 @@ def _get_support_value_contraction(funsor_dist, name, **kwargs): @_get_support_value.register(funsor.delta.Delta) def _get_support_value_delta(funsor_dist, name, **kwargs): assert name in funsor_dist.fresh - support_value = OrderedDict(funsor_dist.terms)[name][0] - return support_value + return OrderedDict(funsor_dist.terms)[name][0] @_get_support_value.register(funsor.Tensor) @@ -259,7 +258,6 @@ def _pyro_sample(self, msg): msg["funsor"]["log_measure"] = _log_measure.arg else: msg["funsor"]["log_measure"] = _log_measure - # msg["funsor"]["log_measure"] = enumerate_site(unsampled_log_measure, msg) msg["funsor"]["value"] = _get_support_value( _log_measure, msg["name"], diff --git a/tests/contrib/funsor/test_enum_funsor.py b/tests/contrib/funsor/test_enum_funsor.py index 651152737f..e8dbc60c69 100644 --- a/tests/contrib/funsor/test_enum_funsor.py +++ b/tests/contrib/funsor/test_enum_funsor.py @@ -702,7 +702,6 @@ def guide(): _check_loss_and_grads(expected_loss, actual_loss) -@pytest.mark.skip(reason="wrong dependency structure?") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plate_7(scale): From e4c8a4769c226503b5a3b9803d728db32c1790a8 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 10 Apr 2022 23:59:25 -0400 Subject: [PATCH 49/53] git fixes --- docs/requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 9e9356cf1c..ce14b39d4e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,4 +6,4 @@ observations>=0.1.4 opt_einsum>=2.3.2 pyro-api>=0.1.1 tqdm>=4.36 -funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@sampled-funsor +funsor[torch] @ git+https://github.com/pyro-ppl/funsor.git@sampled-funsor diff --git a/setup.py b/setup.py index 49891035a0..c99e4d038b 100644 --- a/setup.py +++ b/setup.py @@ -139,7 +139,7 @@ "horovod": ["horovod[pytorch]>=0.19"], "funsor": [ # This must be a released version when Pyro is released. - "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@sampled-funsor", + "funsor[torch] @ git+https://github.com/pyro-ppl/funsor.git@sampled-funsor", # "funsor[torch]==0.4.3", ], }, From 703a2fa4fb35849df6bc8143ad2b1c24237e0768 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 11 Apr 2022 17:38:31 -0400 Subject: [PATCH 50/53] use Provenance funsor --- .../contrib/funsor/handlers/enum_messenger.py | 22 +++++++++---------- .../funsor/handlers/named_messenger.py | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 704f1d71ee..d237cab1ba 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -59,11 +59,11 @@ def _get_support_value_tensor(funsor_dist, name, **kwargs): ) -@_get_support_value.register(funsor.Sampled) +@_get_support_value.register(funsor.Provenance) def _get_support_value_sampled(funsor_dist, name, **kwargs): assert name in funsor_dist.inputs - value = _get_support_value(funsor_dist.arg, name, **kwargs) - return funsor.Sampled(funsor_dist.terms, value) + value = _get_support_value(funsor_dist.term, name, **kwargs) + return funsor.Provenance(value, frozenset(funsor_dist.provenance)) @_get_support_value.register(funsor.distribution.Distribution) @@ -216,16 +216,16 @@ def _pyro_sample(self, msg): expand=msg["infer"].get("expand", False), ) # TODO delegate to _get_support_value - if isinstance(_log_measure, funsor.Sampled): - msg["funsor"]["value"] = funsor.Sampled( - _log_measure.arg.terms, + if isinstance(_log_measure, funsor.Provenance): + msg["funsor"]["value"] = funsor.Provenance( support_value, + frozenset(_log_measure.term.terms), ) - msg["funsor"]["log_measure"] = _log_measure.arg + msg["funsor"]["log_measure"] = _log_measure.term else: - msg["funsor"]["value"] = funsor.Sampled( - _log_measure.terms, + msg["funsor"]["value"] = funsor.Provenance( support_value, + frozenset(_log_measure.terms), ) msg["funsor"]["log_measure"] = _log_measure msg["value"] = to_data(msg["funsor"]["value"]) @@ -254,8 +254,8 @@ def _pyro_sample(self, msg): value=msg["name"] ) _log_measure = enumerate_site(unsampled_log_measure, msg) - if isinstance(_log_measure, funsor.Sampled): - msg["funsor"]["log_measure"] = _log_measure.arg + if isinstance(_log_measure, funsor.Provenance): + msg["funsor"]["log_measure"] = _log_measure.term else: msg["funsor"]["log_measure"] = _log_measure msg["funsor"]["value"] = _get_support_value( diff --git a/pyro/contrib/funsor/handlers/named_messenger.py b/pyro/contrib/funsor/handlers/named_messenger.py index ffa3d5863e..0394584aa1 100644 --- a/pyro/contrib/funsor/handlers/named_messenger.py +++ b/pyro/contrib/funsor/handlers/named_messenger.py @@ -66,8 +66,8 @@ def _pyro_to_data(msg): name_to_dim = msg["kwargs"].setdefault("name_to_dim", OrderedDict()) dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL) - if isinstance(funsor_value, funsor.Sampled): - batch_names = tuple(funsor_value.arg.inputs.keys()) + if isinstance(funsor_value, funsor.Provenance): + batch_names = tuple(funsor_value.term.inputs.keys()) else: batch_names = tuple(funsor_value.inputs.keys()) From 3137b1b10642430c4521952e999b7675ddfe3ecc Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 11 Apr 2022 23:42:58 -0400 Subject: [PATCH 51/53] clean up --- .../contrib/funsor/handlers/enum_messenger.py | 40 +++++++++---------- pyro/contrib/funsor/infer/trace_elbo.py | 1 - 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index d237cab1ba..ad05b3de60 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -18,6 +18,7 @@ from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor from pyro.contrib.funsor.handlers.replay_messenger import ReplayMessenger from pyro.contrib.funsor.handlers.trace_messenger import TraceMessenger +from pyro.ops.provenance import detach_provenance, extract_provenance from pyro.poutine.escape_messenger import EscapeMessenger from pyro.poutine.reentrant_messenger import ReentrantMessenger from pyro.poutine.subsample_messenger import _Subsample @@ -63,7 +64,7 @@ def _get_support_value_tensor(funsor_dist, name, **kwargs): def _get_support_value_sampled(funsor_dist, name, **kwargs): assert name in funsor_dist.inputs value = _get_support_value(funsor_dist.term, name, **kwargs) - return funsor.Provenance(value, frozenset(funsor_dist.provenance)) + return funsor.Provenance(value, funsor_dist.provenance) @_get_support_value.register(funsor.distribution.Distribution) @@ -187,9 +188,14 @@ def enumerate_site(dist, msg): raise ValueError("{} not valid enum strategy".format(msg)) +@extract_provenance.register(funsor.Provenance) +def _extract_provenance_funsor(x): + return x.term, x.provenance + + class ProvenanceMessenger(ReentrantMessenger): """ - Adds provenance dim for all sample sites that are not enumerated. + Adds provenance information for all sample sites that are not enumerated. """ def _pyro_sample(self, msg): @@ -209,25 +215,18 @@ def _pyro_sample(self, msg): value=msg["name"] ) # TODO delegate to enumerate_site - _log_measure = _enum_strategy_default(unsampled_log_measure, msg) + log_measure = _enum_strategy_default(unsampled_log_measure, msg) + msg["funsor"]["log_measure"] = detach_provenance(log_measure) support_value = _get_support_value( - _log_measure, + log_measure, msg["name"], expand=msg["infer"].get("expand", False), ) # TODO delegate to _get_support_value - if isinstance(_log_measure, funsor.Provenance): - msg["funsor"]["value"] = funsor.Provenance( - support_value, - frozenset(_log_measure.term.terms), - ) - msg["funsor"]["log_measure"] = _log_measure.term - else: - msg["funsor"]["value"] = funsor.Provenance( - support_value, - frozenset(_log_measure.terms), - ) - msg["funsor"]["log_measure"] = _log_measure + msg["funsor"]["value"] = funsor.Provenance( + support_value, + frozenset(msg["funsor"]["log_measure"].terms), + ) msg["value"] = to_data(msg["funsor"]["value"]) msg["done"] = True @@ -253,13 +252,10 @@ def _pyro_sample(self, msg): unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)( value=msg["name"] ) - _log_measure = enumerate_site(unsampled_log_measure, msg) - if isinstance(_log_measure, funsor.Provenance): - msg["funsor"]["log_measure"] = _log_measure.term - else: - msg["funsor"]["log_measure"] = _log_measure + log_measure = enumerate_site(unsampled_log_measure, msg) + msg["funsor"]["log_measure"] = detach_provenance(log_measure) msg["funsor"]["value"] = _get_support_value( - _log_measure, + log_measure, msg["name"], expand=msg["infer"].get("expand", False), ) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index ce366c71fb..f4facf882e 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -112,7 +112,6 @@ def differentiable_loss(self, model, guide, *args, **kwargs): elbo = to_funsor(0, output=funsor.Real) for cost in costs: target = targets[cost.input_vars] - # FIXME account for normalization factor for unnormalized logzq log_measure = marginals[target] measure_vars = (frozenset(cost.inputs) - plate_vars) - particle_var elbo_term = funsor.Integrate( From 88713f64e2b144dd4bf0b78941b8cd4b1b8ef662 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 5 May 2022 11:42:50 -0400 Subject: [PATCH 52/53] fixes --- pyro/contrib/funsor/handlers/enum_messenger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index ad05b3de60..ca7e9fdeeb 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -225,7 +225,7 @@ def _pyro_sample(self, msg): # TODO delegate to _get_support_value msg["funsor"]["value"] = funsor.Provenance( support_value, - frozenset(msg["funsor"]["log_measure"].terms), + frozenset([(msg["name"], detach_provenance(support_value))]), ) msg["value"] = to_data(msg["funsor"]["value"]) msg["done"] = True From 14131ade88fee79adb8b579a86cbc89b521efa1e Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 22 Jun 2022 07:53:42 -0400 Subject: [PATCH 53/53] use provenance --- pyro/contrib/funsor/infer/discrete.py | 2 +- pyro/contrib/funsor/infer/trace_elbo.py | 53 +++++++++------------ pyro/contrib/funsor/infer/traceenum_elbo.py | 12 ++--- pyro/contrib/funsor/infer/tracetmc_elbo.py | 4 +- 4 files changed, 33 insertions(+), 38 deletions(-) diff --git a/pyro/contrib/funsor/infer/discrete.py b/pyro/contrib/funsor/infer/discrete.py index 3518882c16..c138a6613f 100644 --- a/pyro/contrib/funsor/infer/discrete.py +++ b/pyro/contrib/funsor/infer/discrete.py @@ -36,7 +36,7 @@ def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): log_prob = funsor.sum_product.sum_product( sum_op, prod_op, - terms["log_factors"] + terms["log_measures"], + terms["log_factors"] + list(terms["log_measures"].values()), eliminate=terms["measure_vars"] | terms["plate_vars"], plates=terms["plate_vars"], ) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index f4facf882e..01b6f71490 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -4,8 +4,6 @@ import contextlib import funsor -from funsor.adjoint import adjoint -from funsor.constant import Constant from funsor.sum_product import _partition from pyro.contrib.funsor import to_data, to_funsor @@ -57,7 +55,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): contracted_costs = [] # incorporate the effects of subsampling and handlers.scale through a common scale factor for group_factors, group_vars in _partition( - model_terms["log_measures"] + contracted_factors, + list(model_terms["log_measures"].values()) + contracted_factors, model_terms["measure_vars"], ): group_factor_vars = frozenset().union( @@ -81,38 +79,33 @@ def differentiable_loss(self, model, guide, *args, **kwargs): costs = contracted_costs + uncontracted_factors # model costs: logp costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq - # compute log_measures corresponding to each cost term - # the goal is to achieve fine-grained Rao-Blackwellization - targets = dict() - for cost in costs: - if cost.input_vars not in targets: - targets[cost.input_vars] = Constant( - cost.inputs, - funsor.Tensor( - funsor.ops.new_zeros( - funsor.tensor.get_default_prototype(), - (), - ) - ), - ) - - logzq = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, - guide_terms["log_measures"] + list(targets.values()), - plates=plate_vars, - eliminate=(plate_vars | guide_terms["measure_vars"]), - ) - logzq = funsor.optimizer.apply_optimizer(logzq) - - marginals = adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq) + # compute log_measures corresponding to each cost term + # the goal is to achieve fine-grained Rao-Blackwellization + log_measures = dict() + for cost in costs: + if cost.input_vars not in log_measures: + log_probs = [ + f + for name, f in guide_terms["log_measures"].items() + if name in cost.inputs + ] + log_prob = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + log_probs, + plates=plate_vars, + eliminate=(plate_vars | guide_terms["measure_vars"]) + - frozenset(cost.inputs), + ) + log_measures[cost.input_vars] = funsor.optimizer.apply_optimizer( + log_prob + ) with funsor.terms.lazy: # finally, integrate out guide variables in the elbo and all plates elbo = to_funsor(0, output=funsor.Real) for cost in costs: - target = targets[cost.input_vars] - log_measure = marginals[target] + log_measure = log_measures[cost.input_vars] measure_vars = (frozenset(cost.inputs) - plate_vars) - particle_var elbo_term = funsor.Integrate( log_measure, diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 4655bc09ac..691c03c9eb 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -31,7 +31,7 @@ def terms_from_trace(tr): # of free variables as either product (plate) variables or sum (measure) variables terms = { "log_factors": [], - "log_measures": [], + "log_measures": {}, "scale": to_funsor(1.0), "plate_vars": frozenset(), "measure_vars": frozenset(), @@ -62,7 +62,7 @@ def terms_from_trace(tr): ) # grab the log-measure, found only at sites that are not replayed or observed if node["funsor"].get("log_measure", None) is not None: - terms["log_measures"].append(node["funsor"]["log_measure"]) + terms["log_measures"][name] = node["funsor"]["log_measure"] # sum (measure) variables: the fresh non-plate variables at a site terms["measure_vars"] |= ( frozenset(node["funsor"]["value"].inputs) | {name} @@ -132,7 +132,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): for f in funsor.sum_product.dynamic_partial_sum_product( funsor.ops.logaddexp, funsor.ops.add, - model_terms["log_measures"] + contracted_factors, + list(model_terms["log_measures"].values()) + contracted_factors, plate_to_step=model_terms["plate_to_step"], eliminate=model_terms["measure_vars"] | markov_dims, ) @@ -149,7 +149,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): log_prob = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, - guide_terms["log_measures"], + list(guide_terms["log_measures"].values()), plates=plate_vars, eliminate=(plate_vars | guide_terms["measure_vars"]) - frozenset(cost.inputs), @@ -198,7 +198,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): contracted_costs = [] # incorporate the effects of subsampling and handlers.scale through a common scale factor for group_factors, group_vars in _partition( - model_terms["log_measures"] + contracted_factors, + list(model_terms["log_measures"].values()) + contracted_factors, model_terms["measure_vars"], ): group_factor_vars = frozenset().union( @@ -244,7 +244,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): logzq = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, - guide_terms["log_measures"] + list(targets.values()), + list(guide_terms["log_measures"].values()) + list(targets.values()), plates=plate_vars, eliminate=(plate_vars | guide_terms["measure_vars"]), ) diff --git a/pyro/contrib/funsor/infer/tracetmc_elbo.py b/pyro/contrib/funsor/infer/tracetmc_elbo.py index 7cf4ba805e..e57b4db76a 100644 --- a/pyro/contrib/funsor/infer/tracetmc_elbo.py +++ b/pyro/contrib/funsor/infer/tracetmc_elbo.py @@ -29,7 +29,9 @@ def differentiable_loss(self, model, guide, *args, **kwargs): model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) - log_measures = guide_terms["log_measures"] + model_terms["log_measures"] + log_measures = list(guide_terms["log_measures"].values()) + list( + model_terms["log_measures"].values() + ) log_factors = model_terms["log_factors"] + [ -f for f in guide_terms["log_factors"] ]