Skip to content

Commit

Permalink
Triple stacked mels (#35)
Browse files Browse the repository at this point in the history
* update README

* update README

* add fp16

* inference modification

* add stacked mels

* triple
  • Loading branch information
loubbrad authored May 7, 2024
1 parent 01ac3e2 commit 0d7badb
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 29 deletions.
78 changes: 53 additions & 25 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# hard-coded audio hyperparameters
config = load_config()["audio"]
SAMPLE_RATE = config["sample_rate"]
N_FFT = config["n_fft"]
N_FFT = config["n_fft_large"]
HOP_LENGTH = config["hop_len"]
CHUNK_LENGTH = config["chunk_len"]
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
Expand Down Expand Up @@ -82,10 +82,12 @@ def __init__(
self.config = load_config()["audio"]
self.sample_rate = self.config["sample_rate"]
self.chunk_len = self.config["chunk_len"]
self.n_fft = self.config["n_fft"]
self.n_fft_reduced = self.config["n_fft_reduced"]
self.n_mels = self.config["n_mels"]
self.n_mels_reduced = self.config["n_mels_reduced"]
self.n_fft_large = self.config["n_fft_large"]
self.n_fft_med = self.config["n_fft_med"]
self.n_fft_small = self.config["n_fft_small"]
self.n_mels_large = self.config["n_mels_large"]
self.n_mels_med = self.config["n_mels_med"]
self.n_mels_small = self.config["n_mels_small"]
self.num_samples = self.sample_rate * self.chunk_len

self.noise_ratio = noise_ratio
Expand Down Expand Up @@ -129,28 +131,39 @@ def __init__(
self.register_buffer(f"applause_{i}", applause)
self.num_applause += 1

self.spec_transform = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft,
self.spec_transform_large = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft_large,
hop_length=self.config["hop_len"],
)
self.mel_transform = torchaudio.transforms.MelScale(
n_mels=self.n_mels,
self.mel_transform_large = torchaudio.transforms.MelScale(
n_mels=self.n_mels_large,
sample_rate=self.sample_rate,
n_stft=self.n_fft // 2 + 1,
n_stft=self.n_fft_large // 2 + 1,
f_min=30,
f_max=8000,
)
self.spec_transform_reduced = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft_reduced,
self.spec_transform_med = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft_med,
hop_length=self.config["hop_len"],
)
self.mel_transform_reduced = torchaudio.transforms.MelScale(
n_mels=self.n_mels_reduced,
self.mel_transform_med = torchaudio.transforms.MelScale(
n_mels=self.n_mels_med,
sample_rate=self.sample_rate,
n_stft=self.n_fft_reduced // 2 + 1,
n_stft=self.n_fft_med // 2 + 1,
f_min=30,
f_max=8000,
)
self.spec_transform_small = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft_small,
hop_length=self.config["hop_len"],
)
self.mel_transform_small = torchaudio.transforms.MelScale(
n_mels=self.n_mels_small,
sample_rate=self.sample_rate,
n_stft=self.n_fft_small // 2 + 1,
f_min=30,
f_max=4000,
)
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.TimeMasking(
time_mask_param=self.time_mask_param,
Expand Down Expand Up @@ -370,28 +383,43 @@ def norm_mel(self, mel_spec: torch.Tensor):
def log_mel(
self, wav: torch.Tensor, shift: int | None = None, detune: bool = False
):
spec = self.spec_transform(wav)[..., :-1]
spec_reduced = self.spec_transform_reduced(wav)[..., :-1]
spec_large = self.spec_transform_large(wav)[..., :-1]
spec_med = self.spec_transform_med(wav)[..., :-1]
spec_small = self.spec_transform_small(wav)[..., :-1]

if shift is not None and shift != 0:
spec = self.shift_spec(spec, shift)
spec_reduced = self.shift_spec(spec_reduced, shift)
spec_large = self.shift_spec(spec_large, shift)
spec_med = self.shift_spec(spec_med, shift)
spec_small = self.shift_spec(spec_small, shift)
elif detune is True:
# Don't detune and spec shift at the same time
if random.random() < self.detune_ratio:
detune_shift = random.uniform(
-self.detune_max_shift, self.detune_max_shift
)
spec = self.detune_spec(spec, detune_shift=detune_shift)
spec_reduced = self.detune_spec(
spec_reduced, detune_shift=detune_shift
spec_large = self.detune_spec(
spec_large,
detune_shift=detune_shift,
)
spec_med = self.detune_spec(
spec_med,
detune_shift=detune_shift,
)
spec_small = self.detune_spec(
spec_small,
detune_shift=detune_shift,
)

mel_spec = self.mel_transform(spec)
mel_spec_reduced = self.mel_transform_reduced(spec_reduced)
mel_spec_large = self.mel_transform_large(spec_large)
mel_spec_med = self.mel_transform_med(spec_med)
mel_spec_small = self.mel_transform_small(spec_small)

# Norm
concat_mel = torch.cat((mel_spec, mel_spec_reduced), dim=1)
concat_mel = torch.cat(
(mel_spec_large, mel_spec_med, mel_spec_small),
# (mel_spec_large, mel_spec_small),
dim=1,
)
log_mel = self.norm_mel(concat_mel)

return log_mel
Expand Down
10 changes: 6 additions & 4 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
},
"audio": {
"sample_rate": 16000,
"n_fft": 2048,
"n_fft_reduced": 800,
"n_fft_large": 4096,
"n_fft_med": 2048,
"n_fft_small": 768,
"hop_len": 160,
"chunk_len": 30,
"n_mels": 384,
"n_mels_reduced": 128
"n_mels_large": 384,
"n_mels_med": 256,
"n_mels_small": 128
},
"data": {
"stride_factor": 15,
Expand Down
11 changes: 11 additions & 0 deletions config/models/medium-triple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"n_mels": 768,
"n_audio_ctx": 1500,
"n_audio_state": 768,
"n_audio_head": 12,
"n_audio_layer": 4,
"n_text_ctx": 4096,
"n_text_state": 768,
"n_text_head": 12,
"n_text_layer": 4
}

0 comments on commit 0d7badb

Please sign in to comment.