From 4676a721e3b155b6cc4c2f672f18c62c42c637d0 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 15 Jan 2021 20:27:15 -0500 Subject: [PATCH 01/15] Implement dynamic creation of TransformOps --- funsor/domains.py | 14 +++++++++++- funsor/ops/op.py | 41 +++++++++++++++++++++++++++++++++++ funsor/torch/distributions.py | 4 +++- test/test_distribution.py | 14 ++++++++++++ 4 files changed, 71 insertions(+), 2 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index da951a1ba..d11649043 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -192,7 +192,6 @@ def find_domain(op, *domains): @find_domain.register(ops.Op) # TODO this is too general, register all ops -@find_domain.register(ops.TransformOp) # TODO too general, may be wrong for some @find_domain.register(ops.ReciprocalOp) @find_domain.register(ops.SigmoidOp) @find_domain.register(ops.TanhOp) @@ -276,6 +275,19 @@ def _find_domain_associative_generic(op, *domains): return Array[dtype, shape] +@find_domain.register(ops.TransformOp) +def _transform_find_domain(op, domain): + fn = op.dispatch(object) + shape = fn.forward_shape(domain.shape) + return Array[domain.dtype, shape] + + +@find_domain.register(ops.LogAbsDetJacobian) +def _transform_log_abs_det_jacobian(op, domain, codomain): + # TODO do we need to handle batch shape here? + return Real + + __all__ = [ 'Bint', 'BintType', diff --git a/funsor/ops/op.py b/funsor/ops/op.py index d790133f8..8d97d777a 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -114,6 +114,47 @@ def log_abs_det_jacobian(x, y): raise NotImplementedError +class LogAbsDetJacobianOp(Op): + pass + + +# TODO memoize or use weakrefs +def make_transform_op(backend_transform): + name = backend_transform.__name__ + + # Check that the op is not batched. + if backend_transform.batch_shape: + raise ValueError("Cannot create an op from a transform " + f"{name} with nontrivial batch shape " + f"{backend_transform.batch_shape}.") + + # Create four ops. + op = make_op(backend_transform, TransformOp, name=name) + op_ldaj = make_op(backend_transform.log_abs_det_jacobian, + LogAbsDetJacobianOp, + name=name + "_log_abs_det_jacobian") + inv = make_op(backend_transform.inv, TransformOp, + name=name + "_inv") + inv_ldaj = make_op(backend_transform.log_abs_det_jacobian, + LogAbsDetJacobianOp, + name=name + "_inv_log_abs_det_jacobian") + + # Register relationships. + op.set_inv(inv) + op.set_log_det_abs_jacobian(op_ldaj) + inv.set_inv(op) + inv.set_log_det_abs_jacobian(inv_ldaj) + + # Register funsor conversions. + from funsor.terms import Binary, Funsor, Unary + op.register(Funsor)(functools.partial(Unary, op)) + inv.register(Funsor)(functools.partial(Unary, inv)) + op_ladj.register(Funsor, Funsor)(functools.partial(Binary, op_ladj)) + inv_ladj.register(Funsor, Funsor)(functools.partial(Binary, ladj_ladj)) + + return op + + # Op registration tables. DISTRIBUTIVE_OPS = set() # (add, mul) pairs UNITS = {} # op -> value diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index 6fe1b3186..f1671ca08 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -269,7 +269,9 @@ def transform_to_data(expr, name_to_dim=None): @to_funsor.register(torch.distributions.Transform) def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None): - raise NotImplementedError("{} is not a currently supported transform".format(tfm)) + op = make_transform_op(tfm) + name = next(real_inputs.keys()) if real_inputs else "value" + return op(Variable(name, output)) @to_funsor.register(torch.distributions.transforms.ExpTransform) diff --git a/test/test_distribution.py b/test/test_distribution.py index f79202b0b..fbe6fe3db 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1215,3 +1215,17 @@ def test_categorical_event_dim_conversion(batch_shape, event_shape): funsor.to_data(data, name_to_dim=name_to_dim)) assert actual_log_prob.shape == expected_log_prob.shape assert_close(actual_log_prob, expected_log_prob) + + +@pytest.mark.parametrize("shape", [(10,), (4, 3)], ids=str) +def test_haar_transform(shape): + d = backend_dist.TransformedDistribution( + backend_dist.Normal(0, 1).expand(shape), + backend_dist.transforms.HaarTransform(dim=-len(shape))) + data = ops.randn(shape) + expected_log_prob = d.log_prob(data) + + f = to_funsor(d, output=Real) + log_prob = f(data) + actual_log_prob = funsor.to_data(log_prob) + assert_close(actual_log_prob, expected_log_prob) From 32b0c2cba29b8093a2f0131823eb9cb708656ac5 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 17 Jan 2021 09:17:30 -0500 Subject: [PATCH 02/15] Fixes --- funsor/distribution.py | 1 + funsor/domains.py | 2 +- funsor/ops/op.py | 10 +++++++--- funsor/torch/distributions.py | 12 ++++++------ test/test_distribution.py | 8 +++++--- 5 files changed, 20 insertions(+), 13 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index d5a6561f2..ad509541f 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -402,6 +402,7 @@ def maskeddist_to_funsor(backend_dist, output=None, dim_to_name=None): return mask * funsor_base_dist +# TODO make this work with transforms with nontrivial event_dim logic # converts TransformedDistributions def transformeddist_to_funsor(backend_dist, output=None, dim_to_name=None): dist_module = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist diff --git a/funsor/domains.py b/funsor/domains.py index d11649043..5a0557c02 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -282,7 +282,7 @@ def _transform_find_domain(op, domain): return Array[domain.dtype, shape] -@find_domain.register(ops.LogAbsDetJacobian) +@find_domain.register(ops.LogAbsDetJacobianOp) def _transform_log_abs_det_jacobian(op, domain, codomain): # TODO do we need to handle batch shape here? return Real diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 8d97d777a..2dede06a3 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -1,6 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import functools + from multipledispatch import Dispatcher @@ -120,7 +122,7 @@ class LogAbsDetJacobianOp(Op): # TODO memoize or use weakrefs def make_transform_op(backend_transform): - name = backend_transform.__name__ + name = type(backend_transform).__name__ # Check that the op is not batched. if backend_transform.batch_shape: @@ -149,8 +151,8 @@ def make_transform_op(backend_transform): from funsor.terms import Binary, Funsor, Unary op.register(Funsor)(functools.partial(Unary, op)) inv.register(Funsor)(functools.partial(Unary, inv)) - op_ladj.register(Funsor, Funsor)(functools.partial(Binary, op_ladj)) - inv_ladj.register(Funsor, Funsor)(functools.partial(Binary, ladj_ladj)) + op_ldaj.register(Funsor, Funsor)(functools.partial(Binary, op_ldaj)) + inv_ldaj.register(Funsor, Funsor)(functools.partial(Binary, inv_ldaj)) return op @@ -163,10 +165,12 @@ def make_transform_op(backend_transform): __all__ = [ 'CachedOpMeta', 'DISTRIBUTIVE_OPS', + 'LogAbsDetJacobianOp', 'Op', 'PRODUCT_INVERSES', 'TransformOp', 'UNITS', 'declare_op_types', 'make_op', + 'make_transform_op', ] diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index f1671ca08..5a25c3085 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -7,13 +7,14 @@ import pyro.distributions as dist import pyro.distributions.testing.fakes as fakes -from pyro.distributions.torch_distribution import ExpandedDistribution, MaskedDistribution import torch +from pyro.distributions.torch_distribution import ExpandedDistribution, MaskedDistribution +import funsor.ops as ops from funsor.cnf import Contraction from funsor.distribution import ( # noqa: F401 - Bernoulli, FUNSOR_DIST_NAMES, + Bernoulli, LogNormal, backenddist_to_funsor, eager_beta, @@ -24,10 +25,10 @@ eager_delta_funsor_funsor, eager_delta_funsor_variable, eager_delta_tensor, + eager_delta_variable_variable, eager_dirichlet_categorical, eager_dirichlet_multinomial, eager_dirichlet_posterior, - eager_delta_variable_variable, eager_gamma_gamma, eager_gamma_poisson, eager_multinomial, @@ -38,15 +39,14 @@ indepdist_to_funsor, make_dist, maskeddist_to_funsor, - transformeddist_to_funsor, + transformeddist_to_funsor ) from funsor.domains import Real, Reals -import funsor.ops as ops +from funsor.ops.op import make_transform_op from funsor.tensor import Tensor from funsor.terms import Binary, Funsor, Reduce, Unary, Variable, eager, to_data, to_funsor from funsor.util import methodof - __all__ = list(x[0] for x in FUNSOR_DIST_NAMES) diff --git a/test/test_distribution.py b/test/test_distribution.py index fbe6fe3db..69f16fe93 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1220,12 +1220,14 @@ def test_categorical_event_dim_conversion(batch_shape, event_shape): @pytest.mark.parametrize("shape", [(10,), (4, 3)], ids=str) def test_haar_transform(shape): d = backend_dist.TransformedDistribution( - backend_dist.Normal(0, 1).expand(shape), + # backend_dist.Normal(0, 1).expand(shape), # FIXME + backend_dist.Normal(0, 1).expand(shape).to_event(), backend_dist.transforms.HaarTransform(dim=-len(shape))) - data = ops.randn(shape) + data = randn(shape) expected_log_prob = d.log_prob(data) - f = to_funsor(d, output=Real) + # f = to_funsor(d, output=Real) # FIXME + f = to_funsor(d, output=Real, dim_to_name={}) log_prob = f(data) actual_log_prob = funsor.to_data(log_prob) assert_close(actual_log_prob, expected_log_prob) From 81cf81681b57441b8814cf254b6ae5090160866b Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 17 Jan 2021 10:27:30 -0500 Subject: [PATCH 03/15] Get tests working on a Pyro branch --- funsor/distribution.py | 2 ++ funsor/ops/op.py | 45 ++++++++++++++++++++++++++++----------- test/test_distribution.py | 41 +++++++++++++++++++++++++++++------ 3 files changed, 69 insertions(+), 19 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index ad509541f..7d6054652 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -366,6 +366,8 @@ def backenddist_to_funsor(funsor_dist_class, backend_dist, output=None, dim_to_n def indepdist_to_funsor(backend_dist, output=None, dim_to_name=None): + if dim_to_name is None: + dim_to_name = {} dim_to_name = OrderedDict((dim - backend_dist.reinterpreted_batch_ndims, name) for dim, name in dim_to_name.items()) dim_to_name.update(OrderedDict((i, "_pyro_event_dim_{}".format(i)) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 2dede06a3..04e7b3270 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +import warnings from multipledispatch import Dispatcher @@ -17,6 +18,12 @@ def __call__(cls, *args, **kwargs): instance = super(CachedOpMeta, cls).__call__(*args, **kwargs) cls._cache[args] = instance return instance + except TypeError as e: + if str(e).startswith("unhashable type"): + warnings.warn(f"Cannot memoize Op ({e})") + instance = super(CachedOpMeta, cls).__call__(*args, **kwargs) + return instance + raise e from None class Op(Dispatcher): @@ -54,6 +61,9 @@ def __str__(self): def make_op(fn=None, parent=None, *, name=None, module_name="funsor.ops"): + """ + Factory to create a new :class:`Op` subclass and a new instance of that class. + """ # Support use as decorator. if fn is None: return lambda fn: make_op(fn, parent, name=name, module_name=module_name) @@ -66,10 +76,11 @@ def make_op(fn=None, parent=None, *, name=None, module_name="funsor.ops"): name = fn if isinstance(fn, str) else fn.__name__ assert isinstance(name, str) - classname = name[0].upper() + name[1:].rstrip("_") + "Op" # e.g. add -> AddOp - new_type = CachedOpMeta(classname, (parent,), {}) - new_type.__module__ = module_name - return new_type(fn, name=name) + classname = name.capitalize().rstrip("_") + "Op" # e.g. add -> AddOp + cls = CachedOpMeta(classname, (parent,), {}) + cls.__module__ = module_name + op = cls(fn, name=name) + return op def declare_op_types(locals_, all_, name_): @@ -120,15 +131,25 @@ class LogAbsDetJacobianOp(Op): pass -# TODO memoize or use weakrefs +# TODO Memoize or use weakrefs. This currently creates reference cycles, +# delaying gc of backend_transform which may be a large expensive per-iteration +# data structure such as a normalizing flow defined by a hyper net. def make_transform_op(backend_transform): + """ + Factory to create a collection four new ops corresponding to a bijective + transform. + + :returns: An ``op`` from which can be accessed ``op.inv``, + ``op.log_abs_det_jacobian``, and ``op.inv.log_abs_det_jacobian``. + :rtype: TransformOp + """ name = type(backend_transform).__name__ - # Check that the op is not batched. - if backend_transform.batch_shape: - raise ValueError("Cannot create an op from a transform " - f"{name} with nontrivial batch shape " - f"{backend_transform.batch_shape}.") + # TODO Check that the op is not batched. Something like: + # if backend_transform.batch_shape: + # raise ValueError("Cannot create an op from a transform " + # f"{name} with nontrivial batch shape " + # f"{backend_transform.batch_shape}.") # Create four ops. op = make_op(backend_transform, TransformOp, name=name) @@ -143,9 +164,9 @@ def make_transform_op(backend_transform): # Register relationships. op.set_inv(inv) - op.set_log_det_abs_jacobian(op_ldaj) + op.set_log_abs_det_jacobian(op_ldaj) inv.set_inv(op) - inv.set_log_det_abs_jacobian(inv_ldaj) + inv.set_log_abs_det_jacobian(inv_ldaj) # Register funsor conversions. from funsor.terms import Binary, Funsor, Unary diff --git a/test/test_distribution.py b/test/test_distribution.py index 69f16fe93..7144c7795 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1217,17 +1217,44 @@ def test_categorical_event_dim_conversion(batch_shape, event_shape): assert_close(actual_log_prob, expected_log_prob) +@pytest.mark.parametrize("shape", [(), (4,), (3, 2)], ids=str) +def test_power_transform(shape): + transform = backend_dist.transforms.PowerTransform(1.5) + base_dist = backend_dist.Exponential(1) + d = backend_dist.TransformedDistribution(base_dist, transform) + + data = randn(shape).exp() + expected_log_prob = d.log_prob(data) + + name_to_dim = dict(i=-3, j=-2, k=-1) + dim_to_name = {v: k for k, v in name_to_dim.items()} + x = to_funsor(data, Real, dim_to_name) + f = to_funsor(d, output=Real) + log_prob = f(x) + actual_log_prob = funsor.to_data(log_prob, name_to_dim) + assert_close(actual_log_prob, expected_log_prob) + + @pytest.mark.parametrize("shape", [(10,), (4, 3)], ids=str) -def test_haar_transform(shape): - d = backend_dist.TransformedDistribution( - # backend_dist.Normal(0, 1).expand(shape), # FIXME - backend_dist.Normal(0, 1).expand(shape).to_event(), - backend_dist.transforms.HaarTransform(dim=-len(shape))) +@pytest.mark.parametrize("to_event", [ + True, + xfail_param(False, reason="bug in to_funsor(TransformedDistribution)"), +]) +def test_haar_transform(shape, to_event): + try: + transform = backend_dist.transforms.HaarTransform(dim=-len(shape)) + except AttributeError: + pytest.xfail(reason="backend missing HaarTransform") + base_dist = backend_dist.Normal(0, 1).expand(shape) + if to_event: + # Work around a bug in to_funsor(TransformedDistribution) + base_dist = base_dist.to_event() + d = backend_dist.TransformedDistribution(base_dist, transform) + data = randn(shape) expected_log_prob = d.log_prob(data) - # f = to_funsor(d, output=Real) # FIXME - f = to_funsor(d, output=Real, dim_to_name={}) + f = to_funsor(d, output=Real) log_prob = f(data) actual_log_prob = funsor.to_data(log_prob) assert_close(actual_log_prob, expected_log_prob) From 7617c4168b7201363503e3bb710ab7abd7599839 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 17 Jan 2021 11:11:58 -0500 Subject: [PATCH 04/15] Do not require hashability --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index b30e5bb06..b1775b96b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,6 +18,7 @@ filterwarnings = error ignore:numpy.ufunc size changed:RuntimeWarning ignore:numpy.dtype size changed:RuntimeWarning ignore:Mixed memory format:UserWarning + ignore:Cannot memoize Op:UserWarning ignore::DeprecationWarning once::DeprecationWarning From 8fba42492946c4ec4b415d8f6d4c8d7a63cd09cc Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 20 Jan 2021 12:28:59 -0500 Subject: [PATCH 05/15] Validate transform is not batched --- funsor/ops/builtin.py | 12 ++---------- funsor/ops/op.py | 31 +++++++++++++++++++++++-------- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index 70000704c..a2af21e92 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -137,11 +137,8 @@ def log_abs_det_jacobian(x, y): exp.set_inv(log) log.set_inv(exp) - - -@tanh.set_inv -def tanh_inv(y): - return atanh(y) +tanh.set_inv(atanh) +atanh.set_inv(tanh) @tanh.set_log_abs_det_jacobian @@ -149,11 +146,6 @@ def tanh_log_abs_det_jacobian(x, y): return 2. * (math.log(2.) - x - softplus(-2. * x)) -@atanh.set_inv -def atanh_inv(y): - return tanh(y) - - @atanh.set_log_abs_det_jacobian def atanh_log_abs_det_jacobian(x, y): return -tanh.log_abs_det_jacobian(y, x) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index e8da464ee..71b89b634 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -154,6 +154,27 @@ def log_abs_det_jacobian(x, y): raise NotImplementedError +class ValidatedTransformOp(TransformOp): + """ + Like :class:`TransformOp` but additionally checks that the backing tranforn + is not batched. This check is performed only on the first :meth:`__call__`. + """ + def __init__(self, fn, *, name=None): + super().__init__(fn, name=name) + self._pending_validation = [fn] + + def __call__(self, x): + while self._pending_validation: + fn = self._pending_validation.pop() + if len(x.shape) < fn.domain.event_dim: + raise ValueError(f"Too few dimensions for input, in {self.name}") + event_shape = x.shape[len(x.shape) - fn.domain.event_dim:] + shape = fn.forward_shape(event_shape) + if len(shape) > fn.codomain.event_dim: + raise ValueError(f"Cannot treat batched transform {self.name} as an Op") + return super().__call__(x) + + class LogAbsDetJacobianOp(Op): pass @@ -172,18 +193,12 @@ def make_transform_op(backend_transform): """ name = type(backend_transform).__name__ - # TODO Check that the op is not batched. Something like: - # if backend_transform.batch_shape: - # raise ValueError("Cannot create an op from a transform " - # f"{name} with nontrivial batch shape " - # f"{backend_transform.batch_shape}.") - # Create four ops. - op = make_op(backend_transform, TransformOp, name=name) + op = make_op(backend_transform, ValidatedTransformOp, name=name) op_ldaj = make_op(backend_transform.log_abs_det_jacobian, LogAbsDetJacobianOp, name=name + "_log_abs_det_jacobian") - inv = make_op(backend_transform.inv, TransformOp, + inv = make_op(backend_transform.inv, ValidatedTransformOp, name=name + "_inv") inv_ldaj = make_op(backend_transform.log_abs_det_jacobian, LogAbsDetJacobianOp, From e13e2b8fffb51860e09d1c124022587f14abc492 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 20 Jan 2021 12:40:04 -0500 Subject: [PATCH 06/15] Avoid import cycle --- funsor/ops/op.py | 7 ------- funsor/terms.py | 7 +++++-- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 71b89b634..ae7e1f348 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -210,13 +210,6 @@ def make_transform_op(backend_transform): inv.set_inv(op) inv.set_log_abs_det_jacobian(inv_ldaj) - # Register funsor conversions. - from funsor.terms import Binary, Funsor, Unary - op.register(Funsor)(functools.partial(Unary, op)) - inv.register(Funsor)(functools.partial(Unary, inv)) - op_ldaj.register(Funsor, Funsor)(functools.partial(Binary, op_ldaj)) - inv_ldaj.register(Funsor, Funsor)(functools.partial(Binary, inv_ldaj)) - return op diff --git a/funsor/terms.py b/funsor/terms.py index 7cf6d1f51..030e8dc74 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1703,14 +1703,17 @@ def quote_inplace_first_arg_on_first_line(arg, indent, out): ops.UnaryOp.subclass_register(Funsor)(Unary) ops.AssociativeOp.subclass_register(Funsor)(Unary) # Reductions. ops.AssociativeOp.subclass_register(Funsor, Funsor)(Binary) +ops.LogAbsDetJacobianOp.subclass_register(Funsor, Funsor)(Binary) -@AssociativeOp.subclass_register(object, Funsor) +@ops.AssociativeOp.subclass_register(object, Funsor) +@ops.LogAbsDetJacobianOp.subclass_register(object, Funsor) def binary_object_funsor(op, x, y): return Binary(op, to_funsor(x), y) -@AssociativeOp.subclass_register(Funsor, object) +@ops.AssociativeOp.subclass_register(Funsor, object) +@ops.LogAbsDetJacobianOp.subclass_register(Funsor, object) def binary_funsor_object(op, x, y): return Binary(op, x, to_funsor(y)) From cdb7e7ea879873169a20111757286ecd18c6aebe Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 20 Jan 2021 12:48:09 -0500 Subject: [PATCH 07/15] Fix typo --- funsor/ops/op.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index ae7e1f348..55ef93394 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -156,8 +156,9 @@ def log_abs_det_jacobian(x, y): class ValidatedTransformOp(TransformOp): """ - Like :class:`TransformOp` but additionally checks that the backing tranforn - is not batched. This check is performed only on the first :meth:`__call__`. + Like :class:`TransformOp` but additionally checks that the backing + transform is not batched. This check is performed only on the first + :meth:`__call__`. """ def __init__(self, fn, *, name=None): super().__init__(fn, name=name) @@ -171,7 +172,8 @@ def __call__(self, x): event_shape = x.shape[len(x.shape) - fn.domain.event_dim:] shape = fn.forward_shape(event_shape) if len(shape) > fn.codomain.event_dim: - raise ValueError(f"Cannot treat batched transform {self.name} as an Op") + raise ValueError(f"Cannot treat transform {self.name} as an Op " + "because it is batched") return super().__call__(x) From 4418590b8de34f8d335ef15e3997cdb0632de079 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 20 Jan 2021 14:49:45 -0500 Subject: [PATCH 08/15] Refactor to create a WrappedTransformOp --- funsor/ops/op.py | 118 ++++++++++++++++------------------ funsor/terms.py | 8 +-- funsor/torch/distributions.py | 3 +- 3 files changed, 62 insertions(+), 67 deletions(-) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 55ef93394..3d3defe85 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -3,7 +3,6 @@ import functools import inspect -import warnings import weakref from multipledispatch import Dispatcher @@ -12,20 +11,38 @@ class CachedOpMeta(type): """ Metaclass for caching op instance construction. + Caching strategy is to key on ``*args`` and retain values forever. """ + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + cls._instance_cache = {} + def __call__(cls, *args, **kwargs): try: - return cls._cache[args] + return cls._instance_cache[args] except KeyError: instance = super(CachedOpMeta, cls).__call__(*args, **kwargs) - cls._cache[args] = instance + cls._instance_cache[args] = instance return instance - except TypeError as e: - if str(e).startswith("unhashable type"): - warnings.warn(f"Cannot memoize Op ({e})") - instance = super(CachedOpMeta, cls).__call__(*args, **kwargs) - return instance - raise e from None + + +class WrappedOpMeta(type): + """ + Metaclass for ops that wrap temporary backend ops. + Caching strategy is to key on ``id(backend_op)`` and forget values asap. + """ + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + cls._instance_cache = weakref.WeakValueDictionary() + + def __call__(cls, fn): + try: + return cls._instance_cache[id(fn)] + except KeyError: + op = super().__call__(fn) + op.fn = fn # Ensures the key id(fn) is not reused. + cls._instance_cache[id(fn)] = op + return op class Op(Dispatcher): @@ -33,7 +50,6 @@ class Op(Dispatcher): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - cls._cache = {} cls._subclass_registry = [] def __init__(self, fn, *, name=None): @@ -100,7 +116,7 @@ def make_op(fn=None, parent=None, *, name=None, module_name="funsor.ops"): assert isinstance(name, str) classname = name.capitalize().rstrip("_") + "Op" # e.g. add -> AddOp - cls = CachedOpMeta(classname, (parent,), {}) + cls = type(classname, (parent,), {}) cls.__module__ = module_name op = cls(fn, name=name) return op @@ -126,6 +142,10 @@ class UnaryOp(Op): pass +class BinaryOp(Op): + pass + + class TransformOp(UnaryOp): def set_inv(self, fn): """ @@ -154,65 +174,40 @@ def log_abs_det_jacobian(x, y): raise NotImplementedError -class ValidatedTransformOp(TransformOp): +class WrappedTransformOp(TransformOp, metaclass=WrappedOpMeta): """ - Like :class:`TransformOp` but additionally checks that the backing - transform is not batched. This check is performed only on the first + Wrapper for a backend ``Transform`` object that provides ``.inv`` and + ``.log_abs_det_jacobian``. This additionally validates shapes on the first :meth:`__call__`. """ - def __init__(self, fn, *, name=None): - super().__init__(fn, name=name) - self._pending_validation = [fn] + def __init__(self, fn): + super().__init__(fn, name=type(fn).__name__) + self._is_validated = False def __call__(self, x): - while self._pending_validation: - fn = self._pending_validation.pop() - if len(x.shape) < fn.domain.event_dim: - raise ValueError(f"Too few dimensions for input, in {self.name}") - event_shape = x.shape[len(x.shape) - fn.domain.event_dim:] - shape = fn.forward_shape(event_shape) - if len(shape) > fn.codomain.event_dim: - raise ValueError(f"Cannot treat transform {self.name} as an Op " - "because it is batched") + if self._is_validated: + return super().__call__(x) + if len(x.shape) < self.fn.domain.event_dim: + raise ValueError(f"Too few dimensions for input, in {self.name}") + event_shape = x.shape[len(x.shape) - self.fn.domain.event_dim:] + shape = self.fn.forward_shape(event_shape) + if len(shape) > self.fn.codomain.event_dim: + raise ValueError(f"Cannot treat transform {self.name} as an Op " + "because it is batched") + self._is_validated = True return super().__call__(x) + @property + def inv(self): + return WrappedTransformOp(self.fn.inv) -class LogAbsDetJacobianOp(Op): - pass + @property + def log_abs_det_jacobian(self): + return LogAbsDetJacobianOp(self.fn.log_abs_det_jacobian) -# TODO Memoize or use weakrefs. This currently creates reference cycles, -# delaying gc of backend_transform which may be a large expensive per-iteration -# data structure such as a normalizing flow defined by a hyper net. -def make_transform_op(backend_transform): - """ - Factory to create a collection four new ops corresponding to a bijective - transform. - - :returns: An ``op`` from which can be accessed ``op.inv``, - ``op.log_abs_det_jacobian``, and ``op.inv.log_abs_det_jacobian``. - :rtype: TransformOp - """ - name = type(backend_transform).__name__ - - # Create four ops. - op = make_op(backend_transform, ValidatedTransformOp, name=name) - op_ldaj = make_op(backend_transform.log_abs_det_jacobian, - LogAbsDetJacobianOp, - name=name + "_log_abs_det_jacobian") - inv = make_op(backend_transform.inv, ValidatedTransformOp, - name=name + "_inv") - inv_ldaj = make_op(backend_transform.log_abs_det_jacobian, - LogAbsDetJacobianOp, - name=name + "_inv_log_abs_det_jacobian") - - # Register relationships. - op.set_inv(inv) - op.set_log_abs_det_jacobian(op_ldaj) - inv.set_inv(op) - inv.set_log_abs_det_jacobian(inv_ldaj) - - return op +class LogAbsDetJacobianOp(BinaryOp, metaclass=WrappedOpMeta): + pass # Op registration tables. @@ -221,6 +216,7 @@ def make_transform_op(backend_transform): PRODUCT_INVERSES = {} # op -> inverse op __all__ = [ + 'BinaryOp', 'CachedOpMeta', 'DISTRIBUTIVE_OPS', 'LogAbsDetJacobianOp', @@ -229,7 +225,7 @@ def make_transform_op(backend_transform): 'TransformOp', 'UNITS', 'UnaryOp', + 'WrappedTransformOp', 'declare_op_types', 'make_op', - 'make_transform_op', ] diff --git a/funsor/terms.py b/funsor/terms.py index 030e8dc74..e61ef4465 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1701,19 +1701,19 @@ def quote_inplace_first_arg_on_first_line(arg, indent, out): ops.UnaryOp.subclass_register(Funsor)(Unary) -ops.AssociativeOp.subclass_register(Funsor)(Unary) # Reductions. +ops.BinaryOp.subclass_register(Funsor, Funsor)(Binary) ops.AssociativeOp.subclass_register(Funsor, Funsor)(Binary) -ops.LogAbsDetJacobianOp.subclass_register(Funsor, Funsor)(Binary) +ops.AssociativeOp.subclass_register(Funsor)(Unary) # Reductions. +@ops.BinaryOp.subclass_register(object, Funsor) @ops.AssociativeOp.subclass_register(object, Funsor) -@ops.LogAbsDetJacobianOp.subclass_register(object, Funsor) def binary_object_funsor(op, x, y): return Binary(op, to_funsor(x), y) +@ops.BinaryOp.subclass_register(Funsor, object) @ops.AssociativeOp.subclass_register(Funsor, object) -@ops.LogAbsDetJacobianOp.subclass_register(Funsor, object) def binary_funsor_object(op, x, y): return Binary(op, x, to_funsor(y)) diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index 5a25c3085..9dfaa94f4 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -42,7 +42,6 @@ transformeddist_to_funsor ) from funsor.domains import Real, Reals -from funsor.ops.op import make_transform_op from funsor.tensor import Tensor from funsor.terms import Binary, Funsor, Reduce, Unary, Variable, eager, to_data, to_funsor from funsor.util import methodof @@ -269,7 +268,7 @@ def transform_to_data(expr, name_to_dim=None): @to_funsor.register(torch.distributions.Transform) def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None): - op = make_transform_op(tfm) + op = ops.WrappedTransformOp(tfm) name = next(real_inputs.keys()) if real_inputs else "value" return op(Variable(name, output)) From b392c3919c7e6acf45a0b004e962f0add17cf823 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 20 Jan 2021 15:30:56 -0500 Subject: [PATCH 09/15] Add more tests --- funsor/ops/op.py | 8 ++++++-- test/test_ops.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 test/test_ops.py diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 3d3defe85..36b700420 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -36,12 +36,16 @@ def __init__(cls, *args, **kwargs): cls._instance_cache = weakref.WeakValueDictionary() def __call__(cls, fn): + if inspect.ismethod(fn): + key = id(fn.__self__), fn.__func__ # e.g. t.log_abs_det_jacobian + else: + key = id(fn) # e.g. t.inv try: - return cls._instance_cache[id(fn)] + return cls._instance_cache[key] except KeyError: op = super().__call__(fn) op.fn = fn # Ensures the key id(fn) is not reused. - cls._instance_cache[id(fn)] = op + cls._instance_cache[key] = op return op diff --git a/test/test_ops.py b/test/test_ops.py new file mode 100644 index 000000000..3dd21f4f0 --- /dev/null +++ b/test/test_ops.py @@ -0,0 +1,44 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import weakref + +import pytest + +from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND +from funsor.util import get_backend +from funsor.ops import WrappedTransformOp + + +@pytest.fixture +def dist(): + try: + module_name = BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()] + except KeyError: + pytest.skip(f"missing distributions module for {get_backend()}") + return pytest.importorskip(module_name).dist + + +def test_transform_op_cache(dist): + t = dist.transforms.PowerTransform(0.5) + W = WrappedTransformOp + assert W(t) is W(t) + assert W(t).inv is W(t).inv + assert W(t.inv) is W(t).inv + assert W(t).log_abs_det_jacobian is W(t).log_abs_det_jacobian + + +@pytest.mark.xfail(reason="reference cycle") +def test_transform_op_gc(dist): + op_set = weakref.WeakSet() + + t = dist.transforms.PowerTransform(0.5) + op_set.add(WrappedTransformOp(t)) + assert len(op_set) == 1 + op_set.add(WrappedTransformOp(t)) + assert len(op_set) == 1 + del t + for op in op_set: + import sys + print(sys.getrefcount(op), sys.getrefcount(op.fn)) + assert len(op_set) == 0 From dc8b1d6413a2d842884d596150fc77a241cc49aa Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 20 Jan 2021 16:30:08 -0500 Subject: [PATCH 10/15] Eliminate reference cycles --- funsor/ops/op.py | 18 ++++++++++++++++-- test/test_ops.py | 14 ++++---------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 36b700420..06c251af6 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -8,6 +8,20 @@ from multipledispatch import Dispatcher +class WeakPartial: + """ + Like ``functools.partial(fn, arg)`` but weakly referencing ``arg``. + """ + def __init__(self, fn, arg): + self.fn = fn + self.weak_arg = weakref.ref(arg) + functools.update_wrapper(self, fn) + + def __call__(self, *args): + arg = self.weak_arg() + return self.fn(arg, *args) + + class CachedOpMeta(type): """ Metaclass for caching op instance construction. @@ -71,7 +85,7 @@ def __init__(self, fn, *, name=None): # Register all existing patterns. for supercls in reversed(inspect.getmro(type(self))): for pattern, fn in getattr(supercls, "_subclass_registry", ()): - self.register(*pattern)(functools.partial(fn, self)) + self.add(pattern, WeakPartial(fn, self)) # Save self for registering future patterns. Op._all_instances.add(self) @@ -96,7 +110,7 @@ def decorator(fn): # Register with all existing instances. for op in Op._all_instances: if isinstance(op, cls): - op.register(*pattern)(functools.partial(fn, op)) + op.add(pattern, WeakPartial(fn, op)) # Ensure registration with all future instances. cls._subclass_registry.append((pattern, fn)) return fn diff --git a/test/test_ops.py b/test/test_ops.py index 3dd21f4f0..9db97165b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -28,17 +28,11 @@ def test_transform_op_cache(dist): assert W(t).log_abs_det_jacobian is W(t).log_abs_det_jacobian -@pytest.mark.xfail(reason="reference cycle") def test_transform_op_gc(dist): - op_set = weakref.WeakSet() - t = dist.transforms.PowerTransform(0.5) - op_set.add(WrappedTransformOp(t)) - assert len(op_set) == 1 - op_set.add(WrappedTransformOp(t)) + op = WrappedTransformOp(t) + op_set = weakref.WeakSet() + op_set.add(op) assert len(op_set) == 1 - del t - for op in op_set: - import sys - print(sys.getrefcount(op), sys.getrefcount(op.fn)) + del op assert len(op_set) == 0 From a5c747a7ee96d193df8f40a4f76404fc06985ca2 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Jan 2021 10:09:25 -0500 Subject: [PATCH 11/15] Raise NotImplementedError on missing shape metadata --- funsor/ops/op.py | 13 +++++++++++++ test/test_distribution.py | 16 ++++++++++++++-- test/test_distribution_generic.py | 2 +- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 06c251af6..f91cafafe 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -205,6 +205,19 @@ def __init__(self, fn): def __call__(self, x): if self._is_validated: return super().__call__(x) + try: + # Check for shape metadata available after + # https://github.com/pytorch/pytorch/pull/50547 + # https://github.com/pytorch/pytorch/pull/50581 + # https://github.com/pyro-ppl/pyro/pull/2739 + # https://github.com/pyro-ppl/numpyro/pull/876 + self.fn.domain.event_dim + self.fn.codomain.event_dim + self.fn.forward_shape + except AttributeError: + backend = self.fn.__module__.split(">")[0] + raise NotImplementedError(f"{self.fn} is missing shape metadata; " + f"try upgrading backend {backend}") if len(x.shape) < self.fn.domain.event_dim: raise ValueError(f"Too few dimensions for input, in {self.name}") event_shape = x.shape[len(x.shape) - self.fn.domain.event_dim:] diff --git a/test/test_distribution.py b/test/test_distribution.py index 7144c7795..85af4657a 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -19,8 +19,18 @@ from funsor.interpreter import interpretation, reinterpret from funsor.tensor import Einsum, Tensor, numeric_array, stack from funsor.terms import Independent, Variable, eager, lazy, to_funsor -from funsor.testing import assert_close, check_funsor, rand, randint, randn, \ - random_mvn, random_scale_tril, random_tensor, xfail_param +from funsor.testing import ( + assert_close, + check_funsor, + rand, + randint, + randn, + random_mvn, + random_scale_tril, + random_tensor, + xfail_if_not_implemented, + xfail_param +) from funsor.util import get_backend pytestmark = pytest.mark.skipif(get_backend() == "numpy", @@ -1217,6 +1227,7 @@ def test_categorical_event_dim_conversion(batch_shape, event_shape): assert_close(actual_log_prob, expected_log_prob) +@xfail_if_not_implemented() @pytest.mark.parametrize("shape", [(), (4,), (3, 2)], ids=str) def test_power_transform(shape): transform = backend_dist.transforms.PowerTransform(1.5) @@ -1235,6 +1246,7 @@ def test_power_transform(shape): assert_close(actual_log_prob, expected_log_prob) +@xfail_if_not_implemented() @pytest.mark.parametrize("shape", [(10,), (4, 3)], ids=str) @pytest.mark.parametrize("to_event", [ True, diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index f7bee7f94..0bd53026f 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -523,7 +523,7 @@ def test_generic_log_prob(case, use_lazy): raw_value = raw_dist.sample() expected_logprob = to_funsor(raw_dist.log_prob(raw_value), output=funsor.Real, dim_to_name=dim_to_name) funsor_value = to_funsor(raw_value, output=expected_value_domain, dim_to_name=dim_to_name) - assert_close(funsor_dist(value=funsor_value), expected_logprob, rtol=1e-4 if use_lazy else 1e-3) + assert_close(funsor_dist(value=funsor_value), expected_logprob, rtol=1e-3) @pytest.mark.parametrize("case", TEST_CASES, ids=str) From e3b18478347dfa28776cf93aff4551552caa72a9 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Jan 2021 10:10:12 -0500 Subject: [PATCH 12/15] fix typo --- funsor/ops/op.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index f91cafafe..155ad23f4 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -205,6 +205,7 @@ def __init__(self, fn): def __call__(self, x): if self._is_validated: return super().__call__(x) + try: # Check for shape metadata available after # https://github.com/pytorch/pytorch/pull/50547 @@ -215,9 +216,10 @@ def __call__(self, x): self.fn.codomain.event_dim self.fn.forward_shape except AttributeError: - backend = self.fn.__module__.split(">")[0] + backend = self.fn.__module__.split(".")[0] raise NotImplementedError(f"{self.fn} is missing shape metadata; " f"try upgrading backend {backend}") + if len(x.shape) < self.fn.domain.event_dim: raise ValueError(f"Too few dimensions for input, in {self.name}") event_shape = x.shape[len(x.shape) - self.fn.domain.event_dim:] From 31da249af0df5e887c2d8981e1fc1414e3615162 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Jan 2021 13:36:54 -0500 Subject: [PATCH 13/15] Fix JAX tests --- funsor/jax/distributions.py | 8 ++++++++ test/test_distribution.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index 15f685304..14a9d14e1 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -218,6 +218,14 @@ def deltadist_to_data(funsor_dist, name_to_dim=None): dist.TransformedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample) dist.TransformedDistribution.rsample = dist.TransformedDistribution.sample + +@to_funsor.register(dist.transforms.Transform) +def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None): + op = ops.WrappedTransformOp(tfm) + name = next(real_inputs.keys()) if real_inputs else "value" + return op(Variable(name, output)) + + to_funsor.register(dist.ExpandedDistribution)(expandeddist_to_funsor) to_funsor.register(dist.Independent)(indepdist_to_funsor) if hasattr(dist, "MaskedDistribution"): diff --git a/test/test_distribution.py b/test/test_distribution.py index 85af4657a..adf110cdf 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1234,7 +1234,7 @@ def test_power_transform(shape): base_dist = backend_dist.Exponential(1) d = backend_dist.TransformedDistribution(base_dist, transform) - data = randn(shape).exp() + data = ops.exp(randn(shape)) expected_log_prob = d.log_prob(data) name_to_dim = dict(i=-3, j=-2, k=-1) From 31a07d74aa8e0c5f92a502515f82568f2184fb03 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 24 Jan 2021 19:02:37 -0500 Subject: [PATCH 14/15] Add more tests --- funsor/testing.py | 4 ++- funsor/torch/distributions.py | 5 ++++ test/test_distribution_generic.py | 43 +++++++++++++++++++++++++------ 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/funsor/testing.py b/funsor/testing.py index cddb8b27b..b50701fb7 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -25,10 +25,12 @@ @contextlib.contextmanager -def xfail_if_not_implemented(msg="Not implemented"): +def xfail_if_not_implemented(msg="Not implemented", *, match=None): try: yield except NotImplementedError as e: + if match is not None and match not in str(e): + raise e from None import pytest pytest.xfail(reason='{}:\n{}'.format(msg, e)) diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index 9dfaa94f4..39e032524 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -227,6 +227,11 @@ def transform_to_torch_transform(op, name_to_dim=None): raise NotImplementedError("{} is not a currently supported transform".format(op)) +@op_to_torch_transform.register(ops.WrappedTransformOp) +def transform_to_torch_transform(op, name_to_dim=None): + return op.fn + + @op_to_torch_transform.register(ops.ExpOp) def exp_to_torch_transform(op, name_to_dim=None): return torch.distributions.transforms.ExpTransform() diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index a2a3d21e8..39769325e 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -468,7 +468,6 @@ def __hash__(self): funsor.Real, xfail_reason="backend not supported" if get_backend() != "torch" else "", ) - # SigmoidTransform (inversion not working) DistTestCase( """ @@ -481,6 +480,28 @@ def __hash__(self): funsor.Real, xfail_reason="failure to re-invert ops.sigmoid.inv, which is not atomic", ) + # PowerTransform + DistTestCase( + """ + dist.TransformedDistribution( + dist.Exponential(rate=case.rate), + dist.transforms.PowerTransform(0.5)) + """, + (("rate", f"rand({batch_shape})"),), + funsor.Real, + xfail_reason=("backend not supported" if get_backend() != "torch" else ""), + ) + # HaarTransform + DistTestCase( + """ + dist.TransformedDistribution( + dist.Normal(loc=case.loc, scale=1.).to_event(1), + dist.transforms.HaarTransform(dim=-1)) + """, + (("loc", f"rand({batch_shape} + (3,))"),), + funsor.Reals[3], + xfail_reason=("backend not supported" if get_backend() != "torch" else ""), + ) # Independent for indep_shape in [(3,), (2, 3)]: @@ -559,7 +580,8 @@ def test_generic_distribution_to_funsor(case): dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) with interpretation(normalize_with_subs): - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + with xfail_if_not_implemented(match="try upgrading backend"): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) assert funsor_dist.inputs["value"] == expected_value_domain while isinstance(funsor_dist, funsor.cnf.Contraction): @@ -593,8 +615,9 @@ def test_generic_log_prob(case, use_lazy): dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) with interpretation(normalize_with_subs if use_lazy else eager): - # some distributions have nontrivial eager patterns - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + with xfail_if_not_implemented(match="try upgrading backend"): + # some distributions have nontrivial eager patterns + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) expected_inputs = {name: funsor.Bint[raw_dist.batch_shape[dim]] for dim, name in dim_to_name.items()} expected_inputs.update({"value": expected_value_domain}) @@ -616,7 +639,8 @@ def test_generic_enumerate_support(case, expand): dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) with interpretation(normalize_with_subs): - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + with xfail_if_not_implemented(match="try upgrading backend"): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) assert getattr(raw_dist, "has_enumerate_support", False) == getattr(funsor_dist, "has_enumerate_support", False) if getattr(funsor_dist, "has_enumerate_support", False): @@ -634,7 +658,8 @@ def test_generic_sample(case, sample_shape): dim_to_name, name_to_dim = _default_dim_to_name(sample_shape + raw_dist.batch_shape) with interpretation(normalize_with_subs): - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + with xfail_if_not_implemented(match="try upgrading backend"): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) sample_inputs = OrderedDict((dim_to_name[dim - len(raw_dist.batch_shape)], funsor.Bint[sample_shape[dim]]) for dim in range(-len(sample_shape), 0)) @@ -656,13 +681,15 @@ def test_generic_stats(case, statistic): dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) with interpretation(normalize_with_subs): - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + with xfail_if_not_implemented(match="try upgrading backend"): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) with xfail_if_not_implemented(msg="entropy not implemented for some distributions"), \ xfail_if_not_found(msg="stats not implemented yet for TransformedDist"): actual_stat = getattr(funsor_dist, statistic)() - expected_stat_raw = getattr(raw_dist, statistic) + with xfail_if_not_implemented(): + expected_stat_raw = getattr(raw_dist, statistic) if statistic == "entropy": expected_stat = to_funsor(expected_stat_raw(), output=funsor.Real, dim_to_name=dim_to_name) else: From 3c2ff61f9a32a376b66d8ea163ad9bb9ef28be27 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 25 Jan 2021 10:13:50 -0500 Subject: [PATCH 15/15] Remove unused environment variable --- test/test_distribution_generic.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 39769325e..e979b0c24 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -1,7 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import os import re from collections import OrderedDict from importlib import import_module @@ -27,8 +26,6 @@ ) from funsor.util import get_backend -_ENABLE_MC_DIST_TESTS = int(os.environ.get("FUNSOR_ENABLE_MC_DIST_TESTS", 0)) - pytestmark = pytest.mark.skipif(get_backend() == "numpy", reason="numpy does not have distributions backend") if get_backend() != "numpy":