-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add more tests for collapse #2702
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2344,6 +2344,58 @@ def guide(): | |
assert_ok(model, guide, elbo) | ||
|
||
|
||
@pytest.mark.stage("funsor") | ||
@pytest.mark.parametrize("num_particles", [1, 2]) | ||
def test_collapse_diag_normal_plate_normal(num_particles): | ||
T, d = 5, 3 | ||
data = torch.randn((T, d)) | ||
|
||
def model(): | ||
x = pyro.sample("x", dist.Normal(0., 1.)).unsqueeze(-1) | ||
with poutine.collapse(): | ||
with pyro.plate("data", T, dim=-1): | ||
expand_shape = (d,) if num_particles == 1 else (num_particles, 1, d) | ||
y = pyro.sample("y", dist.Normal(x, 1.).expand(expand_shape).to_event(1)) | ||
pyro.sample("z", dist.Normal(y, 1.).to_event(1), obs=data) | ||
|
||
def guide(): | ||
loc = pyro.param("loc", torch.tensor(0.)) | ||
scale = pyro.param("scale", torch.tensor(1.), constraint=constraints.positive) | ||
pyro.sample("x", dist.Normal(loc, scale)) | ||
|
||
elbo = Trace_ELBO(num_particles=num_particles, vectorize_particles=True, | ||
max_plate_nesting=1) | ||
assert_ok(model, guide, elbo) | ||
|
||
|
||
@pytest.mark.stage("funsor") | ||
@pytest.mark.parametrize("num_particles", [1, 2]) | ||
def test_collapse_normal_mvn_mvn(num_particles): | ||
T, d, S = 5, 3, 4 | ||
data = torch.randn((T, d, S)) | ||
|
||
def model(): | ||
x = pyro.sample("x", dist.Normal(0., 1.)).unsqueeze(-1) | ||
with poutine.collapse(): | ||
with pyro.plate("d", d, dim=-1): | ||
expand_shape = (d, S) if num_particles == 1 else (num_particles, d, S) | ||
beta0 = pyro.sample("beta0", dist.Normal(x, 1.).expand(expand_shape).to_event(1)) | ||
beta = pyro.sample("beta", dist.MultivariateNormal(beta0, torch.eye(S))) | ||
|
||
mean = torch.ones((T, d)) @ beta | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fails because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this seems like a major incompatibility between funsor-style operations which have a clear batch/event split and numpy-style operations where everything is effectively an event dim. In particular there's no way for One workaround in this model would be to move There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like if we had a way of knowing the plate context of a value both at the time of its creation and each time it was accessed, we could handle this smoothly using with plate("plate_var", d, dim=-1):
beta0 = pyro.sample("beta0", dist.Normal(x, 1.).expand(expand_shape).to_event(1))
# setting env.beta also records the current plate context of beta
env.beta = pyro.sample("beta", dist.MultivariateNormal(beta0, torch.eye(S)))
# use funsor.terms.Lambda to convert the dead plate dimension to an output:
# now reading env.beta returns funsor.Lambda(plate_var, beta)
# where plate_var is the difference between the current and original plate contexts
mean = torch.ones((T, d)) @ env.beta There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I didn't know about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I was suggesting it as a layer of automation on top of |
||
with pyro.plate("data", T, dim=-1): | ||
pyro.sample("obs", dist.MultivariateNormal(mean, torch.eye(S)), obs=data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fails due to a similar reason to the
|
||
|
||
def guide(): | ||
loc = pyro.param("loc", torch.tensor(0.)) | ||
scale = pyro.param("scale", torch.tensor(1.), constraint=constraints.positive) | ||
pyro.sample("x", dist.Normal(loc, scale)) | ||
|
||
elbo = Trace_ELBO(num_particles=num_particles, vectorize_particles=True, | ||
max_plate_nesting=1) | ||
assert_ok(model, guide, elbo) | ||
|
||
|
||
@pytest.mark.stage("funsor") | ||
@pytest.mark.parametrize("num_particles", [1, 2]) | ||
def test_collapse_beta_bernoulli(num_particles): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails due to 2 reasons:
y.output
is Reals[d], which will raise error while infer_param_domain in funsor:Output mismatch: Reals[2] vs Real
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I wonder if there is always enough information available in the args of expanded distributions to automatically determine the event dim, so that we could make
.to_event()
a no-op on funsors. For example here we could deduce the event shape fromy.output
.@eb8680 would it be possible to support this in
to_funsor()
andto_data()
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be possible, at least in principle, and if we want
collapse
to work seamlessly with models that use.to_event
we'll need something like that. I think it will require changing the way type inference works infunsor.distribution.Distribution
, though.The point of failure in Funsor is the
to_funsor
conversion of parameters infunsor.distribution.DistributionMeta.__call__
:In this test case,
kwargs[k]
isloc = Variable("y", Reals[2])
, andNormal._infer_param_domain(...)
thinks the output should beReal
instead ofReals[2]
.I think a general solution (at least when all parameters have the same broadcasted
.output
) is to compute unbroadcasted parameter and value shapes up front, then broadcast:Now the previously incorrect
to_funsor
conversions reduce toWDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds reasonable to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I put up a draft Funsor PR here: pyro-ppl/funsor#402
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Eli! testing now...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe what's in that PR after my recent push is sufficient for this test case, although it's not ready to merge because of some edge cases in the distributions. Let me know if it's not working.