From 6dff41ba7b1bd7969bca2de6c811bce04c65a1e7 Mon Sep 17 00:00:00 2001 From: sabinala Date: Mon, 26 Feb 2024 11:53:40 -0800 Subject: [PATCH] adding notebook for error message --- .../step_size_error_producing_ntbk.ipynb | 280 ++++++++++++++++++ 1 file changed, 280 insertions(+) create mode 100644 docs/source/step_size_error_producing_ntbk.ipynb diff --git a/docs/source/step_size_error_producing_ntbk.ipynb b/docs/source/step_size_error_producing_ntbk.ipynb new file mode 100644 index 000000000..6379218f5 --- /dev/null +++ b/docs/source/step_size_error_producing_ntbk.ipynb @@ -0,0 +1,280 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "76eca136", + "metadata": {}, + "outputs": [], + "source": [ + "# Notebook for reproducing Issue " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c2bc657d", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pyciemss\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2153cf4c", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_PATH = \"https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/\"\n", + "model3 = os.path.join(MODEL_PATH, \"SIR_stockflow.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "bb52bb9c", + "metadata": {}, + "outputs": [], + "source": [ + "start_time = 0.0\n", + "end_time = 100.\n", + "logging_step_size = 10.0\n", + "num_samples = 10" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7a07d073", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ERROR:root:\n", + " ###############################\n", + "\n", + " There was an exception in pyciemss\n", + "\n", + " Error occured in function: sample\n", + "\n", + " Function docs : \n", + " Load a model from a file, compile it into a probabilistic program, and sample from it.\n", + "\n", + " Args:\n", + " model_path_or_json: Union[str, Dict]\n", + " - A path to a AMR model file or JSON containing a model in AMR form.\n", + " end_time: float\n", + " - The end time of the sampled simulation.\n", + " logging_step_size: float\n", + " - The step size to use for logging the trajectory.\n", + " num_samples: int\n", + " - The number of samples to draw from the model.\n", + " solver_method: str\n", + " - The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.\n", + " - If performance is incredibly slow, we suggest using `euler` to debug.\n", + " If using `euler` results in faster simulation, the issue is likely that the model is stiff.\n", + " solver_options: Dict[str, Any]\n", + " - Options to pass to the solver. See torchdiffeq' `odeint` method for more details.\n", + " start_time: float\n", + " - The start time of the model. This is used to align the `start_state` from the\n", + " AMR model with the simulation timepoints.\n", + " - By default we set the `start_time` to be 0.\n", + " inferred_parameters: Optional[pyro.nn.PyroModule]\n", + " - A Pyro module that contains the inferred parameters of the model.\n", + " This is typically the result of `calibrate`.\n", + " - If not provided, we will use the default values from the AMR model.\n", + " static_state_interventions: Dict[float, Dict[str, Intervention]]\n", + " - A dictionary of static interventions to apply to the model.\n", + " - Each key is the time at which the intervention is applied.\n", + " - Each value is a dictionary of the form {state_variable_name: intervention_assignment}.\n", + " - Note that the `intervention_assignment` can be any type supported by\n", + " :func:`~chirho.interventional.ops.intervene`, including functions.\n", + " static_parameter_interventions: Dict[float, Dict[str, Intervention]]\n", + " - A dictionary of static interventions to apply to the model.\n", + " - Each key is the time at which the intervention is applied.\n", + " - Each value is a dictionary of the form {parameter_name: intervention_assignment}.\n", + " - Note that the `intervention_assignment` can be any type supported by\n", + " :func:`~chirho.interventional.ops.intervene`, including functions.\n", + " dynamic_state_interventions: Dict[\n", + " Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],\n", + " Dict[str, Intervention]\n", + " ]\n", + " - A dictionary of dynamic interventions to apply to the model.\n", + " - Each key is a function that takes in the current state of the model and returns a tensor.\n", + " When this function crosses 0, the dynamic intervention is applied.\n", + " - Each value is a dictionary of the form {state_variable_name: intervention_assignment}.\n", + " - Note that the `intervention_assignment` can be any type supported by\n", + " :func:`~chirho.interventional.ops.intervene`, including functions.\n", + " dynamic_parameter_interventions: Dict[\n", + " Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],\n", + " Dict[str, Intervention]\n", + " ]\n", + " - A dictionary of dynamic interventions to apply to the model.\n", + " - Each key is a function that takes in the current state of the model and returns a tensor.\n", + " When this function crosses 0, the dynamic intervention is applied.\n", + " - Each value is a dictionary of the form {parameter_name: intervention_assignment}.\n", + " - Note that the `intervention_assignment` can be any type supported by\n", + " :func:`~chirho.interventional.ops.intervene`, including functions.\n", + "\n", + " Returns:\n", + " result: Dict[str, torch.Tensor]\n", + " - Dictionary of outputs from the model.\n", + " - Each key is the name of a parameter or state variable in the model.\n", + " - Each value is a tensor of shape (num_samples, num_timepoints) for state variables\n", + " and (num_samples,) for parameters.\n", + " \n", + "\n", + " ################################\n", + " \n", + "Traceback (most recent call last):\n", + " File \"/Users/altu809/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py\", line 10, in wrapped\n", + " result = function(*args, **kwargs)\n", + " File \"/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py\", line 298, in sample\n", + " samples = pyro.infer.Predictive(\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n", + " return forward_call(*input, **kwargs)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py\", line 273, in forward\n", + " return _predictive(\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py\", line 78, in _predictive\n", + " max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py\", line 21, in _guess_max_plate_nesting\n", + " model_trace = poutine.trace(model).get_trace(*args, **kwargs)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py\", line 198, in get_trace\n", + " self(*args, **kwargs)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py\", line 174, in __call__\n", + " ret = self.fn(*args, **kwargs)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/autograd/grad_mode.py\", line 27, in decorate_context\n", + " return func(*args, **kwargs)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py\", line 12, in _context_wrap\n", + " return fn(*args, **kwargs)\n", + " File \"/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py\", line 282, in wrapped_model\n", + " full_trajectory = model(\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py\", line 449, in __call__\n", + " result = super().__call__(*args, **kwargs)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n", + " return forward_call(*input, **kwargs)\n", + " File \"/Users/altu809/Projects/pyciemss/pyciemss/compiled_dynamics.py\", line 77, in forward\n", + " simulate(self.deriv, self.initial_state(), start_time, end_time)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py\", line 281, in _fn\n", + " apply_stack(msg)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py\", line 212, in apply_stack\n", + " frame._process_message(msg)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py\", line 162, in _process_message\n", + " return method(msg)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py\", line 109, in _pyro_simulate\n", + " state, start_time, next_interruption = simulate_to_interruption(\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py\", line 281, in _fn\n", + " apply_stack(msg)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py\", line 212, in apply_stack\n", + " frame._process_message(msg)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py\", line 162, in _process_message\n", + " return method(msg)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py\", line 89, in _pyro_simulate_to_interruption\n", + " msg[\"value\"] = torchdiffeq_simulate_to_interruption(\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py\", line 244, in torchdiffeq_simulate_to_interruption\n", + " (next_interruption,), interruption_time = _torchdiffeq_get_next_interruptions(\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py\", line 192, in _torchdiffeq_get_next_interruptions\n", + " event_time, event_solutions = _batched_odeint( # torchdiffeq.odeint_event(\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py\", line 119, in _batched_odeint\n", + " event_t, yt_raw = torchdiffeq.odeint_event(\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py\", line 101, in odeint_event\n", + " event_t, solution = odeint_interface(func, y0, t, event_fn=event_fn, **kwargs)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py\", line 79, in odeint\n", + " event_t, solution = solver.integrate_until_event(t[0], event_fn)\n", + " File \"/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py\", line 122, in integrate_until_event\n", + " assert self.step_size is not None, \"Event handling for fixed step solvers currently requires `step_size` to be provided in options.\"\n", + "AssertionError: Event handling for fixed step solvers currently requires `step_size` to be provided in options.\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "Event handling for fixed step solvers currently requires `step_size` to be provided in options.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 7\u001b[0m\n\u001b[1;32m 4\u001b[0m infection_threshold \u001b[38;5;241m=\u001b[39m make_var_threshold(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mI\u001b[39m\u001b[38;5;124m\"\u001b[39m, torch\u001b[38;5;241m.\u001b[39mtensor(\u001b[38;5;241m400.0\u001b[39m))\n\u001b[1;32m 5\u001b[0m dynamic_parameter_interventions1 \u001b[38;5;241m=\u001b[39m {infection_threshold: {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mp_cbeta\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mtensor(\u001b[38;5;241m0.3\u001b[39m)}}\n\u001b[0;32m----> 7\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mpyciemss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel3\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mend_time\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogging_step_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_time\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstart_time\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_parameter_interventions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_parameter_interventions1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43msolver_method\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43meuler\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:29\u001b[0m, in \u001b[0;36mpyciemss_logging_wrapper..wrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 17\u001b[0m log_message \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\"\"\u001b[39m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;124m ###############################\u001b[39m\n\u001b[1;32m 19\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;124m ################################\u001b[39m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;124m\u001b[39m\u001b[38;5;124m\"\"\"\u001b[39m\n\u001b[1;32m 28\u001b[0m logging\u001b[38;5;241m.\u001b[39mexception(log_message, function\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, function\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__doc__\u001b[39m)\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n", + "File \u001b[0;32m~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:10\u001b[0m, in \u001b[0;36mpyciemss_logging_wrapper..wrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 9\u001b[0m start_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mperf_counter()\n\u001b[0;32m---> 10\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 11\u001b[0m end_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mperf_counter()\n\u001b[1;32m 12\u001b[0m logging\u001b[38;5;241m.\u001b[39minfo(\n\u001b[1;32m 13\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mElapsed time for \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m%f\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, function\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, end_time \u001b[38;5;241m-\u001b[39m start_time\n\u001b[1;32m 14\u001b[0m )\n", + "File \u001b[0;32m~/Projects/pyciemss/pyciemss/interfaces.py:298\u001b[0m, in \u001b[0;36msample\u001b[0;34m(model_path_or_json, end_time, logging_step_size, num_samples, noise_model, noise_model_kwargs, solver_method, solver_options, start_time, inferred_parameters, static_state_interventions, static_parameter_interventions, dynamic_state_interventions, dynamic_parameter_interventions)\u001b[0m\n\u001b[1;32m 294\u001b[0m compiled_noise_model(full_trajectory)\n\u001b[1;32m 296\u001b[0m parallel \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(intervention_handlers) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 298\u001b[0m samples \u001b[38;5;241m=\u001b[39m \u001b[43mpyro\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minfer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPredictive\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 299\u001b[0m \u001b[43m \u001b[49m\u001b[43mwrapped_model\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 300\u001b[0m \u001b[43m \u001b[49m\u001b[43mguide\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minferred_parameters\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 301\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 302\u001b[0m \u001b[43m \u001b[49m\u001b[43mparallel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparallel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 303\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 305\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m prepare_interchange_dictionary(samples)\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:273\u001b[0m, in \u001b[0;36mPredictive.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 263\u001b[0m return_sites \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_sites \u001b[38;5;28;01melse\u001b[39;00m return_sites\n\u001b[1;32m 264\u001b[0m posterior_samples \u001b[38;5;241m=\u001b[39m _predictive(\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mguide,\n\u001b[1;32m 266\u001b[0m posterior_samples,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 271\u001b[0m model_kwargs\u001b[38;5;241m=\u001b[39mkwargs,\n\u001b[1;32m 272\u001b[0m )\n\u001b[0;32m--> 273\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_predictive\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m \u001b[49m\u001b[43mposterior_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 276\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_sites\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_sites\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 278\u001b[0m \u001b[43m \u001b[49m\u001b[43mparallel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparallel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 279\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_args\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 280\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 281\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:78\u001b[0m, in \u001b[0;36m_predictive\u001b[0;34m(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_predictive\u001b[39m(\n\u001b[1;32m 68\u001b[0m model,\n\u001b[1;32m 69\u001b[0m posterior_samples,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 75\u001b[0m model_kwargs\u001b[38;5;241m=\u001b[39m{},\n\u001b[1;32m 76\u001b[0m ):\n\u001b[1;32m 77\u001b[0m model \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mno_grad()(poutine\u001b[38;5;241m.\u001b[39mmask(model, mask\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m))\n\u001b[0;32m---> 78\u001b[0m max_plate_nesting \u001b[38;5;241m=\u001b[39m \u001b[43m_guess_max_plate_nesting\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 79\u001b[0m vectorize \u001b[38;5;241m=\u001b[39m pyro\u001b[38;5;241m.\u001b[39mplate(\n\u001b[1;32m 80\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_num_predictive_samples\u001b[39m\u001b[38;5;124m\"\u001b[39m, num_samples, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39mmax_plate_nesting \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 81\u001b[0m )\n\u001b[1;32m 82\u001b[0m model_trace \u001b[38;5;241m=\u001b[39m prune_subsample_sites(\n\u001b[1;32m 83\u001b[0m poutine\u001b[38;5;241m.\u001b[39mtrace(model)\u001b[38;5;241m.\u001b[39mget_trace(\u001b[38;5;241m*\u001b[39mmodel_args, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs)\n\u001b[1;32m 84\u001b[0m )\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:21\u001b[0m, in \u001b[0;36m_guess_max_plate_nesting\u001b[0;34m(model, args, kwargs)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;124;03mGuesses max_plate_nesting by running the model once\u001b[39;00m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;124;03mwithout enumeration. This optimistically assumes static model\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;124;03mstructure.\u001b[39;00m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m poutine\u001b[38;5;241m.\u001b[39mblock():\n\u001b[0;32m---> 21\u001b[0m model_trace \u001b[38;5;241m=\u001b[39m \u001b[43mpoutine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 22\u001b[0m sites \u001b[38;5;241m=\u001b[39m [site \u001b[38;5;28;01mfor\u001b[39;00m site \u001b[38;5;129;01min\u001b[39;00m model_trace\u001b[38;5;241m.\u001b[39mnodes\u001b[38;5;241m.\u001b[39mvalues() \u001b[38;5;28;01mif\u001b[39;00m site[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msample\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 24\u001b[0m dims \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 25\u001b[0m frame\u001b[38;5;241m.\u001b[39mdim\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m site \u001b[38;5;129;01min\u001b[39;00m sites\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m frame \u001b[38;5;129;01min\u001b[39;00m site[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcond_indep_stack\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m frame\u001b[38;5;241m.\u001b[39mvectorized\n\u001b[1;32m 29\u001b[0m ]\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198\u001b[0m, in \u001b[0;36mTraceHandler.get_trace\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_trace\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 191\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 192\u001b[0m \u001b[38;5;124;03m :returns: data structure\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;124;03m :rtype: pyro.poutine.Trace\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[38;5;124;03m Calls this poutine and returns its trace instead of the function's return value.\u001b[39;00m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 198\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmsngr\u001b[38;5;241m.\u001b[39mget_trace()\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174\u001b[0m, in \u001b[0;36mTraceHandler.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmsngr\u001b[38;5;241m.\u001b[39mtrace\u001b[38;5;241m.\u001b[39madd_node(\n\u001b[1;32m 171\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_INPUT\u001b[39m\u001b[38;5;124m\"\u001b[39m, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_INPUT\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs\u001b[39m\u001b[38;5;124m\"\u001b[39m, args\u001b[38;5;241m=\u001b[39margs, kwargs\u001b[38;5;241m=\u001b[39mkwargs\n\u001b[1;32m 172\u001b[0m )\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 174\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mValueError\u001b[39;00m, \u001b[38;5;167;01mRuntimeError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 176\u001b[0m exc_type, exc_value, traceback \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39mexc_info()\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27\u001b[0m, in \u001b[0;36m_DecoratorContextManager.__call__..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclone():\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12\u001b[0m, in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_context_wrap\u001b[39m(context, fn, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context:\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Projects/pyciemss/pyciemss/interfaces.py:282\u001b[0m, in \u001b[0;36msample..wrapped_model\u001b[0;34m()\u001b[0m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m handler \u001b[38;5;129;01min\u001b[39;00m intervention_handlers:\n\u001b[1;32m 281\u001b[0m stack\u001b[38;5;241m.\u001b[39menter_context(handler)\n\u001b[0;32m--> 282\u001b[0m full_trajectory \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 283\u001b[0m \u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mas_tensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstart_time\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 284\u001b[0m \u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mas_tensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mend_time\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 285\u001b[0m \u001b[43m \u001b[49m\u001b[43mlogging_times\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlogging_times\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 286\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_traced\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 287\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 289\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m noise_model \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 290\u001b[0m compiled_noise_model \u001b[38;5;241m=\u001b[39m compile_noise_model(\n\u001b[1;32m 291\u001b[0m noise_model, \u001b[38;5;28mvars\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mset\u001b[39m(full_trajectory\u001b[38;5;241m.\u001b[39mkeys()), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mnoise_model_kwargs\n\u001b[1;32m 292\u001b[0m )\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py:449\u001b[0m, in \u001b[0;36mPyroModule.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 447\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 448\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pyro_context:\n\u001b[0;32m--> 449\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 450\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 451\u001b[0m pyro\u001b[38;5;241m.\u001b[39msettings\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidate_poutine\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 452\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pyro_context\u001b[38;5;241m.\u001b[39mactive\n\u001b[1;32m 453\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m _is_module_local_param_enabled()\n\u001b[1;32m 454\u001b[0m ):\n\u001b[1;32m 455\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_module_local_param_usage()\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Projects/pyciemss/pyciemss/compiled_dynamics.py:77\u001b[0m, in \u001b[0;36mCompiledDynamics.forward\u001b[0;34m(self, start_time, end_time, logging_times, is_traced)\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m logging_times \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m LogTrajectory(logging_times) \u001b[38;5;28;01mas\u001b[39;00m lt:\n\u001b[0;32m---> 77\u001b[0m \u001b[43msimulate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mderiv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minitial_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_time\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mend_time\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 78\u001b[0m state \u001b[38;5;241m=\u001b[39m lt\u001b[38;5;241m.\u001b[39mtrajectory\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281\u001b[0m, in \u001b[0;36meffectful.._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 264\u001b[0m msg \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 265\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mtype\u001b[39m,\n\u001b[1;32m 266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: name,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 278\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minfer\u001b[39m\u001b[38;5;124m\"\u001b[39m: infer,\n\u001b[1;32m 279\u001b[0m }\n\u001b[1;32m 280\u001b[0m \u001b[38;5;66;03m# apply the stack and return its return value\u001b[39;00m\n\u001b[0;32m--> 281\u001b[0m \u001b[43mapply_stack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 282\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m msg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalue\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212\u001b[0m, in \u001b[0;36mapply_stack\u001b[0;34m(initial_msg)\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m frame \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mreversed\u001b[39m(stack):\n\u001b[1;32m 210\u001b[0m pointer \u001b[38;5;241m=\u001b[39m pointer \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m--> 212\u001b[0m \u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_process_message\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m msg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstop\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162\u001b[0m, in \u001b[0;36mMessenger._process_message\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 160\u001b[0m method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_pyro_\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(msg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m]), \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m method \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 162\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmethod\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py:109\u001b[0m, in \u001b[0;36mSolver._pyro_simulate\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ph\u001b[38;5;241m.\u001b[39mpriority \u001b[38;5;241m>\u001b[39m start_time:\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[0;32m--> 109\u001b[0m state, start_time, next_interruption \u001b[38;5;241m=\u001b[39m \u001b[43msimulate_to_interruption\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 110\u001b[0m \u001b[43m \u001b[49m\u001b[43mpossible_interruptions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 111\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamics\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 112\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 113\u001b[0m \u001b[43m \u001b[49m\u001b[43mstart_time\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 114\u001b[0m \u001b[43m \u001b[49m\u001b[43mend_time\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 115\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmsg\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mkwargs\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 116\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m next_interruption \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 119\u001b[0m dynamics, state \u001b[38;5;241m=\u001b[39m next_interruption\u001b[38;5;241m.\u001b[39mcallback(dynamics, state)\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281\u001b[0m, in \u001b[0;36meffectful.._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 264\u001b[0m msg \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 265\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mtype\u001b[39m,\n\u001b[1;32m 266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: name,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 278\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minfer\u001b[39m\u001b[38;5;124m\"\u001b[39m: infer,\n\u001b[1;32m 279\u001b[0m }\n\u001b[1;32m 280\u001b[0m \u001b[38;5;66;03m# apply the stack and return its return value\u001b[39;00m\n\u001b[0;32m--> 281\u001b[0m \u001b[43mapply_stack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 282\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m msg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalue\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212\u001b[0m, in \u001b[0;36mapply_stack\u001b[0;34m(initial_msg)\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m frame \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mreversed\u001b[39m(stack):\n\u001b[1;32m 210\u001b[0m pointer \u001b[38;5;241m=\u001b[39m pointer \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m--> 212\u001b[0m \u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_process_message\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m msg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstop\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162\u001b[0m, in \u001b[0;36mMessenger._process_message\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 160\u001b[0m method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_pyro_\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(msg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m]), \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m method \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 162\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmethod\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py:89\u001b[0m, in \u001b[0;36mTorchDiffEq._pyro_simulate_to_interruption\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 87\u001b[0m interruptions, dynamics, initial_state, start_time, end_time \u001b[38;5;241m=\u001b[39m msg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 88\u001b[0m msg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mkwargs\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mupdate(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39modeint_kwargs)\n\u001b[0;32m---> 89\u001b[0m msg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalue\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mtorchdiffeq_simulate_to_interruption\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[43m \u001b[49m\u001b[43minterruptions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 91\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamics\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 92\u001b[0m \u001b[43m \u001b[49m\u001b[43minitial_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 93\u001b[0m \u001b[43m \u001b[49m\u001b[43mstart_time\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 94\u001b[0m \u001b[43m \u001b[49m\u001b[43mend_time\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 95\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmsg\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mkwargs\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 96\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 97\u001b[0m msg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdone\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:244\u001b[0m, in \u001b[0;36mtorchdiffeq_simulate_to_interruption\u001b[0;34m(interruptions, dynamics, initial_state, start_time, end_time, **kwargs)\u001b[0m\n\u001b[1;32m 234\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtorchdiffeq_simulate_to_interruption\u001b[39m(\n\u001b[1;32m 235\u001b[0m interruptions: List[Interruption[torch\u001b[38;5;241m.\u001b[39mTensor]],\n\u001b[1;32m 236\u001b[0m dynamics: Dynamics[torch\u001b[38;5;241m.\u001b[39mTensor],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 241\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[State[torch\u001b[38;5;241m.\u001b[39mTensor], torch\u001b[38;5;241m.\u001b[39mTensor, Optional[Interruption[torch\u001b[38;5;241m.\u001b[39mTensor]]]:\n\u001b[1;32m 242\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(interruptions) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshould have at least one interruption here\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 244\u001b[0m (next_interruption,), interruption_time \u001b[38;5;241m=\u001b[39m \u001b[43m_torchdiffeq_get_next_interruptions\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 245\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamics\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minitial_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_time\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minterruptions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 246\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 248\u001b[0m value \u001b[38;5;241m=\u001b[39m simulate_point(\n\u001b[1;32m 249\u001b[0m dynamics, initial_state, start_time, interruption_time, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m 250\u001b[0m )\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value, interruption_time, next_interruption\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:192\u001b[0m, in \u001b[0;36m_torchdiffeq_get_next_interruptions\u001b[0;34m(dynamics, start_state, start_time, interruptions, **kwargs)\u001b[0m\n\u001b[1;32m 189\u001b[0m combined_event_f \u001b[38;5;241m=\u001b[39m torchdiffeq_combined_event_f(interruptions, var_order)\n\u001b[1;32m 191\u001b[0m \u001b[38;5;66;03m# Simulate to the event execution.\u001b[39;00m\n\u001b[0;32m--> 192\u001b[0m event_time, event_solutions \u001b[38;5;241m=\u001b[39m \u001b[43m_batched_odeint\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# torchdiffeq.odeint_event(\u001b[39;49;00m\n\u001b[1;32m 193\u001b[0m \u001b[43m \u001b[49m\u001b[43mfunctools\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpartial\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_deriv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamics\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvar_order\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 194\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mstart_state\u001b[49m\u001b[43m[\u001b[49m\u001b[43mv\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mvar_order\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 195\u001b[0m \u001b[43m \u001b[49m\u001b[43mstart_time\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 196\u001b[0m \u001b[43m \u001b[49m\u001b[43mevent_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcombined_event_f\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 197\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 198\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;66;03m# event_state has both the first and final state of the interrupted simulation. We just want the last.\u001b[39;00m\n\u001b[1;32m 201\u001b[0m event_solution: Tuple[torch\u001b[38;5;241m.\u001b[39mTensor, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(\n\u001b[1;32m 202\u001b[0m s[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m event_solutions\n\u001b[1;32m 203\u001b[0m ) \u001b[38;5;66;03m# TODO support event_dim > 0\u001b[39;00m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:119\u001b[0m, in \u001b[0;36m_batched_odeint\u001b[0;34m(func, y0, t, event_fn, **odeint_kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m y0_expanded \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(\n\u001b[1;32m 113\u001b[0m \u001b[38;5;66;03m# y0_[(None,) * (len(y0_batch_shape) - (len(y0_.shape) - event_dim)) + (...,)]\u001b[39;00m\n\u001b[1;32m 114\u001b[0m y0_\u001b[38;5;241m.\u001b[39mexpand(y0_batch_shape \u001b[38;5;241m+\u001b[39m y0_\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;28mlen\u001b[39m(y0_\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;241m-\u001b[39m event_dim :])\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m y0_ \u001b[38;5;129;01min\u001b[39;00m y0\n\u001b[1;32m 116\u001b[0m )\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m event_fn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 119\u001b[0m event_t, yt_raw \u001b[38;5;241m=\u001b[39m \u001b[43mtorchdiffeq\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43modeint_event\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 120\u001b[0m \u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my0_expanded\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mevent_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mevent_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43modeint_kwargs\u001b[49m\n\u001b[1;32m 121\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 123\u001b[0m yt_raw \u001b[38;5;241m=\u001b[39m torchdiffeq\u001b[38;5;241m.\u001b[39modeint(func, y0_expanded, t, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39modeint_kwargs)\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py:101\u001b[0m, in \u001b[0;36modeint_event\u001b[0;34m(func, y0, t0, event_fn, reverse_time, odeint_interface, **kwargs)\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 99\u001b[0m t \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([t0\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m), t0\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mdetach() \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1.0\u001b[39m])\n\u001b[0;32m--> 101\u001b[0m event_t, solution \u001b[38;5;241m=\u001b[39m \u001b[43modeint_interface\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mevent_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mevent_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 103\u001b[0m \u001b[38;5;66;03m# Dummy values for rtol, atol, method, and options.\u001b[39;00m\n\u001b[1;32m 104\u001b[0m shapes, _func, _, t, _, _, _, _, event_fn, _ \u001b[38;5;241m=\u001b[39m _check_inputs(func, y0, t, \u001b[38;5;241m0.0\u001b[39m, \u001b[38;5;241m0.0\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, event_fn, SOLVERS)\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py:79\u001b[0m, in \u001b[0;36modeint\u001b[0;34m(func, y0, t, rtol, atol, method, options, event_fn)\u001b[0m\n\u001b[1;32m 77\u001b[0m solution \u001b[38;5;241m=\u001b[39m solver\u001b[38;5;241m.\u001b[39mintegrate(t)\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 79\u001b[0m event_t, solution \u001b[38;5;241m=\u001b[39m \u001b[43msolver\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mintegrate_until_event\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mevent_fn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 80\u001b[0m event_t \u001b[38;5;241m=\u001b[39m event_t\u001b[38;5;241m.\u001b[39mto(t)\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m t_is_reversed:\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py:122\u001b[0m, in \u001b[0;36mFixedGridODESolver.integrate_until_event\u001b[0;34m(self, t0, event_fn)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mintegrate_until_event\u001b[39m(\u001b[38;5;28mself\u001b[39m, t0, event_fn):\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstep_size \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEvent handling for fixed step solvers currently requires `step_size` to be provided in options.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 124\u001b[0m t0 \u001b[38;5;241m=\u001b[39m t0\u001b[38;5;241m.\u001b[39mtype_as(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39my0)\n\u001b[1;32m 125\u001b[0m y0 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39my0\n", + "\u001b[0;31mAssertionError\u001b[0m: Event handling for fixed step solvers currently requires `step_size` to be provided in options." + ] + } + ], + "source": [ + "def make_var_threshold(var: str, threshold: torch.Tensor):\n", + " return lambda time, state: state[var] - threshold \n", + " \n", + "infection_threshold = make_var_threshold(\"I\", torch.tensor(400.0))\n", + "dynamic_parameter_interventions1 = {infection_threshold: {\"p_cbeta\": torch.tensor(0.3)}}\n", + "\n", + "result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, \n", + " dynamic_parameter_interventions=dynamic_parameter_interventions1, \n", + " solver_method=\"euler\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6008c385", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (pyciemss)", + "language": "python", + "name": "pyciemss" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}