Skip to content

Commit

Permalink
Torchgate improvements and tests (#84)
Browse files Browse the repository at this point in the history
* Removed code duplications of torchgating
Enabled the direct import of torchgating as a nn.module

* Removed tf from tests and added a test of torch gating

* Added documentation of torchgate as part of noise reduce function in notebook 2.0 and as individual class in notebook 3.0

* Removed tests with cuda
  • Loading branch information
nuniz authored May 11, 2023
1 parent 3f14aa3 commit 57512ce
Show file tree
Hide file tree
Showing 17 changed files with 462 additions and 974 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ after_success:
- coveralls
install:
- pip install -r requirements.txt
- pip install -r requirements-test.txt
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ The most recent version of noisereduce comprises two algorithms:

# Usage
See example notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timsainb/noisereduce/blob/master/notebooks/1.0-test-noise-reduction.ipynb)
Parallel computing example: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timsainb/noisereduce/blob/master/notebooks/2.0-test-noisereduce-pytorch.ipynb)

## reduce_noise

Expand Down Expand Up @@ -147,6 +148,7 @@ y : np.ndarray [shape=(# frames,) or (# channels, # frames)], real-valued
```

## Torch
See example notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timsainb/noisereduce/blob/master/notebooks/3.0-torchgate-as-nn-module.ipynb)
### Simplest usage
```
import torch
Expand Down Expand Up @@ -220,3 +222,4 @@ If you use this code in your research, please cite it:

<p><small>Project based on the <a target="_blank" href="https://drivendata.github.io/cookiecutter-data-science/">cookiecutter data science project template</a>. #cookiecutterdatascience</small></p>


50 changes: 25 additions & 25 deletions noisereduce/noisereduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,32 @@
except ImportError:
TORCH_AVAILABLE = False
if TORCH_AVAILABLE:
from noisereduce.spectralgate.torch.torchgate import SpectralGateTorch
from noisereduce.spectralgate.streamed_torch_gate import StreamedTorchGate


def reduce_noise(
y,
sr,
stationary=False,
y_noise=None,
prop_decrease=1.0,
time_constant_s=2.0,
freq_mask_smooth_hz=500,
time_mask_smooth_ms=50,
thresh_n_mult_nonstationary=2,
sigmoid_slope_nonstationary=10,
n_std_thresh_stationary=1.5,
tmp_folder=None,
chunk_size=600000,
padding=30000,
n_fft=1024,
win_length=None,
hop_length=None,
clip_noise_stationary=True,
use_tqdm=False,
n_jobs=1,
use_torch=False,
device="cuda",
y,
sr,
stationary=False,
y_noise=None,
prop_decrease=1.0,
time_constant_s=2.0,
freq_mask_smooth_hz=500,
time_mask_smooth_ms=50,
thresh_n_mult_nonstationary=2,
sigmoid_slope_nonstationary=10,
n_std_thresh_stationary=1.5,
tmp_folder=None,
chunk_size=600000,
padding=30000,
n_fft=1024,
win_length=None,
hop_length=None,
clip_noise_stationary=True,
use_tqdm=False,
n_jobs=1,
use_torch=False,
device="cuda",
):
"""
Reduce noise via spectral gating.
Expand Down Expand Up @@ -110,7 +110,7 @@ def reduce_noise(
"""

if use_torch:
if TORCH_AVAILABLE == False:
if not TORCH_AVAILABLE:
raise ImportError(
"Torch is not installed. Please install torch to use torch version of spectral gating."
)
Expand All @@ -124,7 +124,7 @@ def reduce_noise(
device = (
torch.device(device) if torch.cuda.is_available() else torch.device(device)
)
sg = SpectralGateTorch(
sg = StreamedTorchGate(
y=y,
sr=sr,
stationary=stationary,
Expand Down
1 change: 1 addition & 0 deletions noisereduce/spectralgate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .nonstationary import SpectralGateNonStationary
from .stationary import SpectralGateStationary
from .streamed_torch_gate import StreamedTorchGate
36 changes: 18 additions & 18 deletions noisereduce/spectralgate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,21 @@ def _smoothing_filter(n_grad_freq, n_grad_time):

class SpectralGate:
def __init__(
self,
y,
sr,
prop_decrease,
chunk_size,
padding,
n_fft,
win_length,
hop_length,
time_constant_s,
freq_mask_smooth_hz,
time_mask_smooth_ms,
tmp_folder,
use_tqdm,
n_jobs,
self,
y,
sr,
prop_decrease,
chunk_size,
padding,
n_fft,
win_length,
hop_length,
time_constant_s,
freq_mask_smooth_hz,
time_mask_smooth_ms,
tmp_folder,
use_tqdm,
n_jobs,
):
self.sr = sr
# if this is a 1D single channel recording
Expand Down Expand Up @@ -138,7 +138,7 @@ def _read_chunk(self, i1, i2):
else:
i2b = i2
chunk = np.zeros((self.n_channels, i2 - i1))
chunk[:, i1b - i1 : i2b - i1] = self.y[:, i1b:i2b]
chunk[:, i1b - i1: i2b - i1] = self.y[:, i1b:i2b]
return chunk

def filter_chunk(self, start_frame, end_frame):
Expand All @@ -147,7 +147,7 @@ def filter_chunk(self, start_frame, end_frame):
i2 = end_frame + self.padding
padded_chunk = self._read_chunk(i1, i2)
filtered_padded_chunk = self._do_filter(padded_chunk)
return filtered_padded_chunk[:, start_frame - i1 : end_frame - i1]
return filtered_padded_chunk[:, start_frame - i1: end_frame - i1]

def _get_filtered_chunk(self, ind):
"""Grabs a single chunk"""
Expand All @@ -161,7 +161,7 @@ def _do_filter(self, chunk):

def _iterate_chunk(self, filtered_chunk, pos, end0, start0, ich):
filtered_chunk0 = self._get_filtered_chunk(ich)
filtered_chunk[:, pos : pos + end0 - start0] = filtered_chunk0[:, start0:end0]
filtered_chunk[:, pos: pos + end0 - start0] = filtered_chunk0[:, start0:end0]
pos += end0 - start0

def get_traces(self, start_frame=None, end_frame=None):
Expand Down
87 changes: 87 additions & 0 deletions noisereduce/spectralgate/streamed_torch_gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
from noisereduce.spectralgate.base import SpectralGate
from noisereduce.torchgate import TorchGate as TG
import numpy as np


class StreamedTorchGate(SpectralGate):
'''
Run interface with noisereduce.
'''

def __init__(
self,
y,
sr,
stationary=False,
y_noise=None,
prop_decrease=1.0,
time_constant_s=2.0,
freq_mask_smooth_hz=500,
time_mask_smooth_ms=50,
thresh_n_mult_nonstationary=2,
sigmoid_slope_nonstationary=10,
n_std_thresh_stationary=1.5,
tmp_folder=None,
chunk_size=600000,
padding=30000,
n_fft=1024,
win_length=None,
hop_length=None,
clip_noise_stationary=True,
use_tqdm=False,
n_jobs=1,
device="cuda",
):
super().__init__(
y=y,
sr=sr,
chunk_size=chunk_size,
padding=padding,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
time_constant_s=time_constant_s,
freq_mask_smooth_hz=freq_mask_smooth_hz,
time_mask_smooth_ms=time_mask_smooth_ms,
tmp_folder=tmp_folder,
prop_decrease=prop_decrease,
use_tqdm=use_tqdm,
n_jobs=n_jobs,
)

self.device = device

# noise convert to torch if needed
if y_noise is not None:
if y_noise.shape[-1] > y.shape[-1] and clip_noise_stationary:
y_noise = y_noise[: y.shape[-1]]
y_noise = torch.from_numpy(y_noise).to(device)
# ensure that y_noise is in shape (#channels, #frames)
if len(y_noise.shape) == 1:
y_noise = y_noise.unsqueeze(0)
self.y_noise = y_noise

# create a torch object
self.tg = TG(
sr=sr,
nonstationary=not stationary,
n_std_thresh_stationary=n_std_thresh_stationary,
n_thresh_nonstationary=thresh_n_mult_nonstationary,
temp_coeff_nonstationary=1 / sigmoid_slope_nonstationary,
n_movemean_nonstationary=int(time_constant_s / self._hop_length * sr),
prop_decrease=prop_decrease,
n_fft=self._n_fft,
win_length=self._win_length,
hop_length=self._hop_length,
freq_mask_smooth_hz=freq_mask_smooth_hz,
time_mask_smooth_ms=time_mask_smooth_ms,
).to(device)

def _do_filter(self, chunk):
"""Do the actual filtering"""
# convert to torch if needed
if type(chunk) is np.ndarray:
chunk = torch.from_numpy(chunk).to(self.device)
chunk_filtered = self.tg(x=chunk, xn=self.y_noise)
return chunk_filtered.cpu().detach().numpy()
Loading

0 comments on commit 57512ce

Please sign in to comment.