Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update inference #22

Merged
merged 11 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,11 @@ def __init__(
noise_ratio: float = 0.95,
reverb_ratio: float = 0.95,
applause_ratio: float = 0.01,
bandpass_ratio: float = 0.1,
bandpass_ratio: float = 0.15,
distort_ratio: float = 0.15,
reduce_ratio: float = 0.01,
codecs_ratio: float = 0.01,
detune_ratio: float = 0.1,
detune_max_shift: float = 0.15,
spec_aug_ratio: float = 0.5,
):
super().__init__()
Expand All @@ -219,8 +220,9 @@ def __init__(
self.bandpass_ratio = bandpass_ratio
self.distort_ratio = distort_ratio
self.reduce_ratio = reduce_ratio
self.detune_ratio = detune_ratio
self.detune_max_shift = detune_max_shift
self.spec_aug_ratio = spec_aug_ratio
self.codecs_ratio = codecs_ratio
self.reduction_resample_rate = 6000 # Hardcoded?

# Audio aug
Expand Down Expand Up @@ -268,6 +270,19 @@ def __init__(
),
)

def get_params(self):
return {
"noise_ratio": self.noise_ratio,
"reverb_ratio": self.reverb_ratio,
"applause_ratio": self.applause_ratio,
"bandpass_ratio": self.bandpass_ratio,
"distort_ratio": self.distort_ratio,
"reduce_ratio": self.reduce_ratio,
"detune_ratio": self.detune_ratio,
"detune_max_shift": self.detune_max_shift,
"spec_aug_ratio": self.spec_aug_ratio,
}

def _get_paths(self, dir_path):
os.makedirs(dir_path, exist_ok=True)

Expand Down Expand Up @@ -399,21 +414,7 @@ def distortion_aug_cpu(self, wav: torch.Tensor):

return wav

def apply_codec(self, wav: torch.tensor):
"""
Apply different audio codecs to the audio.
"""
format_encoder_pairs = [
("wav", "pcm_mulaw"),
("g722", None),
("ogg", "vorbis")
]
for format, encoder in format_encoder_pairs:
encoder = torchaudio.io.AudioEffector(format=format, encoder=encoder)
if random.random() < self.codecs_ratio:
wav = encoder.apply(wav, self.sample_rate)

def shift_spec(self, specs: torch.Tensor, shift: int):
def shift_spec(self, specs: torch.Tensor, shift: int | float):
if shift == 0:
return specs

Expand All @@ -438,9 +439,21 @@ def shift_spec(self, specs: torch.Tensor, shift: int):

return shifted_specs

def detune_spec(self, specs: torch.Tensor):
if random.random() < self.detune_ratio:
detune_shift = random.uniform(
-self.detune_max_shift, self.detune_max_shift
)
detuned_specs = self.shift_spec(specs, shift=detune_shift)

return (specs + detuned_specs) / 2
else:
return specs

def aug_wav(self, wav: torch.Tensor):
# This function doesn't apply distortion. If distortion is desired it
# should be run before hand on the cpu with distortion_aug_cpu.
# should be run beforehand on the cpu with distortion_aug_cpu. Note
# also that detuning is done to the spectrogram in log_mel, not the wav.

# Noise
if random.random() < self.noise_ratio:
Expand Down Expand Up @@ -468,10 +481,17 @@ def norm_mel(self, mel_spec: torch.Tensor):

return log_spec

def log_mel(self, wav: torch.Tensor, shift: int | None = None):
def log_mel(
self, wav: torch.Tensor, shift: int | None = None, detune: bool = False
):
spec = self.spec_transform(wav)[..., :-1]
if shift and shift != 0:

if shift is not None and shift != 0:
spec = self.shift_spec(spec, shift)
elif detune is True:
# Don't detune and spec shift at the same time
spec = self.detune_spec(spec)

mel_spec = self.mel_transform(spec)

# Norm
Expand All @@ -483,8 +503,8 @@ def forward(self, wav: torch.Tensor, shift: int = 0):
# Noise, and reverb
wav = self.aug_wav(wav)

# Spec & pitch shift
log_mel = self.log_mel(wav, shift)
# Spec, detuning & pitch shift
log_mel = self.log_mel(wav, shift, detune=True)

# Spec aug
if random.random() < self.spec_aug_ratio:
Expand Down
Loading
Loading