-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement dataset and update tests (#5)
- Loading branch information
Showing
15 changed files
with
437 additions
and
32 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
"""Contains code taken from https://github.com/openai/whisper""" | ||
|
||
import os | ||
import torch | ||
import numpy as np | ||
import torch.nn.functional as F | ||
|
||
from functools import lru_cache | ||
from subprocess import CalledProcessError, run | ||
from typing import Optional, Union | ||
|
||
from amt.config import load_config | ||
|
||
# hard-coded audio hyperparameters | ||
config = load_config()["audio"] | ||
SAMPLE_RATE = config["sample_rate"] | ||
N_FFT = config["n_fft"] | ||
HOP_LENGTH = config["hop_len"] | ||
CHUNK_LENGTH = config["chunk_len"] | ||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk | ||
N_FRAMES = N_SAMPLES // HOP_LENGTH # 3000 frames in a mel spectrogram input | ||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 | ||
FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame | ||
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token | ||
|
||
|
||
def load_audio(file: str, sr: int = SAMPLE_RATE): | ||
""" | ||
Open an audio file and read as mono waveform, resampling as necessary | ||
Parameters | ||
---------- | ||
file: str | ||
The audio file to open | ||
sr: int | ||
The sample rate to resample the audio if necessary | ||
Returns | ||
------- | ||
A NumPy array containing the audio waveform, in float32 dtype. | ||
""" | ||
|
||
# This launches a subprocess to decode audio while down-mixing | ||
# and resampling as necessary. Requires the ffmpeg CLI in PATH. | ||
# fmt: off | ||
cmd = [ | ||
"ffmpeg", | ||
"-nostdin", | ||
"-threads", "0", | ||
"-i", file, | ||
"-f", "s16le", | ||
"-ac", "1", | ||
"-acodec", "pcm_s16le", | ||
"-ar", str(sr), | ||
"-" | ||
] | ||
|
||
# chat-gpt says that this will work for reading mp3 ?? not tested | ||
# cmd = [ | ||
# "ffmpeg", | ||
# "-nostdin", | ||
# "-threads", "0", | ||
# "-i", file, | ||
# "-ac", "1", | ||
# "-ar", str(sr), | ||
# "-" | ||
# ] | ||
|
||
# fmt: on | ||
try: | ||
out = run(cmd, capture_output=True, check=True).stdout | ||
except CalledProcessError as e: | ||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e | ||
|
||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 | ||
|
||
|
||
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): | ||
""" | ||
Pad or trim the audio array to N_SAMPLES, as expected by the encoder. | ||
""" | ||
if torch.is_tensor(array): | ||
if array.shape[axis] > length: | ||
array = array.index_select( | ||
dim=axis, index=torch.arange(length, device=array.device) | ||
) | ||
|
||
if array.shape[axis] < length: | ||
pad_widths = [(0, 0)] * array.ndim | ||
pad_widths[axis] = (0, length - array.shape[axis]) | ||
array = F.pad( | ||
array, [pad for sizes in pad_widths[::-1] for pad in sizes] | ||
) | ||
else: | ||
if array.shape[axis] > length: | ||
array = array.take(indices=range(length), axis=axis) | ||
|
||
if array.shape[axis] < length: | ||
pad_widths = [(0, 0)] * array.ndim | ||
pad_widths[axis] = (0, length - array.shape[axis]) | ||
array = np.pad(array, pad_widths) | ||
|
||
return array | ||
|
||
|
||
@lru_cache(maxsize=None) | ||
def mel_filters(device, n_mels: int) -> torch.Tensor: | ||
""" | ||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram. | ||
Allows decoupling librosa dependency; saved using: | ||
np.savez_compressed( | ||
"mel_filters.npz", | ||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), | ||
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), | ||
) | ||
""" | ||
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" | ||
|
||
filters_path = os.path.join( | ||
os.path.dirname(__file__), "assets", "mel_filters.npz" | ||
) | ||
with np.load(filters_path, allow_pickle=False) as f: | ||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) | ||
|
||
|
||
def log_mel_spectrogram( | ||
audio: Union[str, np.ndarray, torch.Tensor], | ||
n_mels: int = 80, | ||
padding: int = 0, | ||
device: Optional[Union[str, torch.device]] = None, | ||
): | ||
""" | ||
Compute the log-Mel spectrogram of | ||
Parameters | ||
---------- | ||
audio: Union[str, np.ndarray, torch.Tensor], shape = (*) | ||
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz | ||
n_mels: int | ||
The number of Mel-frequency filters, only 80 is supported | ||
padding: int | ||
Number of zero samples to pad to the right | ||
device: Optional[Union[str, torch.device]] | ||
If given, the audio tensor is moved to this device before STFT | ||
Returns | ||
------- | ||
torch.Tensor, shape = (80, n_frames) | ||
A Tensor that contains the Mel spectrogram | ||
""" | ||
if not torch.is_tensor(audio): | ||
if isinstance(audio, str): | ||
audio = load_audio(audio) | ||
audio = torch.from_numpy(audio) | ||
|
||
if device is not None: | ||
audio = audio.to(device) | ||
if padding > 0: | ||
audio = F.pad(audio, (0, padding)) | ||
window = torch.hann_window(N_FFT).to(audio.device) | ||
stft = torch.stft( | ||
audio, N_FFT, HOP_LENGTH, window=window, return_complex=True | ||
) | ||
magnitudes = stft[..., :-1].abs() ** 2 | ||
|
||
filters = mel_filters(audio.device, n_mels) | ||
mel_spec = filters @ magnitudes | ||
|
||
log_spec = torch.clamp(mel_spec, min=1e-10).log10() | ||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | ||
log_spec = (log_spec + 4.0) / 4.0 | ||
|
||
return log_spec |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import mmap | ||
import os | ||
import json | ||
import jsonlines | ||
import torch | ||
|
||
from typing import Callable | ||
from multiprocessing import Pool | ||
|
||
from aria.data.midi import MidiDict | ||
from amt.tokenizer import AmtTokenizer | ||
from amt.config import load_config | ||
from amt.audio import ( | ||
log_mel_spectrogram, | ||
pad_or_trim, | ||
N_FRAMES, | ||
) | ||
|
||
config = load_config()["data"] | ||
STRIDE_FACTOR = config["stride_factor"] | ||
|
||
|
||
def get_features(audio_path: str, mid_path: str): | ||
"""This function yields tuples of matched log mel spectrograms and | ||
tokenized sequences (np.array, list). | ||
""" | ||
tokenizer = AmtTokenizer() | ||
|
||
if not os.path.isfile(audio_path) or not os.path.isfile(mid_path): | ||
return None | ||
|
||
try: | ||
midi_dict = MidiDict.from_midi(mid_path) | ||
log_spec = log_mel_spectrogram(audio=audio_path) | ||
except Exception as e: | ||
print("Failed to convert files into features") | ||
return None | ||
|
||
_, total_frames = log_spec.shape | ||
res = [] | ||
for start_frame in range(0, total_frames, N_FRAMES // STRIDE_FACTOR): | ||
audio_feature = pad_or_trim(log_spec[:, start_frame:], length=N_FRAMES) | ||
mid_feature = tokenizer._tokenize_midi_dict( | ||
midi_dict=midi_dict, | ||
start_ms=start_frame * 10, | ||
end_ms=(start_frame + N_FRAMES) * 10, | ||
) | ||
res.append((audio_feature, mid_feature)) | ||
|
||
return res | ||
|
||
|
||
def get_features_mp(args): | ||
"""Multiprocessing wrapper for get_features""" | ||
res = get_features(*args) | ||
if res is None: | ||
return False, None | ||
else: | ||
return True, res | ||
|
||
|
||
class AmtDataset(torch.utils.data.Dataset): | ||
def __init__(self, load_path: str): | ||
self.tokenizer = AmtTokenizer(return_tensors=True) | ||
self.aug_fn = self.tokenizer.export_msg_mixup() | ||
self.file_buff = open(load_path, mode="r") | ||
self.file_mmap = mmap.mmap( | ||
self.file_buff.fileno(), 0, access=mmap.ACCESS_READ | ||
) | ||
self.index = self._build_index() | ||
|
||
def close(self): | ||
if self.file_buff: | ||
self.file_buff.close() | ||
if self.file_mmap: | ||
self.file_mmap.close() | ||
|
||
def __del__(self): | ||
self.close() | ||
|
||
def __len__(self): | ||
return len(self.index) | ||
|
||
def __getitem__(self, idx: int): | ||
def _format(tok): | ||
# This is required because json formats tuples into lists | ||
if isinstance(tok, list): | ||
return tuple(tok) | ||
return tok | ||
|
||
self.file_mmap.seek(self.index[idx]) | ||
|
||
# This isn't going to load properly | ||
spec, seq = json.loads(self.file_mmap.readline()) # Load data from line | ||
|
||
spec = torch.tensor(spec) # Format spectrogram into tensor | ||
seq = [_format(tok) for tok in seq] # Format seq | ||
seq = self.aug_fn(seq) # Data augmentation | ||
|
||
src = seq | ||
tgt = seq[1:] + [self.tokenizer.pad_tok] | ||
|
||
return spec, self.tokenizer.encode(src), self.tokenizer.encode(tgt) | ||
|
||
def _build_index(self): | ||
self.file_mmap.seek(0) | ||
index = [] | ||
while True: | ||
pos = self.file_mmap.tell() | ||
line_buffer = self.file_mmap.readline() | ||
if line_buffer == b"": | ||
break | ||
else: | ||
index.append(pos) | ||
|
||
return index | ||
|
||
@classmethod | ||
def build( | ||
cls, | ||
matched_load_paths: list[tuple[str, str]], | ||
save_path: str, | ||
audio_aug_hook: Callable | None = None, | ||
): | ||
def _get_features(_matched_load_paths: list): | ||
with Pool(4) as pool: | ||
results = pool.imap(get_features_mp, _matched_load_paths) | ||
num_paths = len(_matched_load_paths) | ||
for idx, (success, res) in enumerate(results): | ||
if idx % 50 == 0 and idx != 0: | ||
print(f"Processed audio-mid pairs: {idx}/{num_paths}") | ||
|
||
if success == False: | ||
continue | ||
for _audio_feature, _mid_feature in res: | ||
yield _audio_feature.tolist(), _mid_feature | ||
|
||
with jsonlines.open(save_path, mode="w") as writer: | ||
for audio_feature, mid_feature in _get_features(matched_load_paths): | ||
writer.write([audio_feature, mid_feature]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.