Skip to content

Commit

Permalink
add more aug
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Feb 29, 2024
1 parent 274115d commit e10dca0
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
6 changes: 3 additions & 3 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(
reverb_factor: int = 1,
min_snr: int = 10,
max_snr: int = 40,
max_pitch_shift: int = 4,
max_pitch_shift: int = 5,
):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand Down Expand Up @@ -232,10 +232,10 @@ def __init__(
# Spec aug
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.FrequencyMasking(
freq_mask_param=10, iid_masks=True
freq_mask_param=15, iid_masks=True
),
torchaudio.transforms.TimeMasking(
time_mask_param=100, iid_masks=True
time_mask_param=500, iid_masks=True
),
)

Expand Down
3 changes: 2 additions & 1 deletion amt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ def train_loop(

wav, src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz)
with torch.no_grad():
mel, (src, tgt) = audio_transform.forward(wav, src, tgt)
mel = audio_transform.mel(wav)
# mel, (src, tgt) = audio_transform.forward(wav, src, tgt)
logits = model(mel, src) # (b_sz, s_len, v_sz)
logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss
loss = loss_fn(logits, tgt)
Expand Down
2 changes: 1 addition & 1 deletion config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"n_mels": 256
},
"data": {
"stride_factor": 1,
"stride_factor": 3,
"max_seq_len": 4096
}
}
Binary file added tests/test_data/maestro.wav
Binary file not shown.

0 comments on commit e10dca0

Please sign in to comment.