Skip to content

Commit

Permalink
test_hierarchical passes, but nothing else passes
Browse files Browse the repository at this point in the history
  • Loading branch information
djinnome committed Oct 30, 2024
1 parent 6a39869 commit 2e47acc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
8 changes: 7 additions & 1 deletion pyciemss/compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(self, src, **kwargs):
"The model parameters could not be compiled. Please check the model definition."
) from e

for k, v in params.items():
for k in _sort_dependencies(self.src):
v = params[get_name(k)]
if hasattr(self, get_name(k)):
continue

Expand Down Expand Up @@ -189,6 +190,11 @@ def _compile_deriv(src) -> Callable[..., Tuple[torch.Tensor]]:
raise NotImplementedError


@functools.singledispatch
def _sort_dependencies(src) -> list:
raise NotImplementedError


@functools.singledispatch
def _compile_initial_state(src) -> Callable[..., Tuple[torch.Tensor]]:
raise NotImplementedError
Expand Down
10 changes: 5 additions & 5 deletions pyciemss/mira_integration/compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import mira.sources.amr
import numpy
import pyro
import pyro.nn
import sympy
import sympytorch
import torch
Expand All @@ -20,14 +21,14 @@
_compile_initial_state,
_compile_observables,
_compile_param_values,
_sort_dependencies,
eval_deriv,
eval_initial_state,
eval_observables,
get_name,
)
from pyciemss.mira_integration.distributions import (
mira_distribution_to_pyro,
sort_mira_dependencies
mira_distribution_to_pyro
)

S = TypeVar("S")
Expand Down Expand Up @@ -88,7 +89,7 @@ def _compile_param_values_mira(
src: mira.modeling.Model,
) -> Dict[str, Union[torch.Tensor, pyro.nn.PyroParam, pyro.nn.PyroSample]]:
values = {}
for param_name in sort_mira_dependencies(src):
for param_name in _sort_dependencies(src):
param_info = src.parameters[param_name]
if param_info.placeholder:
continue
Expand All @@ -102,8 +103,7 @@ def _compile_param_values_mira(
if isinstance(param_value, torch.nn.Parameter):
values[param_name] = pyro.nn.PyroParam(param_value)
elif isinstance(param_value, pyro.distributions.Distribution):
# call Distribution.sample() to get the sampled values
values[param_name] = param_value.sample()
values[param_name] = pyro.sample(param_name, param_value)
elif isinstance(param_value, (numbers.Number, numpy.ndarray, torch.Tensor)):
values[param_name] = torch.as_tensor(param_value, dtype=torch.float32)
else:
Expand Down
17 changes: 10 additions & 7 deletions pyciemss/mira_integration/distributions.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
import warnings
from typing import Dict, Optional, Union
from pyciemss.compiled_dynamics import get_name

import mira.modeling
from pyciemss.compiled_dynamics import get_name, _sort_dependencies
import mira
import mira.metamodel
import networkx as nx
import pyro
import torch
import sympytorch
from mira.metamodel.utils import safe_parse_expr, SympyExprStr

from mira.metamodel.utils import SympyExprStr
ParameterDict = Dict[str, Union[torch.Tensor, SympyExprStr]]


def sort_mira_dependencies(src: mira.metamodel.TemplateModel) -> list:
@_sort_dependencies.register(mira.modeling.Model)
def sort_mira_dependencies(src: mira.modeling.Model) -> list:
"""
Sort the model parameters of a MIRA TemplateModel by their distribution parameter dependencies.
Parameters
----------
src : mira.metamodel.TemplateModel
The MIRA TemplateModel to sort.
src : mira.modeling.Model
The MIRA Model to sort.
Returns
-------
Expand All @@ -38,7 +40,8 @@ def sort_mira_dependencies(src: mira.metamodel.TemplateModel) -> list:
if isinstance(v, mira.metamodel.utils.SympyExprStr):
for free_symbol in v.free_symbols:
dependencies.add_edge(str(free_symbol), str(param_name))
return list(nx.topological_sort(dependencies))
return list(nx.topological_sort(dependencies))


def safe_sympytorch_parse_expr(expr: SympyExprStr, local_dict: Optional[Dict[str, torch.Tensor]]) -> torch.Tensor:
"""
Expand Down

0 comments on commit 2e47acc

Please sign in to comment.