Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Method for adding noise to a signal #377

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions madmom/audio/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,38 @@ def trim(signal, where='fb'):
return signal[first:last]


def add_noise(signal, snr=10.):
"""
Add Additive White Gaussian Noise (AWGN) to a signal.

Parameters
----------
signal : numpy array
Original signal.
snr : float, optional
The Signal-to-noise ratio determines how strong the noise should be
relative to the signal. A lower value results in a stronger noise.

Returns
-------
numpy array
Signal with added noise.

"""
# make sure SNR is in allowed range
if snr <= 0:
raise ValueError("Invalid value for SNR, must be greater than zero.")
# generate gaussian random noise
noise = np.random.normal(0, 1, signal.shape)
# compute power for signal and noise
power_signal = root_mean_square(signal)
power_noise = root_mean_square(noise)
# scale noise to match SNR
noise *= (power_signal / power_noise) / snr
# return signal with added noise
return signal + noise.astype(signal.dtype)


def energy(signal):
"""
Compute the energy of a (framed) signal.
Expand Down
60 changes: 60 additions & 0 deletions tests/test_audio_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,66 @@ def test_values(self):
self.assertTrue(np.allclose(result, [5, 4, 3, 2, 1]))


class TestAddNoiseFunction(unittest.TestCase):

def test_types(self):
# mono signals
result = add_noise(sig_1d)
self.assertTrue(type(result) == type(sig_1d))
self.assertTrue(result.ndim == sig_1d.ndim)
self.assertTrue(result.dtype == sig_1d.dtype)
signal = Signal(sample_file)
result = add_noise(signal)
self.assertIsInstance(result, Signal)
self.assertIsInstance(result, np.ndarray)
self.assertTrue(result.ndim == signal.ndim)
self.assertTrue(result.dtype == np.int16)
# multi-channel signals
result = trim(sig_2d)
self.assertTrue(type(result) == type(sig_2d))
self.assertTrue(len(result) == len(sig_2d))
self.assertTrue(result.ndim == sig_2d.ndim)
self.assertTrue(result.dtype == sig_2d.dtype)
signal = Signal(stereo_sample_file)
result = trim(signal)
self.assertIsInstance(result, Signal)
self.assertIsInstance(result, np.ndarray)
self.assertTrue(result.ndim == signal.ndim)
self.assertTrue(result.dtype == np.int16)

def test_errors(self):
with self.assertRaises(ValueError):
add_noise(sig_1d, 0)
with self.assertRaises(ValueError):
add_noise(sig_1d, -1)

def test_values(self):
# fixed random seed
np.random.seed(0)
# mono signals
result = add_noise(sig_1d, snr=10.)
self.assertTrue(np.allclose(result, sig_1d, atol=0.11))
result = add_noise(sig_1d, snr=20.)
self.assertTrue(np.allclose(result, sig_1d, atol=0.06))
# same with int dtype
result = add_noise(sig_1d.astype(np.int), 10.)
self.assertTrue(np.allclose(result, sig_1d))
result = add_noise(sig_1d.astype(np.int), 0.5)
self.assertTrue(np.allclose(result, [0, 1, 2, 0, 0, 1, -2, 0, 1]))
# multi-channel signals
result = add_noise(sig_2d, snr=10.)
self.assertTrue(np.allclose(result, sig_2d, atol=0.13))
result = add_noise(sig_2d, snr=20.)
self.assertTrue(np.allclose(result, sig_2d, atol=0.08))
# same with int dtype
result = add_noise(sig_2d.astype(np.int), 10.)
self.assertTrue(np.allclose(result, sig_2d))
result = add_noise(sig_2d.astype(np.int), 0.5)
self.assertTrue(np.allclose(result, [[0, 2], [0, 1], [1, 1],
[0, 2], [0, 1], [3, -1],
[-1, 2], [-1, 2], [1, 1]]))


class TestEnergyFunction(unittest.TestCase):

def test_types(self):
Expand Down