Skip to content

Commit

Permalink
add data aug, inference broken
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Feb 28, 2024
1 parent 4d12799 commit 274115d
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 43 deletions.
200 changes: 199 additions & 1 deletion amt/audio.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
"""Contains code taken from https://github.com/openai/whisper"""

import os
import random
import torch
import numpy as np
import torchaudio
import torch.nn.functional as F
import torchaudio.functional as AF
import numpy as np

from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Optional, Union

from amt.config import load_config
from amt.tokenizer import AmtTokenizer

# hard-coded audio hyperparameters
config = load_config()["audio"]
Expand Down Expand Up @@ -176,3 +180,197 @@ def log_mel_spectrogram(
log_spec = (log_spec + 4.0) / 4.0

return log_spec


class AudioTransform(torch.nn.Module):
def __init__(
self,
reverb_factor: int = 1,
min_snr: int = 10,
max_snr: int = 40,
max_pitch_shift: int = 4,
):
super().__init__()
self.tokenizer = AmtTokenizer()
self.reverb_factor = reverb_factor
self.min_snr = min_snr
self.max_snr = max_snr
self.max_pitch_shift = max_pitch_shift

self.config = load_config()["audio"]
self.sample_rate = self.config["sample_rate"]
self.chunk_len = self.config["chunk_len"]
self.num_samples = self.sample_rate * self.chunk_len

# Audio aug
impulse_paths = self._get_paths(
os.path.join(os.path.dirname(__file__), "assets", "impulse")
)
noise_paths = self._get_paths(
os.path.join(os.path.dirname(__file__), "assets", "noise")
)

# Register impulses and noises as buffers
self.num_impulse = 0
for i, impulse in enumerate(self._get_impulses(impulse_paths)):
self.register_buffer(f"impulse_{i}", impulse)
self.num_impulse += 1

self.num_noise = 0
for i, noise in enumerate(self._get_noise(noise_paths)):
self.register_buffer(f"noise_{i}", noise)
self.num_noise += 1

# Mel-spec
self.melspec_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sample_rate,
n_fft=self.config["n_fft"],
hop_length=self.config["hop_len"],
n_mels=self.config["n_mels"],
)

# Spec aug
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.FrequencyMasking(
freq_mask_param=10, iid_masks=True
),
torchaudio.transforms.TimeMasking(
time_mask_param=100, iid_masks=True
),
)

def _get_paths(self, dir_path):
return [
os.path.join(dir_path, f)
for f in os.listdir(dir_path)
if os.path.isfile(os.path.join(dir_path, f))
]

def _get_impulses(self, impulse_paths: list):
impulses = [torchaudio.load(path) for path in impulse_paths]
impulses = [
AF.resample(
waveform=wav, orig_freq=sr, new_freq=config["sample_rate"]
).mean(0, keepdim=True)[:, : 5 * self.sample_rate]
for wav, sr in impulses
]
return [
(wav) / (torch.linalg.vector_norm(wav, ord=2)) for wav in impulses
]

def _get_noise(self, noise_paths: list):
noises = [torchaudio.load(path) for path in noise_paths]
noises = [
AF.resample(
waveform=wav, orig_freq=sr, new_freq=config["sample_rate"]
).mean(0, keepdim=True)[:, : self.num_samples]
for wav, sr in noises
]

for wav in noises:
assert wav.shape[-1] == self.num_samples, "noise wav too short"

return noises

def apply_reverb(self, wav: torch.Tensor):
# wav: (bz, L)
batch_size, _ = wav.shape

reverb_strength = (
torch.Tensor([random.uniform(0, 1) for _ in range(batch_size)])
.unsqueeze(-1)
.to(wav.device)
)
reverb_type = random.randint(0, self.num_impulse - 1)
impulse = getattr(self, f"impulse_{reverb_type}")

reverb = AF.fftconvolve(wav, impulse, mode="full")[
:, : self.num_samples
]
if self.reverb_factor > 1:
for _ in range(self.reverb_factor - 1):
reverb = AF.fftconvolve(reverb, impulse, mode="full")[
: self.num_samples
]

res = (reverb_strength * reverb) + ((1 - reverb_strength) * wav)

return res

def apply_noise(self, wav: torch.tensor):
batch_size, _ = wav.shape

snr_dbs = torch.tensor(
[
random.randint(self.min_snr, self.max_snr)
for _ in range(batch_size)
]
).to(wav.device)
noise_type = random.randint(0, self.num_noise - 1)
noise = getattr(self, f"noise_{noise_type}")

return AF.add_noise(waveform=wav, noise=noise, snr=snr_dbs)

def aug_pitch(self, wav: torch.Tensor, *seqs: torch.Tensor):
shift = random.randint(-self.max_pitch_shift, self.max_pitch_shift)
shift = 1

if seqs:
for seq in seqs:
assert seq.shape[0] == wav.shape[0]

if shift == 0:
return wav, [seq for seq in seqs]
else:
wav_aug = AF.pitch_shift(
waveform=wav,
sample_rate=self.sample_rate,
n_steps=shift,
n_fft=512,
)
return wav_aug, [
self.tokenizer.pitch_aug(seq, shift) for seq in seqs
]
else:
if shift == 0:
return wav
else:
return AF.pitch_shift(
waveform=wav,
sample_rate=self.sample_rate,
n_steps=shift,
n_fft=512,
)

def aug_wav(self, wav: torch.Tensor):
return self.apply_reverb(self.apply_noise(wav))

def mel(self, wav: torch.Tensor):
mel_spec = self.melspec_transform(wav)[..., :-1]

