diff --git a/amt/assets/mel_filters.npz b/amt/assets/mel_filters.npz deleted file mode 100644 index c57535f..0000000 Binary files a/amt/assets/mel_filters.npz and /dev/null differ diff --git a/amt/audio.py b/amt/audio.py index 7038d17..599885c 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -4,7 +4,9 @@ import random import torch import torchaudio +import torch.nn.functional as F import torchaudio.functional as AF +import numpy as np from amt.config import load_config from amt.tokenizer import AmtTokenizer @@ -21,6 +23,32 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select( + dim=axis, index=torch.arange(length, device=array.device) + ) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad( + array, [pad for sizes in pad_widths[::-1] for pad in sizes] + ) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array # Refactor default params are stored in config.json class AudioTransform(torch.nn.Module): diff --git a/amt/data.py b/amt/data.py index f4103ae..cde420a 100644 --- a/amt/data.py +++ b/amt/data.py @@ -106,15 +106,18 @@ def get_wav_mid_segments( def pianoteq_cmd_fn(mid_path: str, wav_path: str): presets = [ - "C. Bechstein", - "C. Bechstein Close Mic", - "C. Bechstein Under Lid", - "C. Bechstein 440", - "C. Bechstein Recording", - "C. Bechstein Werckmeister III", - "C. Bechstein Neidhardt III", - "C. Bechstein mesotonic", - "C. Bechstein well tempered", + "C. Bechstein DG Prelude", + "C. Bechstein DG Sweet", + "C. Bechstein DG Felt I", + "C. Bechstein DG Felt II", + "C. Bechstein DG D 282", + "C. Bechstein DG Recording 1", + "C. Bechstein DG Recording 2", + "C. Bechstein DG Recording 3", + "C. Bechstein DG Cinematic", + "C. Bechstein DG Snappy", + "C. Bechstein DG Venue", + "C. Bechstein DG Player", "HB Steinway D Blues", "HB Steinway D Pop", "HB Steinway D New Age", @@ -137,8 +140,6 @@ def pianoteq_cmd_fn(mid_path: str, wav_path: str): "HB Steinway D Cabaret", "HB Steinway D Bright", "HB Steinway D Hyper Bright", - "HB Steinway D Prepared", - "HB Steinway D Honky Tonk", ] preset = random.choice(presets)