-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
538 additions
and
238 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,397 @@ | ||
import os | ||
import time | ||
import random | ||
import torch | ||
import torch.multiprocessing as multiprocessing | ||
|
||
from torch.multiprocessing import Queue | ||
from torch.cuda import device_count, is_available | ||
from tqdm import tqdm | ||
|
||
from amt.model import AmtEncoderDecoder | ||
from amt.tokenizer import AmtTokenizer | ||
from amt.audio import AudioTransform | ||
from amt.data import get_wav_mid_segments | ||
from amt.config import load_config | ||
from aria.data.midi import MidiDict | ||
|
||
MAX_SEQ_LEN = 4096 | ||
LEN_MS = 30000 | ||
STRIDE_FACTOR = 3 | ||
CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR | ||
BEAM = 3 | ||
ONSET_TOLERANCE = 50 | ||
VEL_TOLERANCE = 50 | ||
|
||
|
||
def calculate_vel( | ||
logits: torch.tensor, | ||
init_vel: int, | ||
tokenizer: AmtTokenizer = AmtTokenizer(), | ||
): | ||
probs, idxs = torch.topk(torch.softmax(logits, dim=-1), BEAM) | ||
vels = [tokenizer.id_to_tok[idx.item()] for idx in idxs] | ||
|
||
# Get rid of outliers | ||
for idx in range(BEAM): | ||
vel = vels[idx] | ||
if type(vel) is not tuple: | ||
vels[idx] = 0 | ||
probs[idx] = 0.0 | ||
elif vel[0] != "vel": | ||
vels[idx] = 0 | ||
probs[idx] = 0.0 | ||
elif (vel[1] < init_vel - VEL_TOLERANCE / 2) or ( | ||
vel[1] > init_vel + VEL_TOLERANCE / 2 | ||
): | ||
vels[idx] = vels[idx][1] | ||
probs[idx] = 0.0 | ||
else: | ||
vels[idx] = vels[idx][1] | ||
|
||
vels = torch.tensor(vels).to(probs.device) | ||
new_vel = torch.sum(vels * probs) / torch.sum(probs) | ||
new_vel = round(new_vel.item() / 10) * 10 | ||
|
||
return tokenizer.tok_to_id[("vel", new_vel)] | ||
|
||
|
||
def calculate_onset( | ||
logits: torch.tensor, | ||
init_onset: int, | ||
tokenizer: AmtTokenizer = AmtTokenizer(), | ||
): | ||
probs, idxs = torch.topk(torch.softmax(logits, dim=-1), BEAM) | ||
onsets = [tokenizer.id_to_tok[idx.item()] for idx in idxs] | ||
|
||
# Get rid of outliers | ||
for idx in range(BEAM): | ||
onset = onsets[idx] | ||
if type(onset) is not tuple: | ||
onsets[idx] = 0 | ||
probs[idx] = 0.0 | ||
elif onset[0] != "onset": | ||
onsets[idx] = 0 | ||
probs[idx] = 0.0 | ||
elif (onset[1] < init_onset - ONSET_TOLERANCE / 2) or ( | ||
onset[1] > init_onset + ONSET_TOLERANCE / 2 | ||
): | ||
onsets[idx] = onsets[idx][1] | ||
probs[idx] = 0.0 | ||
else: | ||
onsets[idx] = onsets[idx][1] | ||
|
||
onsets = torch.tensor(onsets).to(probs.device) | ||
new_onset = torch.sum(onsets * probs) / torch.sum(probs) | ||
new_onset = round(new_onset.item() / 10) * 10 | ||
|
||
return tokenizer.tok_to_id[("onset", new_onset)] | ||
|
||
|
||
def process_segments( | ||
tasks: list, | ||
model: AmtEncoderDecoder, | ||
audio_transform: AudioTransform, | ||
tokenizer: AmtTokenizer, | ||
): | ||
audio_segs = torch.stack( | ||
[audio_seg for (audio_seg, prefix), _ in tasks] | ||
).cuda() | ||
log_mels = audio_transform.log_mel(audio_segs) | ||
audio_features = model.embed_audio(mel=log_mels) | ||
|
||
raw_prefixes = [prefix for (audio_seg, prefix), _ in tasks] | ||
prefix_lens = [len(prefix) for prefix in raw_prefixes] | ||
min_prefix_len = min(prefix_lens) | ||
prefixes = [ | ||
tokenizer.trunc_seq(prefix, MAX_SEQ_LEN) for prefix in raw_prefixes | ||
] | ||
seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda() | ||
eos_seen = [False for _ in prefixes] | ||
|
||
kv_cache = model.get_empty_cache() | ||
|
||
# for idx in ( | ||
# pbar := tqdm( | ||
# range(min_prefix_len, MAX_SEQ_LEN - 1), | ||
# total=MAX_SEQ_LEN - (min_prefix_len + 1), | ||
# leave=False, | ||
# ) | ||
# ): | ||
for idx in range(min_prefix_len, MAX_SEQ_LEN - 1): | ||
if idx == min_prefix_len: | ||
logits = model.decoder( | ||
xa=audio_features, | ||
x=seq[:, :idx], | ||
kv_cache=kv_cache, | ||
) | ||
else: | ||
logits = model.decoder( | ||
xa=audio_features, | ||
x=seq[:, idx - 1 : idx], | ||
kv_cache=kv_cache, | ||
) | ||
|
||
next_tok_ids = torch.argmax(logits[:, -1], dim=-1) | ||
|
||
for batch_idx in range(logits.shape[0]): | ||
if eos_seen[batch_idx] is not False: | ||
# End already seen, add pad token | ||
tok_id = tokenizer.pad_id | ||
elif idx >= prefix_lens[batch_idx]: | ||
# New token required, recalculated if needed | ||
tok_id = next_tok_ids[batch_idx].item() | ||
tok = tokenizer.id_to_tok[tok_id] | ||
if type(tok) is tuple and tok[0] == "onset": | ||
# If onset token, recalculate | ||
tok_id = calculate_onset(logits[batch_idx, -1], tok[1]) | ||
elif type(tok) is tuple and tok[0] == "vel": | ||
# If velocity token, recalculate | ||
tok_id = calculate_vel(logits[batch_idx, -1], tok[1]) | ||
|
||
else: | ||
# Still in prefix tokens, do nothing | ||
tok_id = tokenizer.tok_to_id[prefixes[batch_idx][idx]] | ||
|
||
seq[batch_idx, idx] = tok_id | ||
if tokenizer.id_to_tok[tok_id] == tokenizer.eos_tok: | ||
eos_seen[batch_idx] = idx | ||
|
||
if all(eos_seen): | ||
break | ||
|
||
results = [ | ||
tokenizer.decode(seq[_idx, : eos_seen[_idx] + 1]) | ||
for _idx in range(seq.shape[0]) | ||
] | ||
|
||
return results | ||
|
||
|
||
def gpu_manager( | ||
gpu_task_queue: Queue, | ||
result_queue: Queue, | ||
model: AmtEncoderDecoder, | ||
batch_size: int, | ||
): | ||
model.cuda() | ||
model.eval() | ||
model.compile() | ||
audio_transform = AudioTransform().cuda() | ||
tokenizer = AmtTokenizer(return_tensors=True) | ||
process_pid = multiprocessing.current_process().pid | ||
|
||
wait_for_batch = True | ||
batch = [] | ||
while True: | ||
try: | ||
task, pid = gpu_task_queue.get(timeout=5) | ||
except: | ||
print(f"{process_pid}: GPU task timeout") | ||
if len(batch) == 0: | ||
print(f"{process_pid}: Finished GPU tasks") | ||
return | ||
else: | ||
wait_for_batch = False | ||
else: | ||
batch.append((task, pid)) | ||
|
||
if len(batch) == batch_size or ( | ||
len(batch) > 0 and wait_for_batch is False | ||
): | ||
# Process batch on GPU | ||
results = process_segments( | ||
tasks=[task for task in batch], | ||
model=model, | ||
audio_transform=audio_transform, | ||
tokenizer=tokenizer, | ||
) | ||
for result, (_, pid) in zip(results, batch): | ||
result_queue.put({"result": result, "pid": pid}) | ||
batch.clear() | ||
|
||
|
||
def _shift_onset(seq: list, shift_ms: int): | ||
res = [] | ||
for tok in seq: | ||
if type(tok) is tuple and tok[0] == "onset": | ||
res.append(("onset", tok[1] + shift_ms)) | ||
else: | ||
res.append(tok) | ||
|
||
return res | ||
|
||
|
||
def _truncate_seq( | ||
seq: list, | ||
start_ms: int, | ||
end_ms: int, | ||
tokenizer: AmtTokenizer = AmtTokenizer(), | ||
): | ||
if start_ms == end_ms: | ||
_mid_dict, unclosed_notes = tokenizer._detokenize_midi_dict( | ||
seq, start_ms, return_unclosed_notes=True | ||
) | ||
random.shuffle(unclosed_notes) | ||
return [("prev", p) for p in unclosed_notes] + [tokenizer.bos_tok] | ||
else: | ||
_mid_dict = tokenizer._detokenize_midi_dict(seq, LEN_MS) | ||
try: | ||
res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1) | ||
except: | ||
return ["<S>"] | ||
else: | ||
return res[: res.index(tokenizer.eos_tok)] | ||
|
||
|
||
def process_file( | ||
file_path, | ||
gpu_task_queue: Queue, | ||
result_queue: Queue, | ||
tokenizer: AmtTokenizer = AmtTokenizer(), | ||
): | ||
process_pid = multiprocessing.current_process().pid | ||
print(f"{process_pid}: Getting wav segments") | ||
audio_segments = [ | ||
f | ||
for f, _ in get_wav_mid_segments( | ||
audio_path=file_path, stride_factor=STRIDE_FACTOR | ||
) | ||
] | ||
seq = ["<S>"] | ||
res = ["<S>"] | ||
for idx, audio_seg in enumerate(audio_segments): | ||
init_idx = len(seq) | ||
|
||
# Add to gpu queue and wait for results | ||
gpu_task_queue.put(((audio_seg, seq), process_pid)) | ||
while True: | ||
gpu_result = result_queue.get() | ||
if gpu_result["pid"] == process_pid: | ||
seq = gpu_result["result"] | ||
break | ||
else: | ||
result_queue.put(gpu_result) | ||
|
||
res += _shift_onset( | ||
seq[init_idx : seq.index(tokenizer.eos_tok)], | ||
idx * CHUNK_LEN_MS, | ||
) | ||
print( | ||
f"{process_pid}: Finished {idx+1}/{len(audio_segments)} audio segments" | ||
) | ||
|
||
if idx == len(audio_segments) - 1: | ||
break | ||
else: | ||
seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS) | ||
if len(seq) == 1: | ||
print(f"{process_pid}: exiting early") | ||
return res | ||
|
||
return res | ||
|
||
|
||
def worker( | ||
file_queue: Queue, | ||
gpu_task_queue: Queue, | ||
result_queue: Queue, | ||
save_dir: str, | ||
input_dir: str | None = None, | ||
): | ||
def _get_save_path(_file_path: str): | ||
if input_dir is None: | ||
save_path = os.path.join( | ||
save_dir, | ||
os.path.splitext(os.path.basename(file_path))[0] + ".mid", | ||
) | ||
else: | ||
input_rel_path = os.path.relpath(_file_path, input_dir) | ||
save_path = os.path.join( | ||
save_dir, os.path.splitext(input_rel_path)[0] + ".mid" | ||
) | ||
if not os.path.exists(os.path.dirname(save_path)): | ||
os.makedirs(os.path.dirname(save_path)) | ||
|
||
return save_path | ||
|
||
pid = multiprocessing.current_process().pid | ||
tokenizer = AmtTokenizer() | ||
files_processed = 0 | ||
while not file_queue.empty(): | ||
file_path = file_queue.get() | ||
save_path = _get_save_path(file_path) | ||
if os.path.exists(save_path): | ||
print(f"{pid}: {save_path} already exists, overwriting") | ||
|
||
try: | ||
res = process_file(file_path, gpu_task_queue, result_queue) | ||
except Exception as e: | ||
print(f"{pid}: Failed to transcribe {file_path}") | ||
continue | ||
|
||
files_processed += 1 | ||
|
||
for tok in res[::-1]: | ||
if type(tok) is tuple and tok[0] == "onset": | ||
last_onset = tok[1] | ||
break | ||
|
||
try: | ||
mid_dict = tokenizer._detokenize_midi_dict( | ||
tokenized_seq=res, len_ms=last_onset | ||
) | ||
mid = mid_dict.to_midi() | ||
mid.save(save_path) | ||
except Exception as e: | ||
print(f"{pid}: Failed to detokenize with error {e}") | ||
else: | ||
print(f"{pid}: Finished file {files_processed} - {file_path}") | ||
print(f"{pid}: {file_queue.qsize()} file(s) remaining in queue") | ||
|
||
|
||
def batch_transcribe( | ||
file_paths: list, | ||
model: AmtEncoderDecoder, | ||
save_dir: str, | ||
batch_size: int = 16, | ||
gpu_id: int | None = None, | ||
input_dir: str | None = None, | ||
): | ||
if gpu_id is not None: | ||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) | ||
|
||
model.to("cuda") | ||
file_queue = Queue() | ||
for file_path in file_paths: | ||
file_queue.put(file_path) | ||
|
||
gpu_task_queue = Queue() | ||
result_queue = Queue() | ||
|
||
gpu_manager_process = multiprocessing.Process( | ||
target=gpu_manager, | ||
args=(gpu_task_queue, result_queue, model, batch_size), | ||
) | ||
gpu_manager_process.start() | ||
|
||
worker_processes = [ | ||
multiprocessing.Process( | ||
target=worker, | ||
args=( | ||
file_queue, | ||
gpu_task_queue, | ||
result_queue, | ||
save_dir, | ||
input_dir, | ||
), | ||
) | ||
for _ in range(batch_size + 1) | ||
] | ||
for p in worker_processes: | ||
p.start() | ||
|
||
for p in worker_processes: | ||
p.join() | ||
|
||
gpu_manager_process.join() |
Oops, something went wrong.