log_spec = torch.clamp(mel_spec, min=1e-10).log10()

max_over_mels = log_spec.max(dim=1, keepdim=True)[0]
max_log_spec = max_over_mels.max(dim=2, keepdim=True)[0]
log_spec = torch.maximum(log_spec, max_log_spec - 8.0)

log_spec = (log_spec + 4.0) / 4.0

return log_spec

def forward(self, wav: torch.Tensor, *seqs: torch.Tensor):
if seqs:
wav, seqs = self.aug_pitch(wav, *seqs)
else:
wav = self.aug_pitch(wav)

if random.random() < 0.2:
if seqs:
return self.mel(self.aug_wav(wav)), seqs
else:
return self.mel(self.aug_wav(wav))
else:
if seqs:
return self.spec_aug(self.mel(self.aug_wav(wav))), seqs
else:
return self.spec_aug(self.mel(self.aug_wav(wav)))
66 changes: 34 additions & 32 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,59 +3,61 @@
import shutil
import orjson
import torch
import torchaudio

from multiprocessing import Pool

from aria.data.midi import MidiDict
from amt.tokenizer import AmtTokenizer
from amt.config import load_config
from amt.audio import (
log_mel_spectrogram,
pad_or_trim,
N_FRAMES,
)
from amt.audio import pad_or_trim

config = load_config()
STRIDE_FACTOR = config["data"]["stride_factor"]


def get_features(
def get_wav_mid_segments(
audio_path: str, mid_path: str = "", return_json: bool = False
):
"""This function yields tuples of matched log mel spectrograms and
tokenized sequences (np.array, list). If it is given only an audio path
then it will return an empty list for the mid_feature
"""
tokenizer = AmtTokenizer()
n_mels = config["audio"]["n_mels"]
config = load_config()
stride_factor = config["data"]["stride_factor"]
sample_rate = config["audio"]["sample_rate"]
chunk_len = config["audio"]["chunk_len"]
num_samples = sample_rate * chunk_len
samples_per_ms = sample_rate // 1000

if not os.path.isfile(audio_path):
return None

# Load midi if required
if mid_path == "":
pass
midi_dict = None
elif not os.path.isfile(mid_path):
return None

try:
log_spec = log_mel_spectrogram(audio=audio_path, n_mels=n_mels)
if mid_path != "":
midi_dict = MidiDict.from_midi(mid_path)
else:
midi_dict = None
except Exception as e:
print("Failed to convert files into features")
raise e

_, total_frames = log_spec.shape
else:
midi_dict = MidiDict.from_midi(mid_path)

# Load audio
wav, sr = torchaudio.load(audio_path)
if sr != sample_rate:
wav = torchaudio.functional.resample(
waveform=wav,
orig_freq=sr,
new_freq=sample_rate,
).mean(0)

# Create features
total_samples = wav.shape[-1]
res = []
for start_frame in range(0, total_frames, N_FRAMES // STRIDE_FACTOR):
audio_feature = pad_or_trim(log_spec[:, start_frame:], length=N_FRAMES)
for idx in range(0, total_samples, num_samples // stride_factor):
audio_feature = pad_or_trim(wav[idx:], length=num_samples)
if midi_dict is not None:
mid_feature = tokenizer._tokenize_midi_dict(
midi_dict=midi_dict,
start_ms=start_frame * 10,
end_ms=(start_frame + N_FRAMES) * 10,
start_ms=idx // samples_per_ms,
end_ms=(idx + num_samples) / samples_per_ms,
)
else:
mid_feature = []
Expand All @@ -70,7 +72,7 @@ def get_features(

def write_features(args):
audio_path, mid_path, save_path = args
features = get_features(
features = get_wav_mid_segments(
audio_path=audio_path,
mid_path=mid_path,
return_json=False,
Expand All @@ -79,10 +81,10 @@ def write_features(args):
proc_save_path = os.path.join(dirname, str(os.getpid()) + basename)

with open(proc_save_path, mode="ab") as file:
for mel, seq in features:
for wav, seq in features:
file.write(
orjson.dumps(
mel.numpy(),
wav.numpy(),
option=orjson.OPT_SERIALIZE_NUMPY,
)
)
Expand Down Expand Up @@ -126,7 +128,7 @@ def _format(tok):
self.file_mmap.seek(self.index[idx])

# Load data from line
mel = torch.tensor(orjson.loads(self.file_mmap.readline()))
wav = torch.tensor(orjson.loads(self.file_mmap.readline()))
_seq = orjson.loads(self.file_mmap.readline())

_seq = [_format(tok) for tok in _seq] # Format seq
Expand All @@ -141,7 +143,7 @@ def _format(tok):
seq_len=self.config["max_seq_len"],
)

return mel, self.tokenizer.encode(src), self.tokenizer.encode(tgt)
return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt)

def _build_index(self):
self.file_mmap.seek(0)
Expand Down
17 changes: 17 additions & 0 deletions amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,20 @@ def round_to_base(n, base=150):
return res

return msg_mixup

def pitch_aug(self, seqs, shift: int):
"""This functions acts on tensors and is used in audio.AudioFeature"""
batch_size, seq_len = seqs.shape

for i in range(batch_size):
for j in range(seq_len):
tok = self.id_to_tok[seqs[i, j].item()]
if type(tok) is tuple and tok[0] in {"on", "off"}:
msg_type, pitch = tok
seqs[i, j] = self.tok_to_id.get(
(msg_type, pitch + shift), self.unk_tok
)
elif tok == self.pad_tok:
break

return seqs
Loading

0 comments on commit 274115d

Please sign in to comment.