From 3c9f4de1d8da2ed6a60611a06c434c60cb945f48 Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Thu, 4 Jul 2024 13:54:32 +0200 Subject: [PATCH] Ensure correct float types with numpy 2.0 --- auglib/core/utils.py | 24 +++++++++++++++++++----- tests/test_transform_babble_noise.py | 6 ++---- tests/test_transform_pink_noise.py | 2 +- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/auglib/core/utils.py b/auglib/core/utils.py index d0b3b55..2d486f0 100644 --- a/auglib/core/utils.py +++ b/auglib/core/utils.py @@ -22,7 +22,7 @@ def from_db(x_db: typing.Union[float, observe.Base]) -> float: """ x_db = observe.observe(x_db) x = pow(10.0, x_db / 20.0) - return x + return float(x) def get_peak(signal: np.ndarray) -> float: @@ -34,6 +34,10 @@ def get_peak(signal: np.ndarray) -> float: Returns: peak as positive value + Examples: + >>> get_peak(np.array([1, 2, 3])) + 3.0 + """ minimum = np.min(signal) maximum = np.max(signal) @@ -41,21 +45,31 @@ def get_peak(signal: np.ndarray) -> float: peak = abs(minimum) else: peak = maximum - return peak + return float(peak) -def rms_db(signal: np.ndarray): +def rms_db(signal: np.ndarray) -> float: r"""Root mean square in dB. Very soft signals are limited to a value of -120 dB. + Args: + signal: input signal + + Returns: + root mean square in decibel + + Examples: + >>> rms_db(np.zeros((1, 4))) + -120.0 + """ # It is: # 20 * log10(rms) = 10 * log10(power) # which saves us from calculating sqrt() power = np.mean(np.square(signal)) - return 10 * np.log10(max(1e-12, power)) + return float(10 * np.log10(max(1e-12, power))) def to_db(x: typing.Union[float, observe.Base]) -> float: @@ -75,7 +89,7 @@ def to_db(x: typing.Union[float, observe.Base]) -> float: x = observe.observe(x) assert x > 0, "cannot convert gain {} to decibels".format(x) x_db = 20 * np.log10(x) - return x_db + return float(x_db) def to_samples( diff --git a/tests/test_transform_babble_noise.py b/tests/test_transform_babble_noise.py index e8644ef..c2ebb07 100644 --- a/tests/test_transform_babble_noise.py +++ b/tests/test_transform_babble_noise.py @@ -45,10 +45,8 @@ def test_babble_noise_1( if snr_db is not None: gain_db = -120 - snr_db gain = audmath.inverse_db(gain_db) - expected_babble = gain * np.ones( - (1, int(duration * sampling_rate)), - dtype=auglib.core.transform.DTYPE, - ) + expected_babble = gain * np.ones((1, int(duration * sampling_rate))) + expected_babble = expected_babble.astype(auglib.core.transform.DTYPE) babble = transform(signal) assert babble.dtype == expected_babble.dtype diff --git a/tests/test_transform_pink_noise.py b/tests/test_transform_pink_noise.py index ef80f4c..5613cf7 100644 --- a/tests/test_transform_pink_noise.py +++ b/tests/test_transform_pink_noise.py @@ -73,7 +73,7 @@ def test_pink_noise(duration, sampling_rate, gain_db, snr_db): ) assert noise.shape == expected_noise.shape assert noise.dtype == expected_noise.dtype - np.testing.assert_almost_equal(noise, expected_noise) + np.testing.assert_almost_equal(noise, expected_noise, decimal=6) @pytest.mark.parametrize(