-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
116 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Distributions | ||
=========== | ||
|
||
.. automodule:: pyrenew.distributions | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# numpydoc ignore=GL08 | ||
|
||
from pyrenew.distributions.censorednormal import CensoredNormal | ||
|
||
__all__ = [ | ||
"CensoredNormal", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
import numpyro | ||
import numpyro.util | ||
from numpyro.distributions import constraints | ||
from numpyro.distributions.util import promote_shapes, validate_sample | ||
|
||
|
||
class CensoredNormal(numpyro.distributions.Distribution): | ||
""" | ||
Censored normal distribution under which samples | ||
are truncated to lie within a specified interval. | ||
This implementation is adapted from | ||
https://github.com/dylanhmorris/host-viral-determinants/blob/main/src/distributions.py | ||
""" | ||
|
||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} | ||
support = constraints.real | ||
|
||
def __init__( | ||
self, | ||
loc=0, | ||
scale=1, | ||
lower_limit=-jnp.inf, | ||
upper_limit=jnp.inf, | ||
validate_args=None, | ||
): | ||
""" | ||
Default constructor | ||
Parameters | ||
---------- | ||
loc : ArrayLike or float, optional | ||
The mean of the normal distribution. | ||
Defaults to 0. | ||
scale : ArrayLike or float, optional | ||
The standard deviation of the normal | ||
distribution. Must be positive. Defaults to 1. | ||
lower_limit : float, optional | ||
The lower bound of the interval for censoring. | ||
Defaults to -inf (no lower bound). | ||
upper_limit : float, optional | ||
The upper bound of the interval for censoring. | ||
Defaults to inf (no upper bound). | ||
validate_args : bool, optional | ||
If True, checks if parameters are valid. | ||
Defaults to None. | ||
Returns | ||
------- | ||
None | ||
""" | ||
self.loc, self.scale = promote_shapes(loc, scale) | ||
self.lower_limit = lower_limit | ||
self.upper_limit = upper_limit | ||
|
||
batch_shape = jax.lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) | ||
self.normal_ = numpyro.distributions.Normal( | ||
loc=loc, scale=scale, validate_args=validate_args | ||
) | ||
super().__init__(batch_shape=batch_shape, validate_args=validate_args) | ||
|
||
def sample(self, key, sample_shape=()): | ||
""" | ||
Generates samples from the censored normal distribution. | ||
Returns | ||
------- | ||
Array | ||
Containing samples from the censored normal distribution. | ||
""" | ||
assert numpyro.util.is_prng_key(key) | ||
result = self.normal_.sample(key, sample_shape) | ||
return jnp.clip(result, min=self.lower_limit, max=self.upper_limit) | ||
|
||
@validate_sample | ||
def log_prob(self, value): | ||
""" | ||
Computes the log probability density of a given value(s) under | ||
the censored normal distribution. | ||
Returns | ||
------- | ||
Array | ||
Containing log probability of the given value(s) | ||
under the censored normal distribution | ||
""" | ||
rescaled_ulim = (self.upper_limit - self.loc) / self.scale | ||
rescaled_llim = (self.lower_limit - self.loc) / self.scale | ||
lim_val = jnp.where( | ||
value <= self.lower_limit, | ||
jax.scipy.special.log_ndtr(rescaled_llim), | ||
jax.scipy.special.log_ndtr(-rescaled_ulim), | ||
) | ||
# we exploit the fact that for the | ||
# standard normal, P(x > a) = P(-x < a) | ||
# to compute the log complementary CDF | ||
inbounds = jnp.logical_and(value > self.lower_limit, value < self.upper_limit) | ||
result = jnp.where(inbounds, self.normal_.log_prob(value), lim_val) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters