Skip to content

Commit

Permalink
pre-commit run
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari committed Sep 6, 2024
1 parent c336e54 commit a56ac2f
Showing 1 changed file with 35 additions and 8 deletions.
43 changes: 35 additions & 8 deletions test/test_censorednormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,36 @@
import jax.numpy as jnp
import numpyro
import pytest
from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_equal
from numpy.testing import (
assert_array_almost_equal,
assert_array_equal,
assert_equal,
)

from numpyro.distributions import constraints
from pyrenew.distributions import CensoredNormal


@pytest.mark.parametrize(
["loc", "scale", "lower_limit", "upper_limit", "in_val", "l_val", "h_val"],
[
[jnp.array([0]), jnp.array([2.0, 1.0]), -1, 1, jnp.array([0, 0.5]), -2, 2],
[jnp.array([0, 1]), jnp.array([1.0]), -1, 2, jnp.array([0, 0.5]), -2, 3],
[
jnp.array([0]),
jnp.array([2.0, 1.0]),
-1,
1,
jnp.array([0, 0.5]),
-2,
2,
],
[
jnp.array([0, 1]),
jnp.array([1.0]),
-1,
2,
jnp.array([0, 0.5]),
-2,
3,
],
],
)
def test_interval_censored_normal_distribution(
Expand All @@ -41,7 +60,9 @@ def test_interval_censored_normal_distribution(
assert jnp.all(samp <= upper_limit)

# test log prob of values within bounds
assert_array_equal(censored_dist.log_prob(in_val), normal_dist.log_prob(in_val))
assert_array_equal(
censored_dist.log_prob(in_val), normal_dist.log_prob(in_val)
)

# test log prob of values lower than the limit
assert_array_almost_equal(
Expand Down Expand Up @@ -92,7 +113,9 @@ def test_left_censored_normal_distribution(
assert jnp.all(samp >= lower_limit)

# test log prob of values within bounds
assert_array_equal(censored_dist.log_prob(in_val), normal_dist.log_prob(in_val))
assert_array_equal(
censored_dist.log_prob(in_val), normal_dist.log_prob(in_val)
)

# test log prob of values lower than the limit
assert_array_almost_equal(
Expand Down Expand Up @@ -121,15 +144,19 @@ def test_right_censored_normal_distribution(
Tests the upper censored normal distribution samples
within the limit and calculation of log probability
"""
censored_dist = CensoredNormal(loc=loc, scale=scale, upper_limit=upper_limit)
censored_dist = CensoredNormal(
loc=loc, scale=scale, upper_limit=upper_limit
)
normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale)

# test samples within the bounds
samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,))
assert jnp.all(samp <= upper_limit)

# test log prob of values within bounds
assert_array_equal(censored_dist.log_prob(in_val), normal_dist.log_prob(in_val))
assert_array_equal(
censored_dist.log_prob(in_val), normal_dist.log_prob(in_val)
)

# test log prob of values higher than the limit
assert_array_almost_equal(
Expand Down

0 comments on commit a56ac2f

Please sign in to comment.