-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update distribution shape inference to handle independent dims #402
Merged
Merged
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
0e95347
Update distribution shape inference to handle independent dims
eb8680 5f83573
update eager_log_prob and add a test
eb8680 13d94f4
remove pdb, add mvnormal test
eb8680 7b8c2ee
fix alignment in eager_log_prob
eb8680 88cb603
attempting a fix, and ops.unsqueeze for jax
eb8680 2de25be
fix dirichletmultinomial
eb8680 9a0ccf3
comments
eb8680 0102210
fix tests for jax
eb8680 9f85c50
add extra broadcasting condition
eb8680 9f5f391
patch jax compound dists
eb8680 a3a4c4b
align to align_tensors
eb8680 260d2a2
lint
eb8680 204b357
tweak tolerance
eb8680 b8f7bf6
Merge branch 'master' into infer-independent-dims
eb8680 aa56a2d
tolerance
eb8680 51d7bef
rtol
eb8680 e7f4441
address comments
eb8680 925cdfb
fix test
eb8680 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
import funsor.ops as ops | ||
from funsor.affine import is_affine | ||
from funsor.cnf import Contraction, GaussianMixture | ||
from funsor.domains import Array, Real, Reals | ||
from funsor.domains import Array, Real, Reals, RealsType | ||
from funsor.gaussian import Gaussian | ||
from funsor.interpreter import gensym | ||
from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype, | ||
|
@@ -57,12 +57,34 @@ class DistributionMeta(FunsorMeta): | |
""" | ||
def __call__(cls, *args, **kwargs): | ||
kwargs.update(zip(cls._ast_fields, args)) | ||
value = kwargs.pop('value', 'value') | ||
kwargs = OrderedDict( | ||
(k, to_funsor(kwargs[k], output=cls._infer_param_domain(k, getattr(kwargs[k], "shape", ())))) | ||
for k in cls._ast_fields if k != 'value') | ||
value = to_funsor(value, output=cls._infer_value_domain(**{k: v.output for k, v in kwargs.items()})) | ||
args = numbers_to_tensors(*(tuple(kwargs.values()) + (value,))) | ||
kwargs["value"] = kwargs.get("value", "value") | ||
kwargs = OrderedDict((k, kwargs[k]) for k in cls._ast_fields) # make sure args are sorted | ||
|
||
domains = OrderedDict() | ||
for k, v in kwargs.items(): | ||
if k == "value": | ||
continue | ||
|
||
# compute unbroadcasted param domains | ||
domain = cls._infer_param_domain(k, getattr(kwargs[k], "shape", ())) | ||
# use to_funsor to infer output dimensions of e.g. tensors | ||
domains[k] = domain if domain is not None else to_funsor(v).output | ||
|
||
# broadcast individual param domains with Funsor inputs | ||
# this avoids .expand-ing underlying parameter tensors | ||
if isinstance(v, Funsor) and isinstance(v.output, RealsType): | ||
domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)] | ||
|
||
# now use the broadcasted parameter shapes to infer the event_shape | ||
domains["value"] = cls._infer_value_domain(**domains) | ||
if isinstance(kwargs["value"], Funsor) and isinstance(kwargs["value"].output, RealsType): | ||
# try to broadcast the event shape with the value, in case they disagree | ||
domains["value"] = Reals[broadcast_shape(domains["value"].shape, kwargs["value"].output.shape)] | ||
|
||
# finally, perform conversions to funsors | ||
kwargs = OrderedDict((k, to_funsor(v, output=domains[k])) for k, v in kwargs.items()) | ||
args = numbers_to_tensors(*kwargs.values()) | ||
|
||
return super(DistributionMeta, cls).__call__(*args) | ||
|
||
|
||
|
@@ -191,7 +213,13 @@ def _infer_value_domain(cls, **kwargs): | |
# rely on the underlying distribution's logic to infer the event_shape given param domains | ||
instance = cls.dist_class(**{k: dummy_numeric_array(domain) for k, domain in kwargs.items()}, | ||
validate_args=False) | ||
out_shape = instance.event_shape | ||
|
||
# Note inclusion of batch_shape here to handle independent event dimensions. | ||
# The arguments to _infer_value_domain are the .output shapes of parameters, | ||
# so any extra batch dimensions that aren't part of the instance event_shape | ||
# must be broadcasted output dimensions by construction. | ||
out_shape = instance.batch_shape + instance.event_shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change to |
||
|
||
if type(instance.support).__name__ == "_IntegerInterval": | ||
out_dtype = int(instance.support.upper_bound + 1) | ||
else: | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the expected domain of
scale
forNormal(Reals[2], 1.)
andNormal(Reals[2], torch.ones(2))
? Currently,domains["scale"]
will beReal
in both case. The second case will trigger an error atto_funsor(v, output=domains[k])
below.In either case, I guess we need to rewrite
eager_normal
oreager_mvn
to addressReals[2]
loc. Maybe there is some trick to avoid doing so. cc @fritzoThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the first case, it's
Real
, and in the second, it'sReals[2]
. I guess I should add a second broadcasting condition below to handle the case where the parameter is a raw tensor: