Skip to content

Commit

Permalink
Make histograms safe to complex numbers, fail safe if Numba does not …
Browse files Browse the repository at this point in the history
…work
  • Loading branch information
matteobachetti committed Sep 29, 2023
1 parent 6ebfa27 commit ea4bdb1
Show file tree
Hide file tree
Showing 2 changed files with 314 additions and 88 deletions.
150 changes: 146 additions & 4 deletions stingray/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import numpy as np
import stingray.utils as utils
from stingray.utils import HAS_NUMBA
from scipy.stats import sem

np.random.seed(20150907)
Expand Down Expand Up @@ -368,10 +369,7 @@ def test_equal_count_energy_ranges():

def test_histogram_equiv_numpy():
x = np.random.uniform(0.0, 1.0, 100)
(
H,
_,
) = np.histogram(x, bins=5, range=(0.0, 1.0))
(H, _) = np.histogram(x, bins=5, range=(0.0, 1.0))

Hn = utils.histogram(x, bins=5, range=np.array([0.0, 1.0]))
assert np.all(H == Hn)
Expand All @@ -386,6 +384,150 @@ def test_histogram2d_equiv_numpy():
assert np.all(H == Hn)


@pytest.mark.skipif("not HAS_NUMBA")
def test_histogramnd_numba_fail_safely():
x = np.random.uniform(0.0, 1.0, 100)
y = np.random.uniform(2.0, 3.0, 100)
z = np.random.uniform(4.0, 5.0, 100)
w = np.random.uniform(6.0, 8.0, 100)
H, _ = np.histogramdd(
(x, y, z, w),
bins=np.array((5, 6, 7, 8)),
range=[(0.0, 1.0), (2.0, 3.0), (4.0, 5), (6.0, 8)],
density=True,
)

with pytest.warns(UserWarning, match="Cannot calculate the histogram with the numba"):
Hn = utils.histogramnd(
(x, y, z, w),
bins=np.array((5, 6, 7, 8)),
range=[(0.0, 1.0), (2.0, 3.0), (4.0, 5), (6.0, 8)],
density=True,
)
assert np.all(H == Hn)


@pytest.mark.skipif("not HAS_NUMBA")
def test_histogram3d_numba_fail_safely():
x = np.random.uniform(0.0, 1.0, 100)
y = np.random.uniform(2.0, 3.0, 100)
z = np.random.uniform(4.0, 5.0, 100)
H, _ = np.histogramdd(
(x, y, z), bins=np.array((5, 6, 7)), range=[(0.0, 1.0), (2.0, 3.0), (4.0, 5)], density=True
)

with pytest.warns(UserWarning, match="Cannot calculate the histogram with the numba"):
Hn = utils.histogram3d(
(x, y, z),
bins=np.array((5, 6, 7)),
range=[(0.0, 1.0), (2.0, 3.0), (4.0, 5)],
density=True,
)
assert np.all(H == Hn)


@pytest.mark.skipif("not HAS_NUMBA")
def test_histogram2d_numba_fail_safely():
x = np.random.uniform(0.0, 1.0, 100)
y = np.random.uniform(2.0, 3.0, 100)
H, _, _ = np.histogram2d(
x, y, bins=np.array((5, 5)), range=[(0.0, 1.0), (2.0, 3.0)], density=True
)

with pytest.warns(UserWarning, match="Cannot calculate the histogram with the numba"):
Hn = utils.histogram2d(
x,
y,
bins=np.array([5, 5]),
range=np.array([[0.0, 1.0], [2.0, 3.0]]),
density=True,
)
assert np.all(H == Hn)


@pytest.mark.skipif("not HAS_NUMBA")
def test_histogram_numba_fail_safely():
x = np.random.uniform(0.0, 1.0, 100)
(H, _) = np.histogram(x, bins=5, range=(0.0, 1.0), density=True)
with pytest.warns(UserWarning, match="Cannot calculate the histogram with the numba"):
Hn = utils.histogram(
x,
bins=5,
range=np.array([0.0, 1.0]),
density=True,
)
assert np.all(H == Hn)


@pytest.mark.skipif("not HAS_NUMBA")
def test_histogramnd_accept_complex():
x = np.random.uniform(0.0, 1.0, 100)
y = np.random.uniform(2.0, 3.0, 100)
z = np.random.uniform(4.0, 5.0, 100)
w = np.random.uniform(6.0, 8.0, 100)
H, _ = np.histogramdd(
(x, y, z, w),
bins=np.array((5, 6, 7, 8)),
range=[(0.0, 1.0), (2.0, 3.0), (4.0, 5), (6.0, 8)],
)

# This implementation does not support weights
with pytest.warns(UserWarning, match="Cannot calculate the histogram with the numba"):
Hn = utils.histogramnd(
(x, y, z, w),
bins=np.array((5, 6, 7, 8)),
range=[(0.0, 1.0), (2.0, 3.0), (4.0, 5), (6.0, 8)],
weights=np.ones_like(x) + 1.0j,
)
assert np.all(H == Hn.real)
assert np.all(H == Hn.imag)


@pytest.mark.skipif("not HAS_NUMBA")
def test_histogram3d_accept_complex():
x = np.random.uniform(0.0, 1.0, 100)
y = np.random.uniform(2.0, 3.0, 100)
z = np.random.uniform(4.0, 5.0, 100)
H, _ = np.histogramdd(
(x, y, z), bins=np.array((5, 6, 7)), range=[(0.0, 1.0), (2.0, 3.0), (4.0, 5)]
)

Hn = utils.histogram3d(
(x, y, z),
bins=np.array((5, 6, 7)),
range=[(0.0, 1.0), (2.0, 3.0), (4.0, 5)],
weights=np.ones_like(x) + 1.0j,
)
assert np.all(H == Hn.real)
assert np.all(H == Hn.imag)


@pytest.mark.skipif("not HAS_NUMBA")
def test_histogram2d_accept_complex():
x = np.random.uniform(0.0, 1.0, 100)
y = np.random.uniform(2.0, 3.0, 100)
H, _, _ = np.histogram2d(x, y, bins=np.array((5, 5)), range=[(0.0, 1.0), (2.0, 3.0)])

Hn = utils.histogram2d(
x,
y,
bins=np.array([5, 5]),
range=np.array([[0.0, 1.0], [2.0, 3.0]]),
weights=np.ones_like(x) + 1.0j,
)
assert np.all(H == Hn.real)
assert np.all(H == Hn.imag)


@pytest.mark.skipif("not HAS_NUMBA")
def test_histogram_accept_complex():
x = np.random.uniform(0.0, 1.0, 100)
(H, _) = np.histogram(x, bins=5, range=(0.0, 1.0))
Hn = utils.histogram(x, bins=5, range=np.array([0.0, 1.0]), weights=np.ones_like(x) + 1.0j)
assert np.all(H == Hn.real)
assert np.all(H == Hn.imag)


def test_compute_bin():
bin_edges = np.array([0, 5, 10])
assert utils.compute_bin(1, bin_edges) == 0
Expand Down
Loading

0 comments on commit ea4bdb1

Please sign in to comment.