From 14021bfcc48b9c79f8852ccdcb0fccdee2c26ea4 Mon Sep 17 00:00:00 2001 From: iver56 Date: Mon, 19 Sep 2022 09:47:52 +0200 Subject: [PATCH] Add support for constant cutoff frequency in LowPassFilter and HighPassfilter --- tests/test_low_pass_filter.py | 20 +++++ .../augmentations/high_pass_filter.py | 4 +- .../augmentations/low_pass_filter.py | 73 ++++++++++++------- 3 files changed, 70 insertions(+), 27 deletions(-) diff --git a/tests/test_low_pass_filter.py b/tests/test_low_pass_filter.py index 22143dfa..d1c5a1db 100644 --- a/tests/test_low_pass_filter.py +++ b/tests/test_low_pass_filter.py @@ -24,3 +24,23 @@ def test_low_pass_filter(self): ).samples.numpy() assert processed_samples.shape == samples.shape assert processed_samples.dtype == np.float32 + + def test_equal_cutoff_min_max(self): + samples = np.array( + [ + [[0.75, 0.5, -0.25, -0.125, 0.0], [0.65, 0.5, -0.25, -0.125, 0.0]], + [[0.3, 0.5, -0.25, -0.125, 0.0], [0.9, 0.5, -0.25, -0.125, 0.0]], + [[0.9, 0.5, -0.25, -1.06, 0.0], [0.9, 0.5, -0.25, -1.12, 0.0]], + ], + dtype=np.float32, + ) + sample_rate = 16000 + + augment = LowPassFilter( + min_cutoff_freq=2000, max_cutoff_freq=2000, p=1.0, output_type="dict" + ) + processed_samples = augment( + samples=torch.from_numpy(samples), sample_rate=sample_rate + ).samples.numpy() + assert processed_samples.shape == samples.shape + assert processed_samples.dtype == np.float32 diff --git a/torch_audiomentations/augmentations/high_pass_filter.py b/torch_audiomentations/augmentations/high_pass_filter.py index 90dcdf1d..0098c701 100644 --- a/torch_audiomentations/augmentations/high_pass_filter.py +++ b/torch_audiomentations/augmentations/high_pass_filter.py @@ -12,8 +12,8 @@ class HighPassFilter(LowPassFilter): def __init__( self, - min_cutoff_freq=20, - max_cutoff_freq=2400, + min_cutoff_freq: float = 20.0, + max_cutoff_freq: float = 2400.0, mode: str = "per_example", p: float = 0.5, p_mode: str = None, diff --git a/torch_audiomentations/augmentations/low_pass_filter.py b/torch_audiomentations/augmentations/low_pass_filter.py index 7065519b..bae4c745 100644 --- a/torch_audiomentations/augmentations/low_pass_filter.py +++ b/torch_audiomentations/augmentations/low_pass_filter.py @@ -24,8 +24,8 @@ class LowPassFilter(BaseWaveformTransform): def __init__( self, - min_cutoff_freq=150, - max_cutoff_freq=7500, + min_cutoff_freq: float = 150.0, + max_cutoff_freq: float = 7500.0, mode: str = "per_example", p: float = 0.5, p_mode: str = None, @@ -55,6 +55,8 @@ def __init__( if self.min_cutoff_freq > self.max_cutoff_freq: raise ValueError("min_cutoff_freq must not be greater than max_cutoff_freq") + self.cached_lpf = None + def randomize_parameters( self, samples: Tensor = None, @@ -67,23 +69,40 @@ def randomize_parameters( """ batch_size, _, num_samples = samples.shape - # Sample frequencies uniformly in mel space, then convert back to frequency - dist = torch.distributions.Uniform( - low=convert_frequencies_to_mels( - torch.tensor( - self.min_cutoff_freq, dtype=torch.float32, device=samples.device - ) - ), - high=convert_frequencies_to_mels( - torch.tensor( - self.max_cutoff_freq, dtype=torch.float32, device=samples.device + if self.min_cutoff_freq == self.max_cutoff_freq: + # Speed up computation by caching the LPF instance if the cutoff is constant + cutoff_as_fraction_of_sr = self.min_cutoff_freq / sample_rate + lpf_needs_init = ( + self.cached_lpf is None + or self.cached_lpf.cutoff != cutoff_as_fraction_of_sr + ) + if lpf_needs_init: + self.cached_lpf = julius.LowPassFilter(cutoff=cutoff_as_fraction_of_sr) + self.transform_parameters["cutoff_freq"] = torch.full( + size=(batch_size,), + fill_value=self.min_cutoff_freq, + dtype=torch.float32, + device=samples.device, ) - ), - validate_args=True, - ) - self.transform_parameters["cutoff_freq"] = convert_mels_to_frequencies( - dist.sample(sample_shape=(batch_size,)) - ) + else: + # Sample frequencies uniformly in mel space, then convert back to frequency + dist = torch.distributions.Uniform( + low=convert_frequencies_to_mels( + torch.tensor( + self.min_cutoff_freq, dtype=torch.float32, device=samples.device + ) + ), + high=convert_frequencies_to_mels( + torch.tensor( + self.max_cutoff_freq, dtype=torch.float32, device=samples.device + ) + ), + validate_args=True, + ) + self.transform_parameters["cutoff_freq"] = convert_mels_to_frequencies( + dist.sample(sample_shape=(batch_size,)) + ) + self.cached_lpf = None def apply_transform( self, @@ -95,14 +114,18 @@ def apply_transform( batch_size, num_channels, num_samples = samples.shape - cutoffs_as_fraction_of_sample_rate = ( - self.transform_parameters["cutoff_freq"] / sample_rate - ) - # TODO: Instead of using a for loop, perform batched compute to speed things up - for i in range(batch_size): - samples[i] = julius.lowpass_filter( - samples[i], cutoffs_as_fraction_of_sample_rate[i].item() + if self.cached_lpf is None: + cutoffs_as_fraction_of_sample_rate = ( + self.transform_parameters["cutoff_freq"] / sample_rate ) + # TODO: Instead of using a for loop, perform batched compute to speed things up + for i in range(batch_size): + samples[i] = julius.lowpass_filter( + samples[i], cutoffs_as_fraction_of_sample_rate[i].item() + ) + else: + for i in range(batch_size): + samples[i] = self.cached_lpf(samples[i]) return ObjectDict( samples=samples,