Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add factory for dynamic TransformOp #427

Merged
merged 20 commits into from
Jan 27, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ def backenddist_to_funsor(funsor_dist_class, backend_dist, output=None, dim_to_n


def indepdist_to_funsor(backend_dist, output=None, dim_to_name=None):
if dim_to_name is None:
dim_to_name = {}
dim_to_name = OrderedDict((dim - backend_dist.reinterpreted_batch_ndims, name)
for dim, name in dim_to_name.items())
dim_to_name.update(OrderedDict((i, "_pyro_event_dim_{}".format(i))
Expand Down Expand Up @@ -409,6 +411,7 @@ def maskeddist_to_funsor(backend_dist, output=None, dim_to_name=None):
return mask * funsor_base_dist


# TODO make this work with transforms with nontrivial event_dim logic
# converts TransformedDistributions
def transformeddist_to_funsor(backend_dist, output=None, dim_to_name=None):
dist_module = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist
Expand Down
14 changes: 13 additions & 1 deletion funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def find_domain(op, *domains):


@find_domain.register(ops.Op) # TODO this is too general, register all ops
@find_domain.register(ops.TransformOp) # TODO too general, may be wrong for some
@find_domain.register(ops.ReciprocalOp)
@find_domain.register(ops.SigmoidOp)
@find_domain.register(ops.TanhOp)
Expand Down Expand Up @@ -312,6 +311,19 @@ def _find_domain_associative_generic(op, *domains):
return Array[dtype, shape]


@find_domain.register(ops.TransformOp)
def _transform_find_domain(op, domain):
fn = op.dispatch(object)
shape = fn.forward_shape(domain.shape)
return Array[domain.dtype, shape]


@find_domain.register(ops.LogAbsDetJacobianOp)
def _transform_log_abs_det_jacobian(op, domain, codomain):
# TODO do we need to handle batch shape here?
return Real


__all__ = [
'Bint',
'BintType',
Expand Down
8 changes: 8 additions & 0 deletions funsor/jax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ def deltadist_to_data(funsor_dist, name_to_dim=None):
dist.TransformedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample)
dist.TransformedDistribution.rsample = dist.TransformedDistribution.sample


@to_funsor.register(dist.transforms.Transform)
def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
op = ops.WrappedTransformOp(tfm)
name = next(real_inputs.keys()) if real_inputs else "value"
return op(Variable(name, output))


to_funsor.register(dist.ExpandedDistribution)(expandeddist_to_funsor)
to_funsor.register(dist.Independent)(indepdist_to_funsor)
if hasattr(dist, "MaskedDistribution"):
Expand Down
12 changes: 2 additions & 10 deletions funsor/ops/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,15 @@ def log_abs_det_jacobian(x, y):

exp.set_inv(log)
log.set_inv(exp)


@tanh.set_inv
def tanh_inv(y):
return atanh(y)
tanh.set_inv(atanh)
atanh.set_inv(tanh)


@tanh.set_log_abs_det_jacobian
def tanh_log_abs_det_jacobian(x, y):
return 2. * (math.log(2.) - x - softplus(-2. * x))


@atanh.set_inv
def atanh_inv(y):
return tanh(y)


@atanh.set_log_abs_det_jacobian
def atanh_log_abs_det_jacobian(x, y):
return -tanh.log_abs_det_jacobian(y, x)
Expand Down
121 changes: 112 additions & 9 deletions funsor/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,66 @@
from multipledispatch import Dispatcher


class WeakPartial:
fritzo marked this conversation as resolved.
Show resolved Hide resolved
"""
Like ``functools.partial(fn, arg)`` but weakly referencing ``arg``.
"""
def __init__(self, fn, arg):
self.fn = fn
self.weak_arg = weakref.ref(arg)
functools.update_wrapper(self, fn)
fritzo marked this conversation as resolved.
Show resolved Hide resolved

def __call__(self, *args):
arg = self.weak_arg()
return self.fn(arg, *args)


class CachedOpMeta(type):
"""
Metaclass for caching op instance construction.
Caching strategy is to key on ``*args`` and retain values forever.
"""
def __init__(cls, *args, **kwargs):
super().__init__(*args, **kwargs)
cls._instance_cache = {}

def __call__(cls, *args, **kwargs):
try:
return cls._cache[args]
return cls._instance_cache[args]
except KeyError:
instance = super(CachedOpMeta, cls).__call__(*args, **kwargs)
cls._cache[args] = instance
cls._instance_cache[args] = instance
return instance


class WrappedOpMeta(type):
"""
Metaclass for ops that wrap temporary backend ops.
Caching strategy is to key on ``id(backend_op)`` and forget values asap.
"""
def __init__(cls, *args, **kwargs):
super().__init__(*args, **kwargs)
cls._instance_cache = weakref.WeakValueDictionary()

def __call__(cls, fn):
if inspect.ismethod(fn):
key = id(fn.__self__), fn.__func__ # e.g. t.log_abs_det_jacobian
else:
key = id(fn) # e.g. t.inv
try:
return cls._instance_cache[key]
except KeyError:
op = super().__call__(fn)
op.fn = fn # Ensures the key id(fn) is not reused.
cls._instance_cache[key] = op
return op


class Op(Dispatcher):
_all_instances = weakref.WeakSet()

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._cache = {}
cls._subclass_registry = []

def __init__(self, fn, *, name=None):
Expand All @@ -44,7 +85,7 @@ def __init__(self, fn, *, name=None):
# Register all existing patterns.
for supercls in reversed(inspect.getmro(type(self))):
for pattern, fn in getattr(supercls, "_subclass_registry", ()):
self.register(*pattern)(functools.partial(fn, self))
self.add(pattern, WeakPartial(fn, self))
# Save self for registering future patterns.
Op._all_instances.add(self)

Expand All @@ -69,14 +110,17 @@ def decorator(fn):
# Register with all existing instances.
for op in Op._all_instances:
if isinstance(op, cls):
op.register(*pattern)(functools.partial(fn, op))
op.add(pattern, WeakPartial(fn, op))
# Ensure registration with all future instances.
cls._subclass_registry.append((pattern, fn))
return fn
return decorator


def make_op(fn=None, parent=None, *, name=None, module_name="funsor.ops"):
"""
Factory to create a new :class:`Op` subclass and a new instance of that class.
"""
# Support use as decorator.
if fn is None:
return lambda fn: make_op(fn, parent, name=name, module_name=module_name)
Expand All @@ -89,10 +133,11 @@ def make_op(fn=None, parent=None, *, name=None, module_name="funsor.ops"):
name = fn if isinstance(fn, str) else fn.__name__
assert isinstance(name, str)

classname = name[0].upper() + name[1:].rstrip("_") + "Op" # e.g. add -> AddOp
new_type = CachedOpMeta(classname, (parent,), {})
new_type.__module__ = module_name
return new_type(fn, name=name)
classname = name.capitalize().rstrip("_") + "Op" # e.g. add -> AddOp
cls = type(classname, (parent,), {})
cls.__module__ = module_name
op = cls(fn, name=name)
return op


def declare_op_types(locals_, all_, name_):
Expand All @@ -115,6 +160,10 @@ class UnaryOp(Op):
pass


class BinaryOp(Op):
pass


class TransformOp(UnaryOp):
def set_inv(self, fn):
"""
Expand Down Expand Up @@ -143,19 +192,73 @@ def log_abs_det_jacobian(x, y):
raise NotImplementedError


class WrappedTransformOp(TransformOp, metaclass=WrappedOpMeta):
"""
Wrapper for a backend ``Transform`` object that provides ``.inv`` and
``.log_abs_det_jacobian``. This additionally validates shapes on the first
:meth:`__call__`.
"""
def __init__(self, fn):
super().__init__(fn, name=type(fn).__name__)
self._is_validated = False

def __call__(self, x):
if self._is_validated:
return super().__call__(x)

try:
# Check for shape metadata available after
# https://github.com/pytorch/pytorch/pull/50547
# https://github.com/pytorch/pytorch/pull/50581
# https://github.com/pyro-ppl/pyro/pull/2739
# https://github.com/pyro-ppl/numpyro/pull/876
self.fn.domain.event_dim
self.fn.codomain.event_dim
self.fn.forward_shape
except AttributeError:
backend = self.fn.__module__.split(".")[0]
raise NotImplementedError(f"{self.fn} is missing shape metadata; "
f"try upgrading backend {backend}")

if len(x.shape) < self.fn.domain.event_dim:
raise ValueError(f"Too few dimensions for input, in {self.name}")
event_shape = x.shape[len(x.shape) - self.fn.domain.event_dim:]
shape = self.fn.forward_shape(event_shape)
if len(shape) > self.fn.codomain.event_dim:
raise ValueError(f"Cannot treat transform {self.name} as an Op "
"because it is batched")
self._is_validated = True
return super().__call__(x)

@property
def inv(self):
return WrappedTransformOp(self.fn.inv)

@property
def log_abs_det_jacobian(self):
return LogAbsDetJacobianOp(self.fn.log_abs_det_jacobian)


class LogAbsDetJacobianOp(BinaryOp, metaclass=WrappedOpMeta):
pass


# Op registration tables.
DISTRIBUTIVE_OPS = set() # (add, mul) pairs
UNITS = {} # op -> value
PRODUCT_INVERSES = {} # op -> inverse op

__all__ = [
'BinaryOp',
'CachedOpMeta',
'DISTRIBUTIVE_OPS',
'LogAbsDetJacobianOp',
'Op',
'PRODUCT_INVERSES',
'TransformOp',
'UNITS',
'UnaryOp',
'WrappedTransformOp',
'declare_op_types',
'make_op',
]
9 changes: 6 additions & 3 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,16 +1753,19 @@ def quote_inplace_first_arg_on_first_line(arg, indent, out):


ops.UnaryOp.subclass_register(Funsor)(Unary)
ops.AssociativeOp.subclass_register(Funsor)(Unary) # Reductions.
ops.BinaryOp.subclass_register(Funsor, Funsor)(Binary)
ops.AssociativeOp.subclass_register(Funsor, Funsor)(Binary)
ops.AssociativeOp.subclass_register(Funsor)(Unary) # Reductions.


@AssociativeOp.subclass_register(object, Funsor)
@ops.BinaryOp.subclass_register(object, Funsor)
@ops.AssociativeOp.subclass_register(object, Funsor)
def binary_object_funsor(op, x, y):
return Binary(op, to_funsor(x), y)


@AssociativeOp.subclass_register(Funsor, object)
@ops.BinaryOp.subclass_register(Funsor, object)
@ops.AssociativeOp.subclass_register(Funsor, object)
def binary_funsor_object(op, x, y):
return Binary(op, x, to_funsor(y))

Expand Down
15 changes: 8 additions & 7 deletions funsor/torch/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

import pyro.distributions as dist
import pyro.distributions.testing.fakes as fakes
from pyro.distributions.torch_distribution import ExpandedDistribution, MaskedDistribution
import torch
from pyro.distributions.torch_distribution import ExpandedDistribution, MaskedDistribution

import funsor.ops as ops
from funsor.cnf import Contraction
from funsor.distribution import ( # noqa: F401
Bernoulli,
FUNSOR_DIST_NAMES,
Bernoulli,
LogNormal,
backenddist_to_funsor,
eager_beta,
Expand All @@ -24,10 +25,10 @@
eager_delta_funsor_funsor,
eager_delta_funsor_variable,
eager_delta_tensor,
eager_delta_variable_variable,
eager_dirichlet_categorical,
eager_dirichlet_multinomial,
eager_dirichlet_posterior,
eager_delta_variable_variable,
eager_gamma_gamma,
eager_gamma_poisson,
eager_multinomial,
Expand All @@ -38,15 +39,13 @@
indepdist_to_funsor,
make_dist,
maskeddist_to_funsor,
transformeddist_to_funsor,
transformeddist_to_funsor
)
from funsor.domains import Real, Reals
import funsor.ops as ops
from funsor.tensor import Tensor
from funsor.terms import Binary, Funsor, Reduce, Unary, Variable, eager, to_data, to_funsor
from funsor.util import methodof


__all__ = list(x[0] for x in FUNSOR_DIST_NAMES)


Expand Down Expand Up @@ -269,7 +268,9 @@ def transform_to_data(expr, name_to_dim=None):

@to_funsor.register(torch.distributions.Transform)
def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
raise NotImplementedError("{} is not a currently supported transform".format(tfm))
op = ops.WrappedTransformOp(tfm)
name = next(real_inputs.keys()) if real_inputs else "value"
return op(Variable(name, output))


@to_funsor.register(torch.distributions.transforms.ExpTransform)
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ filterwarnings = error
ignore:numpy.ufunc size changed:RuntimeWarning
ignore:numpy.dtype size changed:RuntimeWarning
ignore:Mixed memory format:UserWarning
ignore:Cannot memoize Op:UserWarning
ignore::DeprecationWarning
once::DeprecationWarning

Expand Down
Loading