diff --git a/pyciemss/compiled_dynamics.py b/pyciemss/compiled_dynamics.py index 85248440..cb3b630e 100644 --- a/pyciemss/compiled_dynamics.py +++ b/pyciemss/compiled_dynamics.py @@ -29,7 +29,8 @@ def __init__(self, src, **kwargs): "The model parameters could not be compiled. Please check the model definition." ) from e - for k, v in params.items(): + for k in _sort_dependencies(self.src): + v = params[get_name(k)] if hasattr(self, get_name(k)): continue @@ -189,6 +190,11 @@ def _compile_deriv(src) -> Callable[..., Tuple[torch.Tensor]]: raise NotImplementedError +@functools.singledispatch +def _sort_dependencies(src) -> list: + raise NotImplementedError + + @functools.singledispatch def _compile_initial_state(src) -> Callable[..., Tuple[torch.Tensor]]: raise NotImplementedError diff --git a/pyciemss/mira_integration/compiled_dynamics.py b/pyciemss/mira_integration/compiled_dynamics.py index a427c110..617bf474 100644 --- a/pyciemss/mira_integration/compiled_dynamics.py +++ b/pyciemss/mira_integration/compiled_dynamics.py @@ -10,6 +10,7 @@ import mira.sources.amr import numpy import pyro +import pyro.nn import sympy import sympytorch import torch @@ -20,14 +21,14 @@ _compile_initial_state, _compile_observables, _compile_param_values, + _sort_dependencies, eval_deriv, eval_initial_state, eval_observables, get_name, ) from pyciemss.mira_integration.distributions import ( - mira_distribution_to_pyro, - sort_mira_dependencies + mira_distribution_to_pyro ) S = TypeVar("S") @@ -88,7 +89,7 @@ def _compile_param_values_mira( src: mira.modeling.Model, ) -> Dict[str, Union[torch.Tensor, pyro.nn.PyroParam, pyro.nn.PyroSample]]: values = {} - for param_name in sort_mira_dependencies(src): + for param_name in _sort_dependencies(src): param_info = src.parameters[param_name] if param_info.placeholder: continue @@ -102,8 +103,7 @@ def _compile_param_values_mira( if isinstance(param_value, torch.nn.Parameter): values[param_name] = pyro.nn.PyroParam(param_value) elif isinstance(param_value, pyro.distributions.Distribution): - # call Distribution.sample() to get the sampled values - values[param_name] = param_value.sample() + values[param_name] = pyro.sample(param_name, param_value) elif isinstance(param_value, (numbers.Number, numpy.ndarray, torch.Tensor)): values[param_name] = torch.as_tensor(param_value, dtype=torch.float32) else: diff --git a/pyciemss/mira_integration/distributions.py b/pyciemss/mira_integration/distributions.py index d29a8570..305f9ec5 100644 --- a/pyciemss/mira_integration/distributions.py +++ b/pyciemss/mira_integration/distributions.py @@ -1,25 +1,27 @@ import warnings from typing import Dict, Optional, Union -from pyciemss.compiled_dynamics import get_name + +import mira.modeling +from pyciemss.compiled_dynamics import get_name, _sort_dependencies import mira import mira.metamodel import networkx as nx import pyro import torch import sympytorch -from mira.metamodel.utils import safe_parse_expr, SympyExprStr - +from mira.metamodel.utils import SympyExprStr ParameterDict = Dict[str, Union[torch.Tensor, SympyExprStr]] -def sort_mira_dependencies(src: mira.metamodel.TemplateModel) -> list: +@_sort_dependencies.register(mira.modeling.Model) +def sort_mira_dependencies(src: mira.modeling.Model) -> list: """ Sort the model parameters of a MIRA TemplateModel by their distribution parameter dependencies. Parameters ---------- - src : mira.metamodel.TemplateModel - The MIRA TemplateModel to sort. + src : mira.modeling.Model + The MIRA Model to sort. Returns ------- @@ -38,7 +40,8 @@ def sort_mira_dependencies(src: mira.metamodel.TemplateModel) -> list: if isinstance(v, mira.metamodel.utils.SympyExprStr): for free_symbol in v.free_symbols: dependencies.add_edge(str(free_symbol), str(param_name)) - return list(nx.topological_sort(dependencies)) + return list(nx.topological_sort(dependencies)) + def safe_sympytorch_parse_expr(expr: SympyExprStr, local_dict: Optional[Dict[str, torch.Tensor]]) -> torch.Tensor: """