Skip to content

Commit

Permalink
Add support for constant cutoff frequency in LowPassFilter and HighPa…
Browse files Browse the repository at this point in the history
…ssfilter
  • Loading branch information
iver56 committed Sep 19, 2022
1 parent 47546e8 commit 14021bf
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 27 deletions.
20 changes: 20 additions & 0 deletions tests/test_low_pass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,23 @@ def test_low_pass_filter(self):
).samples.numpy()
assert processed_samples.shape == samples.shape
assert processed_samples.dtype == np.float32

def test_equal_cutoff_min_max(self):
samples = np.array(
[
[[0.75, 0.5, -0.25, -0.125, 0.0], [0.65, 0.5, -0.25, -0.125, 0.0]],
[[0.3, 0.5, -0.25, -0.125, 0.0], [0.9, 0.5, -0.25, -0.125, 0.0]],
[[0.9, 0.5, -0.25, -1.06, 0.0], [0.9, 0.5, -0.25, -1.12, 0.0]],
],
dtype=np.float32,
)
sample_rate = 16000

augment = LowPassFilter(
min_cutoff_freq=2000, max_cutoff_freq=2000, p=1.0, output_type="dict"
)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).samples.numpy()
assert processed_samples.shape == samples.shape
assert processed_samples.dtype == np.float32
4 changes: 2 additions & 2 deletions torch_audiomentations/augmentations/high_pass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class HighPassFilter(LowPassFilter):

def __init__(
self,
min_cutoff_freq=20,
max_cutoff_freq=2400,
min_cutoff_freq: float = 20.0,
max_cutoff_freq: float = 2400.0,
mode: str = "per_example",
p: float = 0.5,
p_mode: str = None,
Expand Down
73 changes: 48 additions & 25 deletions torch_audiomentations/augmentations/low_pass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class LowPassFilter(BaseWaveformTransform):

def __init__(
self,
min_cutoff_freq=150,
max_cutoff_freq=7500,
min_cutoff_freq: float = 150.0,
max_cutoff_freq: float = 7500.0,
mode: str = "per_example",
p: float = 0.5,
p_mode: str = None,
Expand Down Expand Up @@ -55,6 +55,8 @@ def __init__(
if self.min_cutoff_freq > self.max_cutoff_freq:
raise ValueError("min_cutoff_freq must not be greater than max_cutoff_freq")

self.cached_lpf = None

def randomize_parameters(
self,
samples: Tensor = None,
Expand All @@ -67,23 +69,40 @@ def randomize_parameters(
"""
batch_size, _, num_samples = samples.shape

# Sample frequencies uniformly in mel space, then convert back to frequency
dist = torch.distributions.Uniform(
low=convert_frequencies_to_mels(
torch.tensor(
self.min_cutoff_freq, dtype=torch.float32, device=samples.device
)
),
high=convert_frequencies_to_mels(
torch.tensor(
self.max_cutoff_freq, dtype=torch.float32, device=samples.device
if self.min_cutoff_freq == self.max_cutoff_freq:
# Speed up computation by caching the LPF instance if the cutoff is constant
cutoff_as_fraction_of_sr = self.min_cutoff_freq / sample_rate
lpf_needs_init = (
self.cached_lpf is None
or self.cached_lpf.cutoff != cutoff_as_fraction_of_sr
)
if lpf_needs_init:
self.cached_lpf = julius.LowPassFilter(cutoff=cutoff_as_fraction_of_sr)
self.transform_parameters["cutoff_freq"] = torch.full(
size=(batch_size,),
fill_value=self.min_cutoff_freq,
dtype=torch.float32,
device=samples.device,
)
),
validate_args=True,
)
self.transform_parameters["cutoff_freq"] = convert_mels_to_frequencies(
dist.sample(sample_shape=(batch_size,))
)
else:
# Sample frequencies uniformly in mel space, then convert back to frequency
dist = torch.distributions.Uniform(
low=convert_frequencies_to_mels(
torch.tensor(
self.min_cutoff_freq, dtype=torch.float32, device=samples.device
)
),
high=convert_frequencies_to_mels(
torch.tensor(
self.max_cutoff_freq, dtype=torch.float32, device=samples.device
)
),
validate_args=True,
)
self.transform_parameters["cutoff_freq"] = convert_mels_to_frequencies(
dist.sample(sample_shape=(batch_size,))
)
self.cached_lpf = None

def apply_transform(
self,
Expand All @@ -95,14 +114,18 @@ def apply_transform(

batch_size, num_channels, num_samples = samples.shape

cutoffs_as_fraction_of_sample_rate = (
self.transform_parameters["cutoff_freq"] / sample_rate
)
# TODO: Instead of using a for loop, perform batched compute to speed things up
for i in range(batch_size):
samples[i] = julius.lowpass_filter(
samples[i], cutoffs_as_fraction_of_sample_rate[i].item()
if self.cached_lpf is None:
cutoffs_as_fraction_of_sample_rate = (
self.transform_parameters["cutoff_freq"] / sample_rate
)
# TODO: Instead of using a for loop, perform batched compute to speed things up
for i in range(batch_size):
samples[i] = julius.lowpass_filter(
samples[i], cutoffs_as_fraction_of_sample_rate[i].item()
)
else:
for i in range(batch_size):
samples[i] = self.cached_lpf(samples[i])

return ObjectDict(
samples=samples,
Expand Down

0 comments on commit 14021bf

Please sign in to comment.