Skip to content

Commit

Permalink
unsqueeze time to handle shape errors
Browse files Browse the repository at this point in the history
  • Loading branch information
SamWitty committed Nov 21, 2024
1 parent f3e0730 commit 5f3ed9f
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions pyciemss/compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pyro
import torch
from chirho.dynamical.handlers import LogTrajectory
from chirho.dynamical.internals._utils import _squeeze_time_dim
from chirho.dynamical.ops import State, simulate

S = TypeVar("S")
Expand Down Expand Up @@ -236,9 +237,8 @@ def __init__(self, times: torch.Tensor, model: CompiledDynamics):
super().__init__()
self.model = model
self.observables_names: List[str] = []
self.lt = LogTrajectory(
times
) # This gets around the issue of the LogTrajectory handler blocking `self`
# This gets around the issue of the LogTrajectory handler blocking `self`
self.lt = LogTrajectory(times)

def _pyro_simulate_point(self, msg):
self.lt._pyro_simulate_point(msg)
Expand All @@ -249,14 +249,18 @@ def _pyro_post_simulate_trajectory(self, msg):
msg["value"] = {**msg["value"], **observables}

def _pyro_post_simulate(self, msg):

initial_state = msg["args"][1]
initial_observables = _squeeze_time_dim(self.model.observables(initial_state))
msg["args"] = (
msg["args"][0],
{**initial_state, **self.model.observables(initial_state)},
{**initial_state, **initial_observables},
msg["args"][2],
msg["args"][3],
)

self.lt._pyro_post_simulate(msg)

self.observables = {
k: v for k, v in self.lt.trajectory.items() if k in self.observables_names
}
Expand Down

0 comments on commit 5f3ed9f

Please sign in to comment.