-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update distribution shape inference to handle independent dims #402
Changes from 8 commits
0e95347
5f83573
13d94f4
7b8c2ee
88cb603
2de25be
9a0ccf3
0102210
9f85c50
9f5f391
a3a4c4b
260d2a2
204b357
b8f7bf6
aa56a2d
51d7bef
e7f4441
925cdfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,13 +16,13 @@ | |
import funsor.ops as ops | ||
from funsor.affine import is_affine | ||
from funsor.cnf import Contraction, GaussianMixture | ||
from funsor.domains import Array, Real, Reals | ||
from funsor.domains import Array, BintType, Real, Reals, RealsType | ||
from funsor.gaussian import Gaussian | ||
from funsor.interpreter import gensym | ||
from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype, | ||
ignore_jit_warnings, numeric_array, stack) | ||
from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, \ | ||
eager, to_data, to_funsor | ||
eager, reflect, to_data, to_funsor | ||
from funsor.util import broadcast_shape, get_backend, getargspec, lazy_property | ||
|
||
|
||
|
@@ -57,12 +57,34 @@ class DistributionMeta(FunsorMeta): | |
""" | ||
def __call__(cls, *args, **kwargs): | ||
kwargs.update(zip(cls._ast_fields, args)) | ||
value = kwargs.pop('value', 'value') | ||
kwargs = OrderedDict( | ||
(k, to_funsor(kwargs[k], output=cls._infer_param_domain(k, getattr(kwargs[k], "shape", ())))) | ||
for k in cls._ast_fields if k != 'value') | ||
value = to_funsor(value, output=cls._infer_value_domain(**{k: v.output for k, v in kwargs.items()})) | ||
args = numbers_to_tensors(*(tuple(kwargs.values()) + (value,))) | ||
kwargs["value"] = kwargs.get("value", "value") | ||
kwargs = OrderedDict((k, kwargs[k]) for k in cls._ast_fields) # make sure args are sorted | ||
|
||
domains = OrderedDict() | ||
for k, v in kwargs.items(): | ||
if k == "value": | ||
continue | ||
|
||
# compute unbroadcasted param domains | ||
domain = cls._infer_param_domain(k, getattr(kwargs[k], "shape", ())) | ||
# use to_funsor to infer output dimensions of e.g. tensors | ||
domains[k] = domain if domain is not None else to_funsor(v).output | ||
|
||
# broadcast individual param domains with Funsor inputs | ||
# this avoids .expand-ing underlying parameter tensors | ||
if isinstance(v, Funsor) and isinstance(v.output, RealsType): | ||
domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)] | ||
|
||
# now use the broadcasted parameter shapes to infer the event_shape | ||
domains["value"] = cls._infer_value_domain(**domains) | ||
if isinstance(kwargs["value"], Funsor) and isinstance(kwargs["value"].output, RealsType): | ||
# try to broadcast the event shape with the value, in case they disagree | ||
domains["value"] = Reals[broadcast_shape(domains["value"].shape, kwargs["value"].output.shape)] | ||
|
||
# finally, perform conversions to funsors | ||
kwargs = OrderedDict((k, to_funsor(v, output=domains[k])) for k, v in kwargs.items()) | ||
args = numbers_to_tensors(*kwargs.values()) | ||
|
||
return super(DistributionMeta, cls).__call__(*args) | ||
|
||
|
||
|
@@ -98,14 +120,6 @@ def eager_reduce(self, op, reduced_vars): | |
return Number(0.) # distributions are normalized | ||
return super(Distribution, self).eager_reduce(op, reduced_vars) | ||
|
||
@classmethod | ||
def eager_log_prob(cls, *params): | ||
inputs, tensors = align_tensors(*params) | ||
params = dict(zip(cls._ast_fields, tensors)) | ||
value = params.pop('value') | ||
data = cls.dist_class(**params).log_prob(value) | ||
return Tensor(data, inputs) | ||
|
||
def _get_raw_dist(self): | ||
""" | ||
Internal method for working with underlying distribution attributes | ||
|
@@ -129,6 +143,23 @@ def has_rsample(self): | |
def has_enumerate_support(self): | ||
return getattr(self.dist_class, "has_enumerate_support", False) | ||
|
||
@classmethod | ||
def eager_log_prob(cls, *params): | ||
params, value = params[:-1], params[-1] | ||
params = params + (Variable("value", value.output),) | ||
instance = reflect(cls, *params) | ||
raw_dist, value_name, value_output, dim_to_name = instance._get_raw_dist() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to refactor |
||
assert value.output == value_output | ||
name_to_dim = {v: k for k, v in dim_to_name.items()} | ||
dim_to_name.update({-1 - d - len(raw_dist.batch_shape): name | ||
for d, name in enumerate(value.inputs) if name not in name_to_dim}) | ||
name_to_dim.update({v: k for k, v in dim_to_name.items() if v not in name_to_dim}) | ||
raw_log_prob = raw_dist.log_prob(to_data(value, name_to_dim=name_to_dim)) | ||
log_prob = to_funsor(raw_log_prob, Real, dim_to_name=dim_to_name) | ||
inputs = value.inputs.copy() | ||
inputs.update(instance.inputs) | ||
return log_prob.align(tuple(k for k, v in inputs.items() if k in log_prob.inputs and isinstance(v, BintType))) | ||
|
||
def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): | ||
|
||
# note this should handle transforms correctly via distribution_to_data | ||
|
@@ -142,7 +173,8 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): | |
sample_shape = tuple(v.size for v in sample_inputs.values()) | ||
sample_args = (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape) | ||
if self.has_rsample: | ||
raw_value = raw_dist.rsample(*sample_args) | ||
# TODO fix this hack by adding rsample and has_rsample to Independent upstream in NumPyro | ||
raw_value = getattr(raw_dist, "rsample", raw_dist.sample)(*sample_args) | ||
eb8680 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
raw_value = ops.detach(raw_dist.sample(*sample_args)) | ||
|
||
|
@@ -191,7 +223,13 @@ def _infer_value_domain(cls, **kwargs): | |
# rely on the underlying distribution's logic to infer the event_shape given param domains | ||
instance = cls.dist_class(**{k: dummy_numeric_array(domain) for k, domain in kwargs.items()}, | ||
validate_args=False) | ||
out_shape = instance.event_shape | ||
|
||
# Note inclusion of batch_shape here to handle independent event dimensions. | ||
# The arguments to _infer_value_domain are the .output shapes of parameters, | ||
# so any extra batch dimensions that aren't part of the instance event_shape | ||
# must be broadcasted output dimensions by construction. | ||
out_shape = instance.batch_shape + instance.event_shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change to |
||
|
||
if type(instance.support).__name__ == "_IntegerInterval": | ||
out_dtype = int(instance.support.upper_bound + 1) | ||
else: | ||
|
@@ -400,10 +438,32 @@ def __call__(self, cls, args, kwargs): | |
|
||
@to_data.register(Distribution) | ||
def distribution_to_data(funsor_dist, name_to_dim=None): | ||
params = [to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim) | ||
for param_name in funsor_dist._ast_fields if param_name != 'value'] | ||
pyro_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params))) | ||
funsor_event_shape = funsor_dist.value.output.shape | ||
|
||
# attempt to generically infer the independent output dimensions | ||
instance = funsor_dist.dist_class(**{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Beyond the scope of this PR, I'm concerned with the increasing overhead of shape computations that need to do tensor ops. I like @fehiepsi's recent suggestion of implementing (Indeed in theory an optimizing compiler could remove all this overhead, but in practice our tensor backends either incur super-linear compile time cost, or fail to cover the wide range of probabilistic models we would like to handle. And while these dummy tensor ops are cheap, they add noise to debugging efforts.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree the repeated creation of distribution instances here is not ideal. Perhaps we could add counterparts of some of the shape inference methods from TFP (e.g. |
||
k: dummy_numeric_array(domain) | ||
for k, domain in zip(funsor_dist._ast_fields, (v.output for v in funsor_dist._ast_values)) | ||
if k != "value" | ||
}, validate_args=False) | ||
eb8680 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
event_shape = broadcast_shape(instance.event_shape, funsor_dist.value.output.shape) | ||
# XXX is this final broadcast_shape necessary? should just be event_shape[...]? | ||
indep_shape = broadcast_shape(instance.batch_shape, event_shape[:len(event_shape) - len(instance.event_shape)]) | ||
eb8680 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
params = [] | ||
for param_name, funsor_param in zip(funsor_dist._ast_fields, funsor_dist._ast_values[:-1]): | ||
param = to_data(funsor_param, name_to_dim=name_to_dim) | ||
|
||
# infer the independent dimensions of each parameter separately, since we chose to keep them unbroadcasted | ||
param_event_shape = getattr(funsor_dist._infer_param_domain(param_name, funsor_param.output.shape), "shape", ()) | ||
param_indep_shape = funsor_param.output.shape[:len(funsor_param.output.shape) - len(param_event_shape)] | ||
for i in range(max(0, len(indep_shape) - len(param_indep_shape))): | ||
# add singleton event dimensions, leave broadcasting/expanding to backend | ||
param = ops.unsqueeze(param, -1 - len(funsor_param.output.shape)) | ||
|
||
params.append(param) | ||
|
||
pyro_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params))) | ||
pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0)) | ||
|
||
# TODO get this working for all backends | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the expected domain of
scale
forNormal(Reals[2], 1.)
andNormal(Reals[2], torch.ones(2))
? Currently,domains["scale"]
will beReal
in both case. The second case will trigger an error atto_funsor(v, output=domains[k])
below.In either case, I guess we need to rewrite
eager_normal
oreager_mvn
to addressReals[2]
loc. Maybe there is some trick to avoid doing so. cc @fritzoThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the first case, it's
Real
, and in the second, it'sReals[2]
. I guess I should add a second broadcasting condition below to handle the case where the parameter is a raw tensor: