From 01845ee126bec3a95d48821d849a64e551e6a08c Mon Sep 17 00:00:00 2001 From: loubbrad Date: Sun, 25 Feb 2024 18:47:44 +0000 Subject: [PATCH] fix model --- amt/data.py | 29 +++++++++++++++--- amt/inference.py | 9 ++++-- amt/model.py | 68 ++++++++++++++--------------------------- amt/train.py | 5 +-- config/config.json | 2 +- tests/test_tokenizer.py | 6 ++-- 6 files changed, 62 insertions(+), 57 deletions(-) diff --git a/amt/data.py b/amt/data.py index ee87c33..ec90141 100644 --- a/amt/data.py +++ b/amt/data.py @@ -1,5 +1,6 @@ import mmap import os +import logging import json import jsonlines import torch @@ -20,7 +21,27 @@ STRIDE_FACTOR = config["stride_factor"] -def get_features(audio_path: str, mid_path: str | None = None): +def setup_logger(): + # Get logger and reset all handlers + logger = logging.getLogger(__name__) + for h in logger.handlers[:]: + logger.removeHandler(h) + + logger.propagate = False + logger.setLevel(logging.INFO) + formatter = logging.Formatter( + "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s", + ) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return logger + + +def get_features(audio_path: str, mid_path: str = ""): """This function yields tuples of matched log mel spectrograms and tokenized sequences (np.array, list). If it is given only an audio path then it will return an empty list for the mid_feature @@ -30,14 +51,14 @@ def get_features(audio_path: str, mid_path: str | None = None): if not os.path.isfile(audio_path): return None - if mid_path is not None: + if mid_path == "": pass elif not os.path.isfile(mid_path): return None try: log_spec = log_mel_spectrogram(audio=audio_path) - if mid_path is not None: + if mid_path != "": midi_dict = MidiDict.from_midi(mid_path) else: midi_dict = None @@ -49,7 +70,7 @@ def get_features(audio_path: str, mid_path: str | None = None): res = [] for start_frame in range(0, total_frames, N_FRAMES // STRIDE_FACTOR): audio_feature = pad_or_trim(log_spec[:, start_frame:], length=N_FRAMES) - if midi_dict: + if midi_dict is not None: mid_feature = tokenizer._tokenize_midi_dict( midi_dict=midi_dict, start_ms=start_frame * 10, diff --git a/amt/inference.py b/amt/inference.py index 1b807aa..83fd61d 100644 --- a/amt/inference.py +++ b/amt/inference.py @@ -13,6 +13,10 @@ # TODO: Implement this with KV-caching, see the whisper inference file +# Due to the autoregressive nature, a good inference algorithm should use some +# sort of branching to make sure that we don't miss notes, ect... Implement this +# next week -- Exciting problem (checkout other inference algos) + def greedy_sample( model: AmtEncoderDecoder, @@ -43,8 +47,9 @@ def _process_segment( ) ): logits = model.forward(mel=audio_seg, tokens=seq[:, :idx]) - probs = torch.softmax(logits[0, -1], dim=-1) - next_tok_id = torch.multinomial(probs / 0.001, num_samples=1) + next_tok_id = torch.argmax(logits[0, -1], dim=-1) + # probs = torch.softmax(logits[0, -1], dim=-1) + # next_tok_id = torch.argmax(probs, dim=-1) # Debug logging: # print(f"input seq shape: {seq[:, :idx].shape}") diff --git a/amt/model.py b/amt/model.py index 45c6eb8..5e19d7f 100644 --- a/amt/model.py +++ b/amt/model.py @@ -1,5 +1,6 @@ """Contains code modified from https://github.com/openai/whisper""" +import math import numpy as np import torch import torch.nn.functional as F @@ -8,9 +9,6 @@ from dataclasses import dataclass from typing import Dict, Iterable, Optional -# TODO: -# Go through and make this more efficient using flash attention ect... - @dataclass class ModelConfig: @@ -29,40 +27,20 @@ def set_vocab_size(self, vocab_size: int): self.n_vocab = vocab_size -class LayerNorm(nn.LayerNorm): - def forward(self, x: Tensor) -> Tensor: - return super().forward(x.float()).type(x.dtype) - - -class Linear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - return F.linear( - x, - self.weight.to(x.dtype), - None if self.bias is None else self.bias.to(x.dtype), - ) - - -class Conv1d(nn.Conv1d): - def _conv_forward( - self, x: Tensor, weight: Tensor, bias: Optional[Tensor] - ) -> Tensor: - return super()._conv_forward( - x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) - ) - - -def sinusoids(length, channels, max_timescale=10000): +def sinusoids( + length: int, channels: int, max_timescale: float = 10000 +) -> torch.Tensor: """Returns sinusoids for positional embedding""" - assert channels % 2 == 0 - log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + if channels % 2 != 0: + raise ValueError( + f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + ) + log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) inv_timescales = torch.exp( -log_timescale_increment * torch.arange(channels // 2) ) - scaled_time = ( - torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] - ) - return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) + return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) class MultiHeadAttention(nn.Module): @@ -72,10 +50,10 @@ def __init__(self, n_state: int, n_head: int): self.n_head = n_head self.d_head = n_state // n_head - self.query = Linear(n_state, n_state) - self.key = Linear(n_state, n_state, bias=False) - self.value = Linear(n_state, n_state) - self.out = Linear(n_state, n_state) + self.query = nn.Linear(n_state, n_state) + self.key = nn.Linear(n_state, n_state, bias=False) + self.value = nn.Linear(n_state, n_state) + self.out = nn.Linear(n_state, n_state) def forward( self, @@ -170,18 +148,18 @@ def __init__( super().__init__() self.attn = MultiHeadAttention(n_state, n_head) - self.attn_ln = LayerNorm(n_state) + self.attn_ln = nn.LayerNorm(n_state) self.cross_attn = ( MultiHeadAttention(n_state, n_head) if cross_attention else None ) - self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None n_mlp = n_state * 4 self.mlp = nn.Sequential( - Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) + nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state) ) - self.mlp_ln = LayerNorm(n_state) + self.mlp_ln = nn.LayerNorm(n_state) def forward( self, @@ -207,8 +185,8 @@ def __init__( self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int ): super().__init__() - self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) - self.conv2 = Conv1d( + self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d( n_state, n_state, kernel_size=3, stride=2, padding=1 ) self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) @@ -216,7 +194,7 @@ def __init__( self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] ) - self.ln_post = LayerNorm(n_state) + self.ln_post = nn.LayerNorm(n_state) def forward(self, x: Tensor): """ @@ -253,7 +231,7 @@ def __init__( for _ in range(n_layer) ] ) - self.ln = LayerNorm(n_state) + self.ln = nn.LayerNorm(n_state) mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) self.register_buffer("mask", mask, persistent=False) diff --git a/amt/train.py b/amt/train.py index ba91ab4..a6a846e 100644 --- a/amt/train.py +++ b/amt/train.py @@ -172,7 +172,7 @@ def get_pretrain_optim( ): LR = 3e-4 END_RATIO = 0.1 - WARMUP_STEPS = 200 + WARMUP_STEPS = 500 return _get_optim( lr=LR, @@ -210,6 +210,7 @@ def get_dataloaders( num_workers: int, ): logger = get_logger(__name__) + logger.info("Indexing datasets...") train_dataset = AmtDataset(load_path=train_data_path) val_dataset = AmtDataset(load_path=val_data_path) logger.info( @@ -220,7 +221,7 @@ def get_dataloaders( train_dataset, batch_size=batch_size, num_workers=num_workers, - shuffle=True, # Maybe remove + shuffle=True, ) val_dataloader = DataLoader( val_dataset, diff --git a/config/config.json b/config/config.json index f727540..0e77778 100644 --- a/config/config.json +++ b/config/config.json @@ -16,7 +16,7 @@ "chunk_len": 30 }, "data": { - "stride_factor": 1, + "stride_factor": 3, "max_seq_len": 4096 } } \ No newline at end of file diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index d6e5d5f..06b24c0 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -56,7 +56,7 @@ def aug(_midi_dict: MidiDict, _start_ms: int, _end_ms: int): DELTA_MS = 5000 tokenizer = AmtTokenizer() - midi_dict = MidiDict.from_midi("tests/test_data/bach.mid") + midi_dict = MidiDict.from_midi("tests/test_data/maestro2.mid") __end_ms = midi_dict.note_msgs[-1]["data"]["end"] for idx, __start_ms in enumerate(range(0, __end_ms, DELTA_MS)): @@ -86,13 +86,13 @@ def aug(_midi_dict: MidiDict, _start_ms: int, _end_ms: int): tokenized_seq, DELTA_MS ) _mid = _midi_dict.to_midi() - _mid.save(f"tests/test_results/bach_orig.mid") + _mid.save(f"tests/test_results/maestro2_orig.mid") _midi_dict = tokenizer._detokenize_midi_dict( aug_tokenized_seq, DELTA_MS ) _mid = _midi_dict.to_midi() - _mid.save(f"tests/test_results/bach_aug.mid") + _mid.save(f"tests/test_results/maestro2_aug.mid") if __name__ == "__main__":