Skip to content

Commit

Permalink
Add stacked spectrograms (#34)
Browse files Browse the repository at this point in the history
* update README

* update README

* add fp16

* inference modification

* add stacked mels
  • Loading branch information
loubbrad authored May 4, 2024
1 parent 49327f1 commit 01ac3e2
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 43 deletions.
57 changes: 39 additions & 18 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __init__(
self.config = load_config()["audio"]
self.sample_rate = self.config["sample_rate"]
self.chunk_len = self.config["chunk_len"]
self.n_fft = self.config["n_fft"]
self.n_fft_reduced = self.config["n_fft_reduced"]
self.n_mels = self.config["n_mels"]
self.n_mels_reduced = self.config["n_mels_reduced"]
self.num_samples = self.sample_rate * self.chunk_len

self.noise_ratio = noise_ratio
Expand All @@ -95,7 +99,7 @@ def __init__(
self.spec_aug_ratio = spec_aug_ratio

self.time_mask_param = 2500
self.freq_mask_param = 15
self.freq_mask_param = 0
self.reduction_resample_rate = 6000

# Audio aug
Expand Down Expand Up @@ -126,13 +130,26 @@ def __init__(
self.num_applause += 1

self.spec_transform = torchaudio.transforms.Spectrogram(
n_fft=self.config["n_fft"],
n_fft=self.n_fft,
hop_length=self.config["hop_len"],
)
self.mel_transform = torchaudio.transforms.MelScale(
n_mels=self.config["n_mels"],
sample_rate=self.config["sample_rate"],
n_stft=self.config["n_fft"] // 2 + 1,
n_mels=self.n_mels,
sample_rate=self.sample_rate,
n_stft=self.n_fft // 2 + 1,
f_min=30,
f_max=8000,
)
self.spec_transform_reduced = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft_reduced,
hop_length=self.config["hop_len"],
)
self.mel_transform_reduced = torchaudio.transforms.MelScale(
n_mels=self.n_mels_reduced,
sample_rate=self.sample_rate,
n_stft=self.n_fft_reduced // 2 + 1,
f_min=30,
f_max=8000,
)
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.TimeMasking(
Expand Down Expand Up @@ -315,16 +332,9 @@ def shift_spec(self, specs: torch.Tensor, shift: int | float):

return shifted_specs

def detune_spec(self, specs: torch.Tensor):
if random.random() < self.detune_ratio:
detune_shift = random.uniform(
-self.detune_max_shift, self.detune_max_shift
)
detuned_specs = self.shift_spec(specs, shift=detune_shift)

return (specs + detuned_specs) / 2
else:
return specs
def detune_spec(self, specs: torch.Tensor, detune_shift: float):
detuned_specs = self.shift_spec(specs, shift=detune_shift)
return (specs + detuned_specs) / 2

def aug_wav(self, wav: torch.Tensor):
# This function doesn't apply distortion. If distortion is desired it
Expand Down Expand Up @@ -361,19 +371,30 @@ def log_mel(
self, wav: torch.Tensor, shift: int | None = None, detune: bool = False
):
spec = self.spec_transform(wav)[..., :-1]
spec_reduced = self.spec_transform_reduced(wav)[..., :-1]

if shift is not None and shift != 0:
spec = self.shift_spec(spec, shift)
spec_reduced = self.shift_spec(spec_reduced, shift)
elif detune is True:
# Don't detune and spec shift at the same time
spec = self.detune_spec(spec)
if random.random() < self.detune_ratio:
detune_shift = random.uniform(
-self.detune_max_shift, self.detune_max_shift
)
spec = self.detune_spec(spec, detune_shift=detune_shift)
spec_reduced = self.detune_spec(
spec_reduced, detune_shift=detune_shift
)

mel_spec = self.mel_transform(spec)
mel_spec_reduced = self.mel_transform_reduced(spec_reduced)

# Norm
log_spec = self.norm_mel(mel_spec)
concat_mel = torch.cat((mel_spec, mel_spec_reduced), dim=1)
log_mel = self.norm_mel(concat_mel)

return log_spec
return log_mel

def forward(self, wav: torch.Tensor, shift: int = 0):
# Noise, and reverb
Expand Down
2 changes: 2 additions & 0 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ def build(
print("The GNU cat command is not available")
else:
for _path in sharded_save_paths:
if os.path.isfile(_path) is False:
continue
shell_cmd = f"cat {_path} >> {save_path}"
os.system(shell_cmd)
os.remove(_path)
Expand Down
10 changes: 6 additions & 4 deletions amt/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.n_head = n_head
self.n_state = n_state
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

Expand Down Expand Up @@ -396,15 +398,15 @@ def setup_cache(
b.attn.kv_cache = KVCache(
max_batch_size=batch_size,
max_seq_length=max_seq_len,
n_heads=8,
head_dim=64,
n_heads=self.n_head,
head_dim=self.n_state // self.n_head,
dtype=dtype,
).cuda()
b.cross_attn.kv_cache = KVCache(
max_batch_size=batch_size,
max_seq_length=max_audio_len,
n_heads=8,
head_dim=64,
n_heads=self.n_head,
head_dim=self.n_state // self.n_head,
dtype=dtype,
).cuda()

Expand Down
4 changes: 4 additions & 0 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ def process_segments(
),
)

logits[:, 389] *= 1.2
next_tok_ids = torch.argmax(logits, dim=-1)

next_tok_ids = recalculate_tok_ids(
logits=logits,
tok_ids=next_tok_ids,
Expand Down Expand Up @@ -683,6 +686,7 @@ def batch_transcribe(
model.decoder = quantize_int8(model.decoder)

file_queue = Queue()
sorted(file_paths, key=lambda x: os.path.getsize(x), reverse=True)
for file_path in file_paths:
if (
os.path.isfile(get_save_path(file_path, input_dir, save_dir))
Expand Down
4 changes: 3 additions & 1 deletion config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
"audio": {
"sample_rate": 16000,
"n_fft": 2048,
"n_fft_reduced": 800,
"hop_len": 160,
"chunk_len": 30,
"n_mels": 256
"n_mels": 384,
"n_mels_reduced": 128
},
"data": {
"stride_factor": 15,
Expand Down
11 changes: 11 additions & 0 deletions config/models/medium-stacked.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"n_mels": 512,
"n_audio_ctx": 1500,
"n_audio_state": 768,
"n_audio_head": 12,
"n_audio_layer": 6,
"n_text_ctx": 4096,
"n_text_state": 768,
"n_text_head": 12,
"n_text_layer": 6
}
4 changes: 2 additions & 2 deletions config/models/medium-final.json → config/models/medium.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
"n_audio_ctx": 1500,
"n_audio_state": 768,
"n_audio_head": 12,
"n_audio_layer": 12,
"n_audio_layer": 6,
"n_text_ctx": 4096,
"n_text_state": 768,
"n_text_head": 12,
"n_text_layer": 12
"n_text_layer": 6
}
78 changes: 60 additions & 18 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,44 @@
if os.path.isdir("tests/test_results") is False:
os.mkdir("tests/test_results")

MAESTRO_PATH = "/mnt/ssd1/amt/training_data/maestro/train-s15.txt"

MAESTRO_PATH = "/home/loubb/work/aria-amt/temp/train.txt"


def plot_spec(
mel: torch.Tensor,
name: str | int,
onsets: list = [],
offsets: list = [],
):
# mel tensor dimensions [height, width]

height, width = mel.shape
fig_width, fig_height = width // 100, height // 100
plt.figure(figsize=(fig_width, fig_height), dpi=100)
plt.imshow(
mel, aspect="auto", origin="lower", cmap="viridis", interpolation="none"
)

line_width_in_points = 1 / 100 * 72 # Convert pixel width to points

for x in onsets:
plt.axvline(
x=x,
color="red",
alpha=0.5,
linewidth=line_width_in_points, # setting the correct line width
)
for x in offsets:
plt.axvline(
x=x,
color="purple",
alpha=0.5,
linewidth=line_width_in_points, # setting the correct line width
)

def plot_spec(mel: torch.Tensor, name: str | int):
plt.figure(figsize=(10, 4))
plt.imshow(mel, aspect="auto", origin="lower", cmap="viridis")
plt.colorbar(format="%+2.0f dB")
plt.title("(mel)-Spectrogram")
plt.tight_layout()
plt.savefig(f"tests/test_results/{name}.png")
plt.axis("off")
plt.tight_layout(pad=0)
plt.savefig(f"tests/test_results/{name}.png", dpi=100)
plt.close()


Expand Down Expand Up @@ -184,7 +212,7 @@ def test_spec(self):

spec = audio_transform.spec_transform(wav)
shift_spec = audio_transform.shift_spec(spec, 1)
shift_wav = griffin_lim(shift_spec)
shift_wav = griffin_lim(shift_spec[..., :384])
torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE)
torchaudio.save("tests/test_results/shift.wav", shift_wav, SAMPLE_RATE)

Expand Down Expand Up @@ -232,28 +260,42 @@ def test_detune(self):
spec = audio_transform.spec_transform(wav)
shift_spec = audio_transform.detune_spec(spec)
shift_wav = griffin_lim(shift_spec)
gl_wav = griffin_lim(spec)
gl_wav = griffin_lim(spec[..., :384])
torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE)
torchaudio.save("tests/test_results/orig_gl.wav", gl_wav, SAMPLE_RATE)
torchaudio.save("tests/test_results/detune.wav", shift_wav, SAMPLE_RATE)

def test_mels(self):
SAMPLE_RATE, CHUNK_LEN = 16000, 30
audio_transform = AudioTransform()
SAMPLE_RATE, N_FFT, CHUNK_LEN = (
audio_transform.sample_rate,
audio_transform.n_fft,
30,
)
wav, sr = torchaudio.load("tests/test_data/maestro.wav")
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE).mean(
0, keepdim=True
)[:, : SAMPLE_RATE * CHUNK_LEN]
wav_aug = audio_transform.aug_wav(
audio_transform.distortion_aug_cpu(wav)
)
torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE)
torchaudio.save("tests/test_results/aug.wav", wav_aug, SAMPLE_RATE)

# tokenizer = AmtTokenizer()
# mid_dict = MidiDict.from_midi("tests/test_data/maestro-test.mid")
# seq = tokenizer._tokenize_midi_dict(mid_dict, 0, 30000, 10000)
# mid_dict = tokenizer._detokenize_midi_dict(seq, 30000)
# onsets = [msg["data"]["start"] // 10 for msg in mid_dict.note_msgs]
# offsets = [
# msg["data"]["end"] // 10
# for msg in mid_dict.note_msgs
# if msg["data"]["end"] < 30000
# ]

wavs = torch.stack((wav[0], wav[0], wav[0]))
mels = audio_transform(wavs)
for idx in range(mels.shape[0]):
plot_spec(mels[idx], idx)
plot_spec(
mels[idx],
f"{mels[0].shape[0]}-{N_FFT}-{SAMPLE_RATE}",
)
break

def test_distortion(self):
SAMPLE_RATE, CHUNK_LEN = 16000, 30
Expand Down

0 comments on commit 01ac3e2

Please sign in to comment.