From 5f3ed9f6295bfa734cb1bfb738a6ece78d317e9e Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Wed, 20 Nov 2024 21:09:23 -0500 Subject: [PATCH] unsqueeze time to handle shape errors --- pyciemss/compiled_dynamics.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pyciemss/compiled_dynamics.py b/pyciemss/compiled_dynamics.py index 70613df0..83b48784 100644 --- a/pyciemss/compiled_dynamics.py +++ b/pyciemss/compiled_dynamics.py @@ -11,6 +11,7 @@ import pyro import torch from chirho.dynamical.handlers import LogTrajectory +from chirho.dynamical.internals._utils import _squeeze_time_dim from chirho.dynamical.ops import State, simulate S = TypeVar("S") @@ -236,9 +237,8 @@ def __init__(self, times: torch.Tensor, model: CompiledDynamics): super().__init__() self.model = model self.observables_names: List[str] = [] - self.lt = LogTrajectory( - times - ) # This gets around the issue of the LogTrajectory handler blocking `self` + # This gets around the issue of the LogTrajectory handler blocking `self` + self.lt = LogTrajectory(times) def _pyro_simulate_point(self, msg): self.lt._pyro_simulate_point(msg) @@ -249,14 +249,18 @@ def _pyro_post_simulate_trajectory(self, msg): msg["value"] = {**msg["value"], **observables} def _pyro_post_simulate(self, msg): + initial_state = msg["args"][1] + initial_observables = _squeeze_time_dim(self.model.observables(initial_state)) msg["args"] = ( msg["args"][0], - {**initial_state, **self.model.observables(initial_state)}, + {**initial_state, **initial_observables}, msg["args"][2], msg["args"][3], ) + self.lt._pyro_post_simulate(msg) + self.observables = { k: v for k, v in self.lt.trajectory.items() if k in self.observables_names }