Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update distribution shape inference to handle independent dims #402

Merged
merged 18 commits into from
Dec 17, 2020
44 changes: 36 additions & 8 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import funsor.ops as ops
from funsor.affine import is_affine
from funsor.cnf import Contraction, GaussianMixture
from funsor.domains import Array, Real, Reals
from funsor.domains import Array, Real, Reals, RealsType
from funsor.gaussian import Gaussian
from funsor.interpreter import gensym
from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype,
Expand Down Expand Up @@ -57,12 +57,34 @@ class DistributionMeta(FunsorMeta):
"""
def __call__(cls, *args, **kwargs):
kwargs.update(zip(cls._ast_fields, args))
value = kwargs.pop('value', 'value')
kwargs = OrderedDict(
(k, to_funsor(kwargs[k], output=cls._infer_param_domain(k, getattr(kwargs[k], "shape", ()))))
for k in cls._ast_fields if k != 'value')
value = to_funsor(value, output=cls._infer_value_domain(**{k: v.output for k, v in kwargs.items()}))
args = numbers_to_tensors(*(tuple(kwargs.values()) + (value,)))
kwargs["value"] = kwargs.get("value", "value")
kwargs = OrderedDict((k, kwargs[k]) for k in cls._ast_fields) # make sure args are sorted

domains = OrderedDict()
for k, v in kwargs.items():
if k == "value":
continue

# compute unbroadcasted param domains
domain = cls._infer_param_domain(k, getattr(kwargs[k], "shape", ()))
# use to_funsor to infer output dimensions of e.g. tensors
domains[k] = domain if domain is not None else to_funsor(v).output

# broadcast individual param domains with Funsor inputs
# this avoids .expand-ing underlying parameter tensors
Copy link
Member

@fehiepsi fehiepsi Dec 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected domain of scale for Normal(Reals[2], 1.) and Normal(Reals[2], torch.ones(2))? Currently, domains["scale"] will be Real in both case. The second case will trigger an error at to_funsor(v, output=domains[k]) below.

In either case, I guess we need to rewrite eager_normal or eager_mvn to address Reals[2] loc. Maybe there is some trick to avoid doing so. cc @fritzo

Copy link
Member Author

@eb8680 eb8680 Dec 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected domain of scale for Normal(Reals[2], 1.) and Normal(Reals[2], torch.ones(2))?

In the first case, it's Real, and in the second, it's Reals[2]. I guess I should add a second broadcasting condition below to handle the case where the parameter is a raw tensor:

if ops.is_numeric_array(v):  # at this point we know all of v's dims are output dims
    domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)]

if isinstance(v, Funsor) and isinstance(v.output, RealsType):
domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)]

# now use the broadcasted parameter shapes to infer the event_shape
domains["value"] = cls._infer_value_domain(**domains)
if isinstance(kwargs["value"], Funsor) and isinstance(kwargs["value"].output, RealsType):
# try to broadcast the event shape with the value, in case they disagree
domains["value"] = Reals[broadcast_shape(domains["value"].shape, kwargs["value"].output.shape)]

# finally, perform conversions to funsors
kwargs = OrderedDict((k, to_funsor(v, output=domains[k])) for k, v in kwargs.items())
args = numbers_to_tensors(*kwargs.values())

return super(DistributionMeta, cls).__call__(*args)


Expand Down Expand Up @@ -191,7 +213,13 @@ def _infer_value_domain(cls, **kwargs):
# rely on the underlying distribution's logic to infer the event_shape given param domains
instance = cls.dist_class(**{k: dummy_numeric_array(domain) for k, domain in kwargs.items()},
validate_args=False)
out_shape = instance.event_shape

# Note inclusion of batch_shape here to handle independent event dimensions.
# The arguments to _infer_value_domain are the .output shapes of parameters,
# so any extra batch dimensions that aren't part of the instance event_shape
# must be broadcasted output dimensions by construction.
out_shape = instance.batch_shape + instance.event_shape
Copy link
Member Author

@eb8680 eb8680 Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change to _infer_value_domain is the conceptual meat of the PR.


if type(instance.support).__name__ == "_IntegerInterval":
out_dtype = int(instance.support.upper_bound + 1)
else:
Expand Down