From 01ac3e253599c4dbbac7cea9d08da724112216c6 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 4 May 2024 20:44:15 +0100 Subject: [PATCH] Add stacked spectrograms (#34) * update README * update README * add fp16 * inference modification * add stacked mels --- amt/audio.py | 57 +++++++++----- amt/data.py | 2 + amt/inference/model.py | 10 ++- amt/inference/transcribe.py | 4 + config/config.json | 4 +- config/models/medium-stacked.json | 11 +++ .../models/{medium-final.json => medium.json} | 4 +- tests/test_data.py | 78 ++++++++++++++----- 8 files changed, 127 insertions(+), 43 deletions(-) create mode 100644 config/models/medium-stacked.json rename config/models/{medium-final.json => medium.json} (77%) diff --git a/amt/audio.py b/amt/audio.py index 28913ab..7bb2a47 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -82,6 +82,10 @@ def __init__( self.config = load_config()["audio"] self.sample_rate = self.config["sample_rate"] self.chunk_len = self.config["chunk_len"] + self.n_fft = self.config["n_fft"] + self.n_fft_reduced = self.config["n_fft_reduced"] + self.n_mels = self.config["n_mels"] + self.n_mels_reduced = self.config["n_mels_reduced"] self.num_samples = self.sample_rate * self.chunk_len self.noise_ratio = noise_ratio @@ -95,7 +99,7 @@ def __init__( self.spec_aug_ratio = spec_aug_ratio self.time_mask_param = 2500 - self.freq_mask_param = 15 + self.freq_mask_param = 0 self.reduction_resample_rate = 6000 # Audio aug @@ -126,13 +130,26 @@ def __init__( self.num_applause += 1 self.spec_transform = torchaudio.transforms.Spectrogram( - n_fft=self.config["n_fft"], + n_fft=self.n_fft, hop_length=self.config["hop_len"], ) self.mel_transform = torchaudio.transforms.MelScale( - n_mels=self.config["n_mels"], - sample_rate=self.config["sample_rate"], - n_stft=self.config["n_fft"] // 2 + 1, + n_mels=self.n_mels, + sample_rate=self.sample_rate, + n_stft=self.n_fft // 2 + 1, + f_min=30, + f_max=8000, + ) + self.spec_transform_reduced = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft_reduced, + hop_length=self.config["hop_len"], + ) + self.mel_transform_reduced = torchaudio.transforms.MelScale( + n_mels=self.n_mels_reduced, + sample_rate=self.sample_rate, + n_stft=self.n_fft_reduced // 2 + 1, + f_min=30, + f_max=8000, ) self.spec_aug = torch.nn.Sequential( torchaudio.transforms.TimeMasking( @@ -315,16 +332,9 @@ def shift_spec(self, specs: torch.Tensor, shift: int | float): return shifted_specs - def detune_spec(self, specs: torch.Tensor): - if random.random() < self.detune_ratio: - detune_shift = random.uniform( - -self.detune_max_shift, self.detune_max_shift - ) - detuned_specs = self.shift_spec(specs, shift=detune_shift) - - return (specs + detuned_specs) / 2 - else: - return specs + def detune_spec(self, specs: torch.Tensor, detune_shift: float): + detuned_specs = self.shift_spec(specs, shift=detune_shift) + return (specs + detuned_specs) / 2 def aug_wav(self, wav: torch.Tensor): # This function doesn't apply distortion. If distortion is desired it @@ -361,19 +371,30 @@ def log_mel( self, wav: torch.Tensor, shift: int | None = None, detune: bool = False ): spec = self.spec_transform(wav)[..., :-1] + spec_reduced = self.spec_transform_reduced(wav)[..., :-1] if shift is not None and shift != 0: spec = self.shift_spec(spec, shift) + spec_reduced = self.shift_spec(spec_reduced, shift) elif detune is True: # Don't detune and spec shift at the same time - spec = self.detune_spec(spec) + if random.random() < self.detune_ratio: + detune_shift = random.uniform( + -self.detune_max_shift, self.detune_max_shift + ) + spec = self.detune_spec(spec, detune_shift=detune_shift) + spec_reduced = self.detune_spec( + spec_reduced, detune_shift=detune_shift + ) mel_spec = self.mel_transform(spec) + mel_spec_reduced = self.mel_transform_reduced(spec_reduced) # Norm - log_spec = self.norm_mel(mel_spec) + concat_mel = torch.cat((mel_spec, mel_spec_reduced), dim=1) + log_mel = self.norm_mel(concat_mel) - return log_spec + return log_mel def forward(self, wav: torch.Tensor, shift: int = 0): # Noise, and reverb diff --git a/amt/data.py b/amt/data.py index 6abb2f3..c966e96 100644 --- a/amt/data.py +++ b/amt/data.py @@ -435,6 +435,8 @@ def build( print("The GNU cat command is not available") else: for _path in sharded_save_paths: + if os.path.isfile(_path) is False: + continue shell_cmd = f"cat {_path} >> {save_path}" os.system(shell_cmd) os.remove(_path) diff --git a/amt/inference/model.py b/amt/inference/model.py index 44655c6..8819dd5 100644 --- a/amt/inference/model.py +++ b/amt/inference/model.py @@ -344,6 +344,8 @@ def __init__( self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int ): super().__init__() + self.n_head = n_head + self.n_state = n_state self.token_embedding = nn.Embedding(n_vocab, n_state) self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) @@ -396,15 +398,15 @@ def setup_cache( b.attn.kv_cache = KVCache( max_batch_size=batch_size, max_seq_length=max_seq_len, - n_heads=8, - head_dim=64, + n_heads=self.n_head, + head_dim=self.n_state // self.n_head, dtype=dtype, ).cuda() b.cross_attn.kv_cache = KVCache( max_batch_size=batch_size, max_seq_length=max_audio_len, - n_heads=8, - head_dim=64, + n_heads=self.n_head, + head_dim=self.n_state // self.n_head, dtype=dtype, ).cuda() diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index ba50102..d4ea3cd 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -218,6 +218,9 @@ def process_segments( ), ) + logits[:, 389] *= 1.2 + next_tok_ids = torch.argmax(logits, dim=-1) + next_tok_ids = recalculate_tok_ids( logits=logits, tok_ids=next_tok_ids, @@ -683,6 +686,7 @@ def batch_transcribe( model.decoder = quantize_int8(model.decoder) file_queue = Queue() + sorted(file_paths, key=lambda x: os.path.getsize(x), reverse=True) for file_path in file_paths: if ( os.path.isfile(get_save_path(file_path, input_dir, save_dir)) diff --git a/config/config.json b/config/config.json index 9442648..7503777 100644 --- a/config/config.json +++ b/config/config.json @@ -12,9 +12,11 @@ "audio": { "sample_rate": 16000, "n_fft": 2048, + "n_fft_reduced": 800, "hop_len": 160, "chunk_len": 30, - "n_mels": 256 + "n_mels": 384, + "n_mels_reduced": 128 }, "data": { "stride_factor": 15, diff --git a/config/models/medium-stacked.json b/config/models/medium-stacked.json new file mode 100644 index 0000000..0cf3a51 --- /dev/null +++ b/config/models/medium-stacked.json @@ -0,0 +1,11 @@ +{ + "n_mels": 512, + "n_audio_ctx": 1500, + "n_audio_state": 768, + "n_audio_head": 12, + "n_audio_layer": 6, + "n_text_ctx": 4096, + "n_text_state": 768, + "n_text_head": 12, + "n_text_layer": 6 +} \ No newline at end of file diff --git a/config/models/medium-final.json b/config/models/medium.json similarity index 77% rename from config/models/medium-final.json rename to config/models/medium.json index 69b79c7..a0b3857 100644 --- a/config/models/medium-final.json +++ b/config/models/medium.json @@ -3,9 +3,9 @@ "n_audio_ctx": 1500, "n_audio_state": 768, "n_audio_head": 12, - "n_audio_layer": 12, + "n_audio_layer": 6, "n_text_ctx": 4096, "n_text_state": 768, "n_text_head": 12, - "n_text_layer": 12 + "n_text_layer": 6 } \ No newline at end of file diff --git a/tests/test_data.py b/tests/test_data.py index 9c566ad..56b9eba 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -20,16 +20,44 @@ if os.path.isdir("tests/test_results") is False: os.mkdir("tests/test_results") -MAESTRO_PATH = "/mnt/ssd1/amt/training_data/maestro/train-s15.txt" - +MAESTRO_PATH = "/home/loubb/work/aria-amt/temp/train.txt" + + +def plot_spec( + mel: torch.Tensor, + name: str | int, + onsets: list = [], + offsets: list = [], +): + # mel tensor dimensions [height, width] + + height, width = mel.shape + fig_width, fig_height = width // 100, height // 100 + plt.figure(figsize=(fig_width, fig_height), dpi=100) + plt.imshow( + mel, aspect="auto", origin="lower", cmap="viridis", interpolation="none" + ) + + line_width_in_points = 1 / 100 * 72 # Convert pixel width to points + + for x in onsets: + plt.axvline( + x=x, + color="red", + alpha=0.5, + linewidth=line_width_in_points, # setting the correct line width + ) + for x in offsets: + plt.axvline( + x=x, + color="purple", + alpha=0.5, + linewidth=line_width_in_points, # setting the correct line width + ) -def plot_spec(mel: torch.Tensor, name: str | int): - plt.figure(figsize=(10, 4)) - plt.imshow(mel, aspect="auto", origin="lower", cmap="viridis") - plt.colorbar(format="%+2.0f dB") - plt.title("(mel)-Spectrogram") - plt.tight_layout() - plt.savefig(f"tests/test_results/{name}.png") + plt.axis("off") + plt.tight_layout(pad=0) + plt.savefig(f"tests/test_results/{name}.png", dpi=100) plt.close() @@ -184,7 +212,7 @@ def test_spec(self): spec = audio_transform.spec_transform(wav) shift_spec = audio_transform.shift_spec(spec, 1) - shift_wav = griffin_lim(shift_spec) + shift_wav = griffin_lim(shift_spec[..., :384]) torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) torchaudio.save("tests/test_results/shift.wav", shift_wav, SAMPLE_RATE) @@ -232,28 +260,42 @@ def test_detune(self): spec = audio_transform.spec_transform(wav) shift_spec = audio_transform.detune_spec(spec) shift_wav = griffin_lim(shift_spec) - gl_wav = griffin_lim(spec) + gl_wav = griffin_lim(spec[..., :384]) torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) torchaudio.save("tests/test_results/orig_gl.wav", gl_wav, SAMPLE_RATE) torchaudio.save("tests/test_results/detune.wav", shift_wav, SAMPLE_RATE) def test_mels(self): - SAMPLE_RATE, CHUNK_LEN = 16000, 30 audio_transform = AudioTransform() + SAMPLE_RATE, N_FFT, CHUNK_LEN = ( + audio_transform.sample_rate, + audio_transform.n_fft, + 30, + ) wav, sr = torchaudio.load("tests/test_data/maestro.wav") wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE).mean( 0, keepdim=True )[:, : SAMPLE_RATE * CHUNK_LEN] - wav_aug = audio_transform.aug_wav( - audio_transform.distortion_aug_cpu(wav) - ) - torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) - torchaudio.save("tests/test_results/aug.wav", wav_aug, SAMPLE_RATE) + + # tokenizer = AmtTokenizer() + # mid_dict = MidiDict.from_midi("tests/test_data/maestro-test.mid") + # seq = tokenizer._tokenize_midi_dict(mid_dict, 0, 30000, 10000) + # mid_dict = tokenizer._detokenize_midi_dict(seq, 30000) + # onsets = [msg["data"]["start"] // 10 for msg in mid_dict.note_msgs] + # offsets = [ + # msg["data"]["end"] // 10 + # for msg in mid_dict.note_msgs + # if msg["data"]["end"] < 30000 + # ] wavs = torch.stack((wav[0], wav[0], wav[0])) mels = audio_transform(wavs) for idx in range(mels.shape[0]): - plot_spec(mels[idx], idx) + plot_spec( + mels[idx], + f"{mels[0].shape[0]}-{N_FFT}-{SAMPLE_RATE}", + ) + break def test_distortion(self): SAMPLE_RATE, CHUNK_LEN = 16000, 30