diff --git a/funsor/distribution.py b/funsor/distribution.py index 6f7492ce2..00e8feb48 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -424,6 +424,7 @@ def maskeddist_to_funsor(backend_dist, output=None, dim_to_name=None): return mask * funsor_base_dist +# TODO make this work with transforms with nontrivial event_dim logic # converts TransformedDistributions def transformeddist_to_funsor(backend_dist, output=None, dim_to_name=None): dist_module = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist diff --git a/funsor/domains.py b/funsor/domains.py index 87f2ce644..75fcfe3b4 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -223,7 +223,6 @@ def find_domain(op, *domains): @find_domain.register(ops.Op) # TODO this is too general, register all ops -@find_domain.register(ops.TransformOp) # TODO too general, may be wrong for some @find_domain.register(ops.ReciprocalOp) @find_domain.register(ops.SigmoidOp) @find_domain.register(ops.TanhOp) @@ -312,6 +311,19 @@ def _find_domain_associative_generic(op, *domains): return Array[dtype, shape] +@find_domain.register(ops.TransformOp) +def _transform_find_domain(op, domain): + fn = op.dispatch(object) + shape = fn.forward_shape(domain.shape) + return Array[domain.dtype, shape] + + +@find_domain.register(ops.LogAbsDetJacobianOp) +def _transform_log_abs_det_jacobian(op, domain, codomain): + # TODO do we need to handle batch shape here? + return Real + + __all__ = [ 'Bint', 'BintType', diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index 15f685304..14a9d14e1 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -218,6 +218,14 @@ def deltadist_to_data(funsor_dist, name_to_dim=None): dist.TransformedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample) dist.TransformedDistribution.rsample = dist.TransformedDistribution.sample + +@to_funsor.register(dist.transforms.Transform) +def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None): + op = ops.WrappedTransformOp(tfm) + name = next(real_inputs.keys()) if real_inputs else "value" + return op(Variable(name, output)) + + to_funsor.register(dist.ExpandedDistribution)(expandeddist_to_funsor) to_funsor.register(dist.Independent)(indepdist_to_funsor) if hasattr(dist, "MaskedDistribution"): diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index 70000704c..a2af21e92 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -137,11 +137,8 @@ def log_abs_det_jacobian(x, y): exp.set_inv(log) log.set_inv(exp) - - -@tanh.set_inv -def tanh_inv(y): - return atanh(y) +tanh.set_inv(atanh) +atanh.set_inv(tanh) @tanh.set_log_abs_det_jacobian @@ -149,11 +146,6 @@ def tanh_log_abs_det_jacobian(x, y): return 2. * (math.log(2.) - x - softplus(-2. * x)) -@atanh.set_inv -def atanh_inv(y): - return tanh(y) - - @atanh.set_log_abs_det_jacobian def atanh_log_abs_det_jacobian(x, y): return -tanh.log_abs_det_jacobian(y, x) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 700c70eae..155ad23f4 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -8,25 +8,66 @@ from multipledispatch import Dispatcher +class WeakPartial: + """ + Like ``functools.partial(fn, arg)`` but weakly referencing ``arg``. + """ + def __init__(self, fn, arg): + self.fn = fn + self.weak_arg = weakref.ref(arg) + functools.update_wrapper(self, fn) + + def __call__(self, *args): + arg = self.weak_arg() + return self.fn(arg, *args) + + class CachedOpMeta(type): """ Metaclass for caching op instance construction. + Caching strategy is to key on ``*args`` and retain values forever. """ + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + cls._instance_cache = {} + def __call__(cls, *args, **kwargs): try: - return cls._cache[args] + return cls._instance_cache[args] except KeyError: instance = super(CachedOpMeta, cls).__call__(*args, **kwargs) - cls._cache[args] = instance + cls._instance_cache[args] = instance return instance +class WrappedOpMeta(type): + """ + Metaclass for ops that wrap temporary backend ops. + Caching strategy is to key on ``id(backend_op)`` and forget values asap. + """ + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + cls._instance_cache = weakref.WeakValueDictionary() + + def __call__(cls, fn): + if inspect.ismethod(fn): + key = id(fn.__self__), fn.__func__ # e.g. t.log_abs_det_jacobian + else: + key = id(fn) # e.g. t.inv + try: + return cls._instance_cache[key] + except KeyError: + op = super().__call__(fn) + op.fn = fn # Ensures the key id(fn) is not reused. + cls._instance_cache[key] = op + return op + + class Op(Dispatcher): _all_instances = weakref.WeakSet() def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - cls._cache = {} cls._subclass_registry = [] def __init__(self, fn, *, name=None): @@ -44,7 +85,7 @@ def __init__(self, fn, *, name=None): # Register all existing patterns. for supercls in reversed(inspect.getmro(type(self))): for pattern, fn in getattr(supercls, "_subclass_registry", ()): - self.register(*pattern)(functools.partial(fn, self)) + self.add(pattern, WeakPartial(fn, self)) # Save self for registering future patterns. Op._all_instances.add(self) @@ -69,7 +110,7 @@ def decorator(fn): # Register with all existing instances. for op in Op._all_instances: if isinstance(op, cls): - op.register(*pattern)(functools.partial(fn, op)) + op.add(pattern, WeakPartial(fn, op)) # Ensure registration with all future instances. cls._subclass_registry.append((pattern, fn)) return fn @@ -77,6 +118,9 @@ def decorator(fn): def make_op(fn=None, parent=None, *, name=None, module_name="funsor.ops"): + """ + Factory to create a new :class:`Op` subclass and a new instance of that class. + """ # Support use as decorator. if fn is None: return lambda fn: make_op(fn, parent, name=name, module_name=module_name) @@ -89,10 +133,11 @@ def make_op(fn=None, parent=None, *, name=None, module_name="funsor.ops"): name = fn if isinstance(fn, str) else fn.__name__ assert isinstance(name, str) - classname = name[0].upper() + name[1:].rstrip("_") + "Op" # e.g. add -> AddOp - new_type = CachedOpMeta(classname, (parent,), {}) - new_type.__module__ = module_name - return new_type(fn, name=name) + classname = name.capitalize().rstrip("_") + "Op" # e.g. add -> AddOp + cls = type(classname, (parent,), {}) + cls.__module__ = module_name + op = cls(fn, name=name) + return op def declare_op_types(locals_, all_, name_): @@ -115,6 +160,10 @@ class UnaryOp(Op): pass +class BinaryOp(Op): + pass + + class TransformOp(UnaryOp): def set_inv(self, fn): """ @@ -143,19 +192,73 @@ def log_abs_det_jacobian(x, y): raise NotImplementedError +class WrappedTransformOp(TransformOp, metaclass=WrappedOpMeta): + """ + Wrapper for a backend ``Transform`` object that provides ``.inv`` and + ``.log_abs_det_jacobian``. This additionally validates shapes on the first + :meth:`__call__`. + """ + def __init__(self, fn): + super().__init__(fn, name=type(fn).__name__) + self._is_validated = False + + def __call__(self, x): + if self._is_validated: + return super().__call__(x) + + try: + # Check for shape metadata available after + # https://github.com/pytorch/pytorch/pull/50547 + # https://github.com/pytorch/pytorch/pull/50581 + # https://github.com/pyro-ppl/pyro/pull/2739 + # https://github.com/pyro-ppl/numpyro/pull/876 + self.fn.domain.event_dim + self.fn.codomain.event_dim + self.fn.forward_shape + except AttributeError: + backend = self.fn.__module__.split(".")[0] + raise NotImplementedError(f"{self.fn} is missing shape metadata; " + f"try upgrading backend {backend}") + + if len(x.shape) < self.fn.domain.event_dim: + raise ValueError(f"Too few dimensions for input, in {self.name}") + event_shape = x.shape[len(x.shape) - self.fn.domain.event_dim:] + shape = self.fn.forward_shape(event_shape) + if len(shape) > self.fn.codomain.event_dim: + raise ValueError(f"Cannot treat transform {self.name} as an Op " + "because it is batched") + self._is_validated = True + return super().__call__(x) + + @property + def inv(self): + return WrappedTransformOp(self.fn.inv) + + @property + def log_abs_det_jacobian(self): + return LogAbsDetJacobianOp(self.fn.log_abs_det_jacobian) + + +class LogAbsDetJacobianOp(BinaryOp, metaclass=WrappedOpMeta): + pass + + # Op registration tables. DISTRIBUTIVE_OPS = set() # (add, mul) pairs UNITS = {} # op -> value PRODUCT_INVERSES = {} # op -> inverse op __all__ = [ + 'BinaryOp', 'CachedOpMeta', 'DISTRIBUTIVE_OPS', + 'LogAbsDetJacobianOp', 'Op', 'PRODUCT_INVERSES', 'TransformOp', 'UNITS', 'UnaryOp', + 'WrappedTransformOp', 'declare_op_types', 'make_op', ] diff --git a/funsor/terms.py b/funsor/terms.py index 661b92c19..7ea2783d0 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1753,16 +1753,19 @@ def quote_inplace_first_arg_on_first_line(arg, indent, out): ops.UnaryOp.subclass_register(Funsor)(Unary) -ops.AssociativeOp.subclass_register(Funsor)(Unary) # Reductions. +ops.BinaryOp.subclass_register(Funsor, Funsor)(Binary) ops.AssociativeOp.subclass_register(Funsor, Funsor)(Binary) +ops.AssociativeOp.subclass_register(Funsor)(Unary) # Reductions. -@AssociativeOp.subclass_register(object, Funsor) +@ops.BinaryOp.subclass_register(object, Funsor) +@ops.AssociativeOp.subclass_register(object, Funsor) def binary_object_funsor(op, x, y): return Binary(op, to_funsor(x), y) -@AssociativeOp.subclass_register(Funsor, object) +@ops.BinaryOp.subclass_register(Funsor, object) +@ops.AssociativeOp.subclass_register(Funsor, object) def binary_funsor_object(op, x, y): return Binary(op, x, to_funsor(y)) diff --git a/funsor/testing.py b/funsor/testing.py index cddb8b27b..b50701fb7 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -25,10 +25,12 @@ @contextlib.contextmanager -def xfail_if_not_implemented(msg="Not implemented"): +def xfail_if_not_implemented(msg="Not implemented", *, match=None): try: yield except NotImplementedError as e: + if match is not None and match not in str(e): + raise e from None import pytest pytest.xfail(reason='{}:\n{}'.format(msg, e)) diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index 6fe1b3186..39e032524 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -7,13 +7,14 @@ import pyro.distributions as dist import pyro.distributions.testing.fakes as fakes -from pyro.distributions.torch_distribution import ExpandedDistribution, MaskedDistribution import torch +from pyro.distributions.torch_distribution import ExpandedDistribution, MaskedDistribution +import funsor.ops as ops from funsor.cnf import Contraction from funsor.distribution import ( # noqa: F401 - Bernoulli, FUNSOR_DIST_NAMES, + Bernoulli, LogNormal, backenddist_to_funsor, eager_beta, @@ -24,10 +25,10 @@ eager_delta_funsor_funsor, eager_delta_funsor_variable, eager_delta_tensor, + eager_delta_variable_variable, eager_dirichlet_categorical, eager_dirichlet_multinomial, eager_dirichlet_posterior, - eager_delta_variable_variable, eager_gamma_gamma, eager_gamma_poisson, eager_multinomial, @@ -38,15 +39,13 @@ indepdist_to_funsor, make_dist, maskeddist_to_funsor, - transformeddist_to_funsor, + transformeddist_to_funsor ) from funsor.domains import Real, Reals -import funsor.ops as ops from funsor.tensor import Tensor from funsor.terms import Binary, Funsor, Reduce, Unary, Variable, eager, to_data, to_funsor from funsor.util import methodof - __all__ = list(x[0] for x in FUNSOR_DIST_NAMES) @@ -228,6 +227,11 @@ def transform_to_torch_transform(op, name_to_dim=None): raise NotImplementedError("{} is not a currently supported transform".format(op)) +@op_to_torch_transform.register(ops.WrappedTransformOp) +def transform_to_torch_transform(op, name_to_dim=None): + return op.fn + + @op_to_torch_transform.register(ops.ExpOp) def exp_to_torch_transform(op, name_to_dim=None): return torch.distributions.transforms.ExpTransform() @@ -269,7 +273,9 @@ def transform_to_data(expr, name_to_dim=None): @to_funsor.register(torch.distributions.Transform) def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None): - raise NotImplementedError("{} is not a currently supported transform".format(tfm)) + op = ops.WrappedTransformOp(tfm) + name = next(real_inputs.keys()) if real_inputs else "value" + return op(Variable(name, output)) @to_funsor.register(torch.distributions.transforms.ExpTransform) diff --git a/setup.cfg b/setup.cfg index b30e5bb06..b1775b96b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,6 +18,7 @@ filterwarnings = error ignore:numpy.ufunc size changed:RuntimeWarning ignore:numpy.dtype size changed:RuntimeWarning ignore:Mixed memory format:UserWarning + ignore:Cannot memoize Op:UserWarning ignore::DeprecationWarning once::DeprecationWarning diff --git a/test/test_distribution.py b/test/test_distribution.py index f79202b0b..adf110cdf 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -19,8 +19,18 @@ from funsor.interpreter import interpretation, reinterpret from funsor.tensor import Einsum, Tensor, numeric_array, stack from funsor.terms import Independent, Variable, eager, lazy, to_funsor -from funsor.testing import assert_close, check_funsor, rand, randint, randn, \ - random_mvn, random_scale_tril, random_tensor, xfail_param +from funsor.testing import ( + assert_close, + check_funsor, + rand, + randint, + randn, + random_mvn, + random_scale_tril, + random_tensor, + xfail_if_not_implemented, + xfail_param +) from funsor.util import get_backend pytestmark = pytest.mark.skipif(get_backend() == "numpy", @@ -1215,3 +1225,48 @@ def test_categorical_event_dim_conversion(batch_shape, event_shape): funsor.to_data(data, name_to_dim=name_to_dim)) assert actual_log_prob.shape == expected_log_prob.shape assert_close(actual_log_prob, expected_log_prob) + + +@xfail_if_not_implemented() +@pytest.mark.parametrize("shape", [(), (4,), (3, 2)], ids=str) +def test_power_transform(shape): + transform = backend_dist.transforms.PowerTransform(1.5) + base_dist = backend_dist.Exponential(1) + d = backend_dist.TransformedDistribution(base_dist, transform) + + data = ops.exp(randn(shape)) + expected_log_prob = d.log_prob(data) + + name_to_dim = dict(i=-3, j=-2, k=-1) + dim_to_name = {v: k for k, v in name_to_dim.items()} + x = to_funsor(data, Real, dim_to_name) + f = to_funsor(d, output=Real) + log_prob = f(x) + actual_log_prob = funsor.to_data(log_prob, name_to_dim) + assert_close(actual_log_prob, expected_log_prob) + + +@xfail_if_not_implemented() +@pytest.mark.parametrize("shape", [(10,), (4, 3)], ids=str) +@pytest.mark.parametrize("to_event", [ + True, + xfail_param(False, reason="bug in to_funsor(TransformedDistribution)"), +]) +def test_haar_transform(shape, to_event): + try: + transform = backend_dist.transforms.HaarTransform(dim=-len(shape)) + except AttributeError: + pytest.xfail(reason="backend missing HaarTransform") + base_dist = backend_dist.Normal(0, 1).expand(shape) + if to_event: + # Work around a bug in to_funsor(TransformedDistribution) + base_dist = base_dist.to_event() + d = backend_dist.TransformedDistribution(base_dist, transform) + + data = randn(shape) + expected_log_prob = d.log_prob(data) + + f = to_funsor(d, output=Real) + log_prob = f(data) + actual_log_prob = funsor.to_data(log_prob) + assert_close(actual_log_prob, expected_log_prob) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 802f857ff..0fe0759f0 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -1,7 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import os import re from collections import OrderedDict from importlib import import_module @@ -27,8 +26,6 @@ ) from funsor.util import get_backend -_ENABLE_MC_DIST_TESTS = int(os.environ.get("FUNSOR_ENABLE_MC_DIST_TESTS", 0)) - pytestmark = pytest.mark.skipif(get_backend() == "numpy", reason="numpy does not have distributions backend") if get_backend() != "numpy": @@ -472,7 +469,6 @@ def __hash__(self): funsor.Real, xfail_reason="backend not supported" if get_backend() != "torch" else "", ) - # SigmoidTransform (inversion not working) DistTestCase( """ @@ -485,6 +481,28 @@ def __hash__(self): funsor.Real, xfail_reason="failure to re-invert ops.sigmoid.inv, which is not atomic", ) + # PowerTransform + DistTestCase( + """ + dist.TransformedDistribution( + dist.Exponential(rate=case.rate), + dist.transforms.PowerTransform(0.5)) + """, + (("rate", f"rand({batch_shape})"),), + funsor.Real, + xfail_reason=("backend not supported" if get_backend() != "torch" else ""), + ) + # HaarTransform + DistTestCase( + """ + dist.TransformedDistribution( + dist.Normal(loc=case.loc, scale=1.).to_event(1), + dist.transforms.HaarTransform(dim=-1)) + """, + (("loc", f"rand({batch_shape} + (3,))"),), + funsor.Reals[3], + xfail_reason=("backend not supported" if get_backend() != "torch" else ""), + ) # Independent for indep_shape in [(3,), (2, 3)]: @@ -563,7 +581,8 @@ def test_generic_distribution_to_funsor(case): dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) with interpretation(eager_no_dists): - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + with xfail_if_not_implemented(match="try upgrading backend"): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) assert funsor_dist.inputs["value"] == expected_value_domain while isinstance(funsor_dist, funsor.cnf.Contraction): @@ -597,8 +616,9 @@ def test_generic_log_prob(case, use_lazy): dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) with interpretation(eager_no_dists if use_lazy else eager): - # some distributions have nontrivial eager patterns - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + with xfail_if_not_implemented(match="try upgrading backend"): + # some distributions have nontrivial eager patterns + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) expected_inputs = {name: funsor.Bint[raw_dist.batch_shape[dim]] for dim, name in dim_to_name.items()} expected_inputs.update({"value": expected_value_domain}) @@ -620,7 +640,8 @@ def test_generic_enumerate_support(case, expand): dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) with interpretation(eager_no_dists): - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + with xfail_if_not_implemented(match="try upgrading backend"): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) assert getattr(raw_dist, "has_enumerate_support", False) == getattr(funsor_dist, "has_enumerate_support", False) if getattr(funsor_dist, "has_enumerate_support", False): @@ -638,7 +659,8 @@ def test_generic_sample(case, sample_shape): dim_to_name, name_to_dim = _default_dim_to_name(sample_shape + raw_dist.batch_shape) with interpretation(eager_no_dists): - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + with xfail_if_not_implemented(match="try upgrading backend"): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) sample_inputs = OrderedDict((dim_to_name[dim - len(raw_dist.batch_shape)], funsor.Bint[sample_shape[dim]]) for dim in range(-len(sample_shape), 0)) @@ -660,13 +682,15 @@ def test_generic_stats(case, statistic): dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) with interpretation(eager_no_dists): - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + with xfail_if_not_implemented(match="try upgrading backend"): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) with xfail_if_not_implemented(msg="entropy not implemented for some distributions"), \ xfail_if_not_found(msg="stats not implemented yet for TransformedDist"): actual_stat = getattr(funsor_dist, statistic)() - expected_stat_raw = getattr(raw_dist, statistic) + with xfail_if_not_implemented(): + expected_stat_raw = getattr(raw_dist, statistic) if statistic == "entropy": expected_stat = to_funsor(expected_stat_raw(), output=funsor.Real, dim_to_name=dim_to_name) else: diff --git a/test/test_ops.py b/test/test_ops.py new file mode 100644 index 000000000..9db97165b --- /dev/null +++ b/test/test_ops.py @@ -0,0 +1,38 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import weakref + +import pytest + +from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND +from funsor.util import get_backend +from funsor.ops import WrappedTransformOp + + +@pytest.fixture +def dist(): + try: + module_name = BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()] + except KeyError: + pytest.skip(f"missing distributions module for {get_backend()}") + return pytest.importorskip(module_name).dist + + +def test_transform_op_cache(dist): + t = dist.transforms.PowerTransform(0.5) + W = WrappedTransformOp + assert W(t) is W(t) + assert W(t).inv is W(t).inv + assert W(t.inv) is W(t).inv + assert W(t).log_abs_det_jacobian is W(t).log_abs_det_jacobian + + +def test_transform_op_gc(dist): + t = dist.transforms.PowerTransform(0.5) + op = WrappedTransformOp(t) + op_set = weakref.WeakSet() + op_set.add(op) + assert len(op_set) == 1 + del op + assert len(op_set) == 0