-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update type annotations for Pyro 1.9 #527
base: master
Are you sure you want to change the base?
Conversation
This should not be merged before the ASKEM evaluation ends in 2 weeks. It also seems likely that there will be a Pyro 1.9.1 release soon that will address some of the issues here, so maybe we should wait for that as well. |
@eb8680 , now that the ASKEM evaluation is over would you like me to review this PR? |
If possible I think we should aim to make this backward compatible with Pyro 1.8 before merging. That way we can unpin the Pyro version here in ChiRho and separately pin it to As I mentioned in a comment above, I had tried to enable backward compatibility when I made this PR, but it ended up being more work than I anticipated because there are backwards-incompatible changes to some types in Pyro 1.9.0. Unless I can come up with a simpler workaround for this problem, I'm not sure I'll have the time to fix it myself and land this PR in the next couple of weeks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a bunch of miscellaneous comments here, none of which really need to be addressed before this merged. That said, if I'm correct and the types for msg
throughout could be specified as Mapping[str, Any]
then that would be nice.
I'm approving, but I would appreciate some brief discussion on the points I brought up.
new_shape = [1] * ((event_dim - dim) - len(new_shape)) + new_shape | ||
new_shape[dim - event_dim] = rv.batch_shape[dim] | ||
|
||
msg["value"] = value.expand(tuple(new_shape)) | ||
|
||
def _pyro_observe(self, msg: dict) -> None: | ||
def _pyro_observe(self, msg) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is msg
now, if not a dict
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see later it's typed as Mapping[str, Any]
. Would that work here?
) -> torch.Tensor: | ||
with pyro.poutine.infer_config(config_fn=no_ambiguity): | ||
with pyro.poutine.infer_config(config_fn=no_ambiguity): # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible, could we add a comment (or an issue) describing what will need to happen upsteam for this type: ignore
to be removed?
new_base_dist = dist.Delta(value, event_dim=obs_event_dim).mask(False) | ||
new_noise_dist = dist.TransformedDistribution(new_base_dist, tfm.inv) | ||
new_noise_dist = dist.TransformedDistribution(new_base_dist, tfm.inv) # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't it infer that new_noise_dist
is a dist.TorchDistribution
?
@@ -23,21 +25,21 @@ class FactualConditioningMessenger(pyro.poutine.messenger.Messenger): | |||
counterfactual semantics handlers such as :class:`MultiWorldCounterfactual` . | |||
""" | |||
|
|||
def _pyro_post_sample(self, msg: dict) -> None: | |||
def _pyro_post_sample(self, msg) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See below comment about removal of dict
@@ -55,7 +56,7 @@ def __init__(self, times: torch.Tensor, is_traced: bool = False): | |||
|
|||
super().__init__() | |||
|
|||
def _pyro_post_simulate(self, msg: dict) -> None: | |||
def _pyro_post_simulate(self, msg) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment in ambiguity.py
about dropping the msg
type. Could this be replaced with Mapping[str, Any]
?
from chirho.dynamical.handlers.interruption import StaticEvent | ||
|
||
dynamics: Dynamics[T] = msg["args"][0] | ||
state: State[T] = msg["args"][1] | ||
dynamics: Dynamics[T] = typing.cast(Dynamics[T], msg["args"][0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the type specification Dynamics[T]
in dynamics: Dynamics[T]
still necessary here now that you're using typing.cast
?
@@ -187,9 +200,9 @@ def _indices_of_tensor(value: torch.Tensor, **kwargs) -> IndexSet: | |||
return indices_of(value.shape, **kwargs) | |||
|
|||
|
|||
@indices_of.register | |||
@indices_of.register(pyro.distributions.Distribution) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly annoying that singledispatch
dispatches on pyro.distributions.Distribution
but mypy
wants TorchDistribution
types. Nothing really to do here... just commenting for my own benefit.
@@ -241,7 +254,7 @@ def frame(self) -> CondIndepStackFrame: | |||
name=self.name, dim=self.dim, size=self.size, counter=0 | |||
) | |||
|
|||
def _process_message(self, msg): | |||
def _process_message(self, msg) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What other return type might there be?
import pyro.distributions.constraints as constraints | ||
import pyro.distributions.torch as dist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did filepaths move upstream in pyro
?
observed_value = msg["value"] | ||
computed_value = msg["fn"].base_dist.v | ||
fn: _MaskedDelta = msg["fn"] | ||
is_observed: Literal[True] = msg["is_observed"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we know that msg["is_observed"]
is True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use a bool
type here and then assert is_observed
instead?
Supercedes #522