Skip to content

Commit

Permalink
added warning and refined types
Browse files Browse the repository at this point in the history
  • Loading branch information
SamWitty committed Jul 30, 2024
1 parent 7a1862f commit a127dc1
Showing 1 changed file with 44 additions and 46 deletions.
90 changes: 44 additions & 46 deletions pyciemss/mira_integration/distributions.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from typing import Callable, Dict, Union
import warnings
from typing import Dict

import mira.metamodel
import pyro
import torch

ParameterDict = Dict[str, torch.Tensor]

def mira_uniform_to_pyro(
parameters: Dict[str, float]
) -> pyro.distributions.Distribution:

def mira_uniform_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution:
low = parameters["minimum"]
high = parameters["maximum"]
return pyro.distributions.Uniform(low=low, high=high)


def mira_normal_to_pyro(
parameters: Dict[str, float]
) -> pyro.distributions.Distribution:
def mira_normal_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution:
if "mean" in parameters.keys():
loc = parameters["mean"]
if "stdev" in parameters.keys():
Expand All @@ -28,7 +28,7 @@ def mira_normal_to_pyro(


def mira_lognormal_to_pyro(
parameters: Dict[str, float]
parameters: ParameterDict,
) -> pyro.distributions.Distribution:
if "meanLog" in parameters.keys():
loc = parameters["meanLog"]
Expand All @@ -42,7 +42,7 @@ def mira_lognormal_to_pyro(

# Provide either probs or logits, not both
def mira_bernoulli_to_pyro(
parameters: Dict[str, float]
parameters: ParameterDict,
) -> pyro.distributions.Distribution:
if "probability" in parameters.keys():
probs = parameters["probability"]
Expand All @@ -54,12 +54,12 @@ def mira_bernoulli_to_pyro(
return pyro.distributions.Bernoulli(probs=probs, logits=logits)


def mira_beta_to_pyro(parameters: Dict[str, float]) -> pyro.distributions.Distribution:
def mira_beta_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution:
return pyro.distributions.Beta(alpha=parameters["alpha"], beta=parameters["beta"])


def mira_betabinomial_to_pyro(
parameters: Dict[str, Union[float, list]]
parameters: ParameterDict,
) -> pyro.distributions.Distribution:
raise NotImplementedError(
"Conversion from MIRA BetaBinomial distribution to Pyro distribution is not implemented."
Expand All @@ -73,9 +73,7 @@ def mira_betabinomial_to_pyro(
)


def mira_binomial_to_pyro(
parameters: Dict[str, Union[float, list]]
) -> pyro.distributions.Distribution:
def mira_binomial_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution:
total_count = parameters["numberOfTrials"]
if "probability" in parameters.keys():
probs = parameters["probability"]
Expand All @@ -89,30 +87,28 @@ def mira_binomial_to_pyro(
)


def mira_cauchy_to_pyro(
parameters: Dict[str, float]
) -> pyro.distributions.Distribution:
def mira_cauchy_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution:
loc = parameters["location"]
scale = parameters["scale"]
return pyro.distributions.Cauchy(loc=loc, scale=scale)


def mira_chisquared_to_pyro(
parameters: Dict[str, float]
parameters: ParameterDict,
) -> pyro.distributions.Distribution:
df = parameters["degreesOfFreedom"]
return pyro.distributions.Chi2(df=df)


def mira_dirichlet_to_pyro(
parameters: Dict[str, list]
parameters: ParameterDict,
) -> pyro.distributions.Distribution:
concentration = parameters["concentration"]
return pyro.distributions.Dirichlet(concentration=concentration)


def mira_exponential_to_pyro(
parameters: Dict[str, float]
parameters: ParameterDict,
) -> pyro.distributions.Distribution:
if "rate" in parameters.keys():
rate = parameters["rate"]
Expand All @@ -121,7 +117,7 @@ def mira_exponential_to_pyro(
return pyro.distributions.Exponential(rate=rate)


def mira_gamma_to_pyro(parameters: Dict[str, float]) -> pyro.distributions.Distribution:
def mira_gamma_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution:
if "shape" in parameters.keys():
concentration = parameters["shape"]
if "scale" in parameters.keys():
Expand All @@ -132,25 +128,21 @@ def mira_gamma_to_pyro(parameters: Dict[str, float]) -> pyro.distributions.Distr


def mira_inversegamma_to_pyro(
parameters: Dict[str, float]
parameters: ParameterDict,
) -> pyro.distributions.Distribution:
raise NotImplementedError(
"Conversion from MIRA InverseGamma distribution to Pyro distribution is not implemented."
)
# TODO: Map parameters to Pyro distribution parameters


def mira_gumbel_to_pyro(
parameters: Dict[str, float]
) -> pyro.distributions.Distribution:
def mira_gumbel_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution:
loc = parameters["location"]
scale = parameters["scale"]
return pyro.distributions.Gumbel(loc=loc, scale=scale)


def mira_laplace_to_pyro(
parameters: Dict[str, float]
) -> pyro.distributions.Distribution:
def mira_laplace_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution:

if "location" in parameters.keys():
loc = parameters["location"]
Expand All @@ -166,47 +158,37 @@ def mira_laplace_to_pyro(


def mira_paretotypeI_to_pyro(
parameters: Dict[str, float]
parameters: ParameterDict,
) -> pyro.distributions.Distribution:
raise NotImplementedError(
"Conversion from MIRA ParetoTypeI distribution to Pyro distribution is not implemented."
)
# TODO: Confirm that parameters are mapped correctly
scale = parameters["scale"]
alpha = parameters["shape"]
return pyro.distributions.Pareto(scale=scale, alpha=alpha)


def mira_poisson_to_pyro(
parameters: Dict[str, float]
) -> pyro.distributions.Distribution:
def mira_poisson_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution:
rate = parameters["rate"]
return pyro.distributions.Poisson(rate=rate)


def mira_studentt_to_pyro(
parameters: Dict[str, float]
) -> pyro.distributions.Distribution:
def mira_studentt_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution:
if "mean" in parameters.keys():
loc = parameters["mean"]
elif "location" in parameters.keys():
loc = parameters["location"]
else:
loc = 0.0
loc = torch.tensor(0.0)

if "scale" in parameters.keys():
scale = parameters["scale"]
else:
scale = 1.0
scale = torch.tensor(1.0)

df = parameters["degreesOfFreedom"]

return pyro.distributions.StudentT(df=df, loc=loc, scale=scale)


def mira_weibull_to_pyro(
parameters: Dict[str, float]
) -> pyro.distributions.Distribution:
def mira_weibull_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution:

if "scale" in parameters.keys():
scale = parameters["scale"]
Expand All @@ -219,7 +201,7 @@ def mira_weibull_to_pyro(


# Key - MIRA distribution type : str
# Value - MIRA -> Pyro function : Callable[[Dict[str, float]], pyro.distributions.Distribution]
# Value - MIRA -> Pyro function : Callable[[ParameterDict], pyro.distributions.Distribution]
# See https://github.com/indralab/mira/blob/main/mira/dkg/resources/probonto.json for MIRA distribution types
_MIRA_TO_PYRO = {
"Uniform1": mira_uniform_to_pyro,
Expand Down Expand Up @@ -256,6 +238,15 @@ def mira_weibull_to_pyro(
"Weibull2": mira_weibull_to_pyro,
}

_TESTED_DISTRIBUTIONS = [
"Uniform1",
"StandardUniform1",
"StandardNormal1",
"Normal1",
"Normal2",
"Normal3",
]


def mira_distribution_to_pyro(
mira_dist: mira.metamodel.template_model.Distribution,
Expand All @@ -265,4 +256,11 @@ def mira_distribution_to_pyro(
f"Conversion from MIRA distribution type {mira_dist.type} to Pyro distribution is not implemented."
)

return _MIRA_TO_PYRO[mira_dist.type](mira_dist.parameters)
if mira_dist.type not in _TESTED_DISTRIBUTIONS:
warnings.warn(
f"Conversion from MIRA distribution type {mira_dist.type} to Pyro distribution has not been tested."
)

parameters = {param.name: torch.as_tensor(param.value) for param in mira_dist}

return _MIRA_TO_PYRO[mira_dist.type](parameters)

0 comments on commit a127dc1

Please sign in to comment.