Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update type annotations for Pyro 1.9 #527

Open
wants to merge 53 commits into
base: master
Choose a base branch
from
Open

Conversation

eb8680
Copy link
Contributor

@eb8680 eb8680 commented Feb 23, 2024

Supercedes #522

@eb8680 eb8680 added bug Something isn't working upstream status:WIP Work-in-progress not yet ready for review labels Feb 23, 2024
@eb8680
Copy link
Contributor Author

eb8680 commented Mar 14, 2024

This should not be merged before the ASKEM evaluation ends in 2 weeks. It also seems likely that there will be a Pyro 1.9.1 release soon that will address some of the issues here, so maybe we should wait for that as well.

@SamWitty
Copy link
Collaborator

@eb8680 , now that the ASKEM evaluation is over would you like me to review this PR?

@eb8680
Copy link
Contributor Author

eb8680 commented Apr 25, 2024

If possible I think we should aim to make this backward compatible with Pyro 1.8 before merging. That way we can unpin the Pyro version here in ChiRho and separately pin it to <1.9.0 downstream in PyCIEMSS.

As I mentioned in a comment above, I had tried to enable backward compatibility when I made this PR, but it ended up being more work than I anticipated because there are backwards-incompatible changes to some types in Pyro 1.9.0. Unless I can come up with a simpler workaround for this problem, I'm not sure I'll have the time to fix it myself and land this PR in the next couple of weeks.

@eb8680 eb8680 added status:WIP Work-in-progress not yet ready for review and removed status:awaiting review Awaiting response from reviewer labels Apr 25, 2024
@SamWitty SamWitty added status:awaiting review Awaiting response from reviewer and removed status:WIP Work-in-progress not yet ready for review labels Apr 26, 2024
@SamWitty SamWitty self-requested a review April 26, 2024 15:36
@eb8680 eb8680 removed the blocked label Apr 26, 2024
Copy link
Collaborator

@SamWitty SamWitty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a bunch of miscellaneous comments here, none of which really need to be addressed before this merged. That said, if I'm correct and the types for msg throughout could be specified as Mapping[str, Any] then that would be nice.

I'm approving, but I would appreciate some brief discussion on the points I brought up.

new_shape = [1] * ((event_dim - dim) - len(new_shape)) + new_shape
new_shape[dim - event_dim] = rv.batch_shape[dim]

msg["value"] = value.expand(tuple(new_shape))

def _pyro_observe(self, msg: dict) -> None:
def _pyro_observe(self, msg) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is msg now, if not a dict?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see later it's typed as Mapping[str, Any]. Would that work here?

) -> torch.Tensor:
with pyro.poutine.infer_config(config_fn=no_ambiguity):
with pyro.poutine.infer_config(config_fn=no_ambiguity): # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, could we add a comment (or an issue) describing what will need to happen upsteam for this type: ignore to be removed?

new_base_dist = dist.Delta(value, event_dim=obs_event_dim).mask(False)
new_noise_dist = dist.TransformedDistribution(new_base_dist, tfm.inv)
new_noise_dist = dist.TransformedDistribution(new_base_dist, tfm.inv) # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't it infer that new_noise_dist is a dist.TorchDistribution?

@@ -23,21 +25,21 @@ class FactualConditioningMessenger(pyro.poutine.messenger.Messenger):
counterfactual semantics handlers such as :class:`MultiWorldCounterfactual` .
"""

def _pyro_post_sample(self, msg: dict) -> None:
def _pyro_post_sample(self, msg) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below comment about removal of dict

@@ -55,7 +56,7 @@ def __init__(self, times: torch.Tensor, is_traced: bool = False):

super().__init__()

def _pyro_post_simulate(self, msg: dict) -> None:
def _pyro_post_simulate(self, msg) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment in ambiguity.py about dropping the msg type. Could this be replaced with Mapping[str, Any]?

from chirho.dynamical.handlers.interruption import StaticEvent

dynamics: Dynamics[T] = msg["args"][0]
state: State[T] = msg["args"][1]
dynamics: Dynamics[T] = typing.cast(Dynamics[T], msg["args"][0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the type specification Dynamics[T] in dynamics: Dynamics[T] still necessary here now that you're using typing.cast?

@@ -187,9 +200,9 @@ def _indices_of_tensor(value: torch.Tensor, **kwargs) -> IndexSet:
return indices_of(value.shape, **kwargs)


@indices_of.register
@indices_of.register(pyro.distributions.Distribution)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly annoying that singledispatch dispatches on pyro.distributions.Distribution but mypy wants TorchDistribution types. Nothing really to do here... just commenting for my own benefit.

@@ -241,7 +254,7 @@ def frame(self) -> CondIndepStackFrame:
name=self.name, dim=self.dim, size=self.size, counter=0
)

def _process_message(self, msg):
def _process_message(self, msg) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other return type might there be?

import pyro.distributions.constraints as constraints
import pyro.distributions.torch as dist
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did filepaths move upstream in pyro?

observed_value = msg["value"]
computed_value = msg["fn"].base_dist.v
fn: _MaskedDelta = msg["fn"]
is_observed: Literal[True] = msg["is_observed"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know that msg["is_observed"] is True?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use a bool type here and then assert is_observed instead?

@SamWitty SamWitty added status:awaiting response Awaiting response from creator and removed status:awaiting review Awaiting response from reviewer labels Apr 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working status:awaiting response Awaiting response from creator upstream
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants