Skip to content

Commit

Permalink
Implement dataset and update tests (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad authored Feb 17, 2024
1 parent 5835a26 commit 4a9a70d
Show file tree
Hide file tree
Showing 15 changed files with 437 additions and 32 deletions.
Binary file added amt/assets/mel_filters.npz
Binary file not shown.
178 changes: 178 additions & 0 deletions amt/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Contains code taken from https://github.com/openai/whisper"""

import os
import torch
import numpy as np
import torch.nn.functional as F

from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Optional, Union

from amt.config import load_config

# hard-coded audio hyperparameters
config = load_config()["audio"]
SAMPLE_RATE = config["sample_rate"]
N_FFT = config["n_fft"]
HOP_LENGTH = config["hop_len"]
CHUNK_LENGTH = config["chunk_len"]
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = N_SAMPLES // HOP_LENGTH # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token


def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""

# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]

# chat-gpt says that this will work for reading mp3 ?? not tested
# cmd = [
# "ffmpeg",
# "-nostdin",
# "-threads", "0",
# "-i", file,
# "-ac", "1",
# "-ar", str(sr),
# "-"
# ]

# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0


def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(
array, [pad for sizes in pad_widths[::-1] for pad in sizes]
)
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)

return array


@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"

filters_path = os.path.join(
os.path.dirname(__file__), "assets", "mel_filters.npz"
)
with np.load(filters_path, allow_pickle=False) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)


def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)

if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(
audio, N_FFT, HOP_LENGTH, window=window, return_complex=True
)
magnitudes = stft[..., :-1].abs() ** 2

filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0

return log_spec
140 changes: 140 additions & 0 deletions amt/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import mmap
import os
import json
import jsonlines
import torch

from typing import Callable
from multiprocessing import Pool

from aria.data.midi import MidiDict
from amt.tokenizer import AmtTokenizer
from amt.config import load_config
from amt.audio import (
log_mel_spectrogram,
pad_or_trim,
N_FRAMES,
)

config = load_config()["data"]
STRIDE_FACTOR = config["stride_factor"]


def get_features(audio_path: str, mid_path: str):
"""This function yields tuples of matched log mel spectrograms and
tokenized sequences (np.array, list).
"""
tokenizer = AmtTokenizer()

if not os.path.isfile(audio_path) or not os.path.isfile(mid_path):
return None

try:
midi_dict = MidiDict.from_midi(mid_path)
log_spec = log_mel_spectrogram(audio=audio_path)
except Exception as e:
print("Failed to convert files into features")
return None

_, total_frames = log_spec.shape
res = []
for start_frame in range(0, total_frames, N_FRAMES // STRIDE_FACTOR):
audio_feature = pad_or_trim(log_spec[:, start_frame:], length=N_FRAMES)
mid_feature = tokenizer._tokenize_midi_dict(
midi_dict=midi_dict,
start_ms=start_frame * 10,
end_ms=(start_frame + N_FRAMES) * 10,
)
res.append((audio_feature, mid_feature))

return res


def get_features_mp(args):
"""Multiprocessing wrapper for get_features"""
res = get_features(*args)
if res is None:
return False, None
else:
return True, res


class AmtDataset(torch.utils.data.Dataset):
def __init__(self, load_path: str):
self.tokenizer = AmtTokenizer(return_tensors=True)
self.aug_fn = self.tokenizer.export_msg_mixup()
self.file_buff = open(load_path, mode="r")
self.file_mmap = mmap.mmap(
self.file_buff.fileno(), 0, access=mmap.ACCESS_READ
)
self.index = self._build_index()

def close(self):
if self.file_buff:
self.file_buff.close()
if self.file_mmap:
self.file_mmap.close()

def __del__(self):
self.close()

def __len__(self):
return len(self.index)

def __getitem__(self, idx: int):
def _format(tok):
# This is required because json formats tuples into lists
if isinstance(tok, list):
return tuple(tok)
return tok

self.file_mmap.seek(self.index[idx])

# This isn't going to load properly
spec, seq = json.loads(self.file_mmap.readline()) # Load data from line

spec = torch.tensor(spec) # Format spectrogram into tensor
seq = [_format(tok) for tok in seq] # Format seq
seq = self.aug_fn(seq) # Data augmentation

src = seq
tgt = seq[1:] + [self.tokenizer.pad_tok]

return spec, self.tokenizer.encode(src), self.tokenizer.encode(tgt)

def _build_index(self):
self.file_mmap.seek(0)
index = []
while True:
pos = self.file_mmap.tell()
line_buffer = self.file_mmap.readline()
if line_buffer == b"":
break
else:
index.append(pos)

return index

@classmethod
def build(
cls,
matched_load_paths: list[tuple[str, str]],
save_path: str,
audio_aug_hook: Callable | None = None,
):
def _get_features(_matched_load_paths: list):
with Pool(4) as pool:
results = pool.imap(get_features_mp, _matched_load_paths)
num_paths = len(_matched_load_paths)
for idx, (success, res) in enumerate(results):
if idx % 50 == 0 and idx != 0:
print(f"Processed audio-mid pairs: {idx}/{num_paths}")

if success == False:
continue
for _audio_feature, _mid_feature in res:
yield _audio_feature.tolist(), _mid_feature

with jsonlines.open(save_path, mode="w") as writer:
for audio_feature, mid_feature in _get_features(matched_load_paths):
writer.write([audio_feature, mid_feature])
18 changes: 13 additions & 5 deletions amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,24 @@ def _tokenize_midi_dict(
if note_end_ms <= start_ms or note_start_ms >= end_ms: # Skip
continue
elif (
note_start_ms <= start_ms and _pitch not in prev_notes
note_start_ms < start_ms and _pitch not in prev_notes
): # Add to prev notes
prev_notes.append(_pitch)
if note_end_ms < end_ms:
on_off_notes.append(
("off", _pitch, rel_note_end_ms_q, None)
)
else: # Add to on_off_msgs
# Skip notes with no duration
# Skip notes with no duration or duplicate notes
if rel_note_start_ms_q == rel_note_end_ms_q:
continue
elif (
"on",
_pitch,
rel_note_start_ms_q,
velocity_q,
) in on_off_notes:
continue

on_off_notes.append(
("on", _pitch, rel_note_start_ms_q, velocity_q)
Expand Down Expand Up @@ -190,7 +197,7 @@ def _detokenize_midi_dict(self, tokenized_seq: list, len_ms: int):
for tok_1, tok_2, tok_3 in zip(
tokenized_seq[:],
tokenized_seq[1:],
tokenized_seq[2:],
tokenized_seq[2:] + [(None, None)],
):
tok_1_type, tok_1_data = tok_1
tok_2_type, tok_2_data = tok_2
Expand All @@ -210,7 +217,7 @@ def _detokenize_midi_dict(self, tokenized_seq: list, len_ms: int):
# Process note and add to note msgs
note_to_close = notes_to_close.pop(tok_1_data, None)
if note_to_close is None:
print("No 'on' token corresponding to 'off' token")
print(f"No 'on' token corresponding to 'off' token")
continue
else:
_pitch = tok_1_data
Expand Down Expand Up @@ -267,6 +274,7 @@ def export_msg_mixup(self):
def msg_mixup(src: list):
# Reorder prev tokens
res = []
idx = 0
for idx, tok in enumerate(src):
tok_type, tok_data = tok
if tok_type != "prev":
Expand All @@ -279,7 +287,7 @@ def msg_mixup(src: list):
for tok_1, tok_2, tok_3 in zip(
src[idx:],
src[idx + 1 :],
src[idx + 2 :],
src[idx + 2 :] + [(None, None)],
):
tok_1_type, tok_1_data = tok_1
tok_2_type, tok_2_data = tok_2
Expand Down
11 changes: 10 additions & 1 deletion config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,14 @@
"num_steps": 3000,
"step": 10
}
},
"audio": {
"sample_rate": 16000,
"n_fft": 400,
"hop_len": 160,
"chunk_len": 30
},
"data": {
"stride_factor": 1
}
}
}
Loading

0 comments on commit 4a9a70d

Please sign in to comment.