Skip to content

Commit

Permalink
create distribution module
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari committed Sep 5, 2024
1 parent 59ccf64 commit 8e1df35
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 105 deletions.
7 changes: 7 additions & 0 deletions docs/source/msei_reference/distributions.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Distributions
===========

.. automodule:: pyrenew.distributions
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions pyrenew/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# numpydoc ignore=GL08

from pyrenew.distributions.censorednormal import CensoredNormal

__all__ = [
"CensoredNormal",
]
101 changes: 101 additions & 0 deletions pyrenew/distributions/censorednormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import jax
import jax.numpy as jnp
import numpyro
import numpyro.util
from numpyro.distributions import constraints
from numpyro.distributions.util import promote_shapes, validate_sample


class CensoredNormal(numpyro.distributions.Distribution):
"""
Censored normal distribution under which samples
are truncated to lie within a specified interval.
This implementation is adapted from
https://github.com/dylanhmorris/host-viral-determinants/blob/main/src/distributions.py
"""

arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real

def __init__(
self,
loc=0,
scale=1,
lower_limit=-jnp.inf,
upper_limit=jnp.inf,
validate_args=None,
):
"""
Default constructor
Parameters
----------
loc : ArrayLike or float, optional
The mean of the normal distribution.
Defaults to 0.
scale : ArrayLike or float, optional
The standard deviation of the normal
distribution. Must be positive. Defaults to 1.
lower_limit : float, optional
The lower bound of the interval for censoring.
Defaults to -inf (no lower bound).
upper_limit : float, optional
The upper bound of the interval for censoring.
Defaults to inf (no upper bound).
validate_args : bool, optional
If True, checks if parameters are valid.
Defaults to None.
Returns
-------
None
"""
self.loc, self.scale = promote_shapes(loc, scale)
self.lower_limit = lower_limit
self.upper_limit = upper_limit

batch_shape = jax.lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
self.normal_ = numpyro.distributions.Normal(
loc=loc, scale=scale, validate_args=validate_args
)
super().__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
"""
Generates samples from the censored normal distribution.
Returns
-------
Array
Containing samples from the censored normal distribution.
"""
assert numpyro.util.is_prng_key(key)
result = self.normal_.sample(key, sample_shape)
return jnp.clip(result, min=self.lower_limit, max=self.upper_limit)

@validate_sample
def log_prob(self, value):
"""
Computes the log probability density of a given value(s) under
the censored normal distribution.
Returns
-------
Array
Containing log probability of the given value(s)
under the censored normal distribution
"""
rescaled_ulim = (self.upper_limit - self.loc) / self.scale
rescaled_llim = (self.lower_limit - self.loc) / self.scale
lim_val = jnp.where(
value <= self.lower_limit,
jax.scipy.special.log_ndtr(rescaled_llim),
jax.scipy.special.log_ndtr(-rescaled_ulim),
)
# we exploit the fact that for the
# standard normal, P(x > a) = P(-x < a)
# to compute the log complementary CDF
inbounds = jnp.logical_and(value > self.lower_limit, value < self.upper_limit)
result = jnp.where(inbounds, self.normal_.log_prob(value), lim_val)

return result
104 changes: 0 additions & 104 deletions pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@
from typing import NamedTuple, get_type_hints

import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.util
import polars as pl
from jax.typing import ArrayLike
from numpyro.distributions import constraints
from numpyro.distributions.util import promote_shapes, validate_sample
from numpyro.infer import MCMC, NUTS, Predictive

from pyrenew.mcmcutils import plot_posterior, spread_draws
Expand Down Expand Up @@ -561,102 +556,3 @@ def prior_predictive(
)

return predictive(rng_key, **kwargs)


class CensoredNormal(numpyro.distributions.Distribution):
"""
Censored normal distribution under which samples
are truncated to lie within a specified interval.
This implementation is adapted from
https://github.com/dylanhmorris/host-viral-determinants/blob/main/src/distributions.py
"""

arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real

def __init__(
self,
loc=0,
scale=1,
lower_limit=-jnp.inf,
upper_limit=jnp.inf,
validate_args=None,
):
"""
Default constructor
Parameters
----------
loc : ArrayLike or float, optional
The mean of the normal distribution.
Defaults to 0.
scale : ArrayLike or float, optional
The standard deviation of the normal
distribution. Must be positive. Defaults to 1.
lower_limit : float, optional
The lower bound of the interval for censoring.
Defaults to -inf (no lower bound).
upper_limit : float, optional
The upper bound of the interval for censoring.
Defaults to inf (no upper bound).
validate_args : bool, optional
If True, checks if parameters are valid.
Defaults to None.
Returns
-------
None
"""
self.loc, self.scale = promote_shapes(loc, scale)
self.lower_limit = lower_limit
self.upper_limit = upper_limit

batch_shape = jax.lax.broadcast_shapes(
jnp.shape(loc), jnp.shape(scale)
)
self.normal_ = numpyro.distributions.Normal(
loc=loc, scale=scale, validate_args=validate_args
)
super().__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
"""
Generates samples from the censored normal distribution.
Returns
-------
Array
Containing samples from the censored normal distribution.
"""
assert numpyro.util.is_prng_key(key)
result = self.normal_.sample(key, sample_shape)
return jnp.clip(result, min=self.lower_limit, max=self.upper_limit)

@validate_sample
def log_prob(self, value):
"""
Computes the log probability density of a given value(s) under
the censored normal distribution.
Returns
-------
Array
Containing log probability of the given value(s)
under the censored normal distribution
"""
rescaled_ulim = (self.upper_limit - self.loc) / self.scale
rescaled_llim = (self.lower_limit - self.loc) / self.scale
lim_val = jnp.where(
value <= self.lower_limit,
jax.scipy.special.log_ndtr(rescaled_llim),
jax.scipy.special.log_ndtr(-rescaled_ulim),
)
# we exploit the fact that for the
# standard normal, P(x > a) = P(-x < a)
# to compute the log complementary CDF
inbounds = jnp.logical_and(
value > self.lower_limit, value < self.upper_limit
)
result = jnp.where(inbounds, self.normal_.log_prob(value), lim_val)

return result
2 changes: 1 addition & 1 deletion test/test_censorednormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
from numpy.testing import assert_array_almost_equal, assert_array_equal

from pyrenew.metaclass import CensoredNormal
from pyrenew.distributions import CensoredNormal


@pytest.mark.parametrize(
Expand Down

0 comments on commit 8e1df35

Please sign in to comment.