diff --git a/amt/audio.py b/amt/audio.py index fff4fb2..27e1e26 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -188,7 +188,7 @@ def __init__( reverb_factor: int = 1, min_snr: int = 10, max_snr: int = 40, - max_pitch_shift: int = 4, + max_pitch_shift: int = 5, ): super().__init__() self.tokenizer = AmtTokenizer() @@ -232,10 +232,10 @@ def __init__( # Spec aug self.spec_aug = torch.nn.Sequential( torchaudio.transforms.FrequencyMasking( - freq_mask_param=10, iid_masks=True + freq_mask_param=15, iid_masks=True ), torchaudio.transforms.TimeMasking( - time_mask_param=100, iid_masks=True + time_mask_param=500, iid_masks=True ), ) diff --git a/amt/train.py b/amt/train.py index ea6d7d1..c79a6a2 100644 --- a/amt/train.py +++ b/amt/train.py @@ -339,7 +339,8 @@ def train_loop( wav, src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz) with torch.no_grad(): - mel, (src, tgt) = audio_transform.forward(wav, src, tgt) + mel = audio_transform.mel(wav) + # mel, (src, tgt) = audio_transform.forward(wav, src, tgt) logits = model(mel, src) # (b_sz, s_len, v_sz) logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss loss = loss_fn(logits, tgt) diff --git a/config/config.json b/config/config.json index be4cfd7..67c407e 100644 --- a/config/config.json +++ b/config/config.json @@ -17,7 +17,7 @@ "n_mels": 256 }, "data": { - "stride_factor": 1, + "stride_factor": 3, "max_seq_len": 4096 } } \ No newline at end of file diff --git a/tests/test_data/maestro.wav b/tests/test_data/maestro.wav new file mode 100644 index 0000000..ad2279b Binary files /dev/null and b/tests/test_data/maestro.wav differ