From 9bb575082950aba1895893b1aa7dd006fdc1adde Mon Sep 17 00:00:00 2001 From: loubbrad Date: Thu, 7 Mar 2024 19:31:14 +0000 Subject: [PATCH] add multi gpu inference --- amt/infer.py | 397 +++++++++++++++++++++++++++++++++++++++++++++++ amt/inference.py | 117 -------------- amt/model.py | 150 ++++++------------ amt/run.py | 112 ++++++++++--- 4 files changed, 538 insertions(+), 238 deletions(-) create mode 100644 amt/infer.py delete mode 100644 amt/inference.py diff --git a/amt/infer.py b/amt/infer.py new file mode 100644 index 0000000..72c99b0 --- /dev/null +++ b/amt/infer.py @@ -0,0 +1,397 @@ +import os +import time +import random +import torch +import torch.multiprocessing as multiprocessing + +from torch.multiprocessing import Queue +from torch.cuda import device_count, is_available +from tqdm import tqdm + +from amt.model import AmtEncoderDecoder +from amt.tokenizer import AmtTokenizer +from amt.audio import AudioTransform +from amt.data import get_wav_mid_segments +from amt.config import load_config +from aria.data.midi import MidiDict + +MAX_SEQ_LEN = 4096 +LEN_MS = 30000 +STRIDE_FACTOR = 3 +CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR +BEAM = 3 +ONSET_TOLERANCE = 50 +VEL_TOLERANCE = 50 + + +def calculate_vel( + logits: torch.tensor, + init_vel: int, + tokenizer: AmtTokenizer = AmtTokenizer(), +): + probs, idxs = torch.topk(torch.softmax(logits, dim=-1), BEAM) + vels = [tokenizer.id_to_tok[idx.item()] for idx in idxs] + + # Get rid of outliers + for idx in range(BEAM): + vel = vels[idx] + if type(vel) is not tuple: + vels[idx] = 0 + probs[idx] = 0.0 + elif vel[0] != "vel": + vels[idx] = 0 + probs[idx] = 0.0 + elif (vel[1] < init_vel - VEL_TOLERANCE / 2) or ( + vel[1] > init_vel + VEL_TOLERANCE / 2 + ): + vels[idx] = vels[idx][1] + probs[idx] = 0.0 + else: + vels[idx] = vels[idx][1] + + vels = torch.tensor(vels).to(probs.device) + new_vel = torch.sum(vels * probs) / torch.sum(probs) + new_vel = round(new_vel.item() / 10) * 10 + + return tokenizer.tok_to_id[("vel", new_vel)] + + +def calculate_onset( + logits: torch.tensor, + init_onset: int, + tokenizer: AmtTokenizer = AmtTokenizer(), +): + probs, idxs = torch.topk(torch.softmax(logits, dim=-1), BEAM) + onsets = [tokenizer.id_to_tok[idx.item()] for idx in idxs] + + # Get rid of outliers + for idx in range(BEAM): + onset = onsets[idx] + if type(onset) is not tuple: + onsets[idx] = 0 + probs[idx] = 0.0 + elif onset[0] != "onset": + onsets[idx] = 0 + probs[idx] = 0.0 + elif (onset[1] < init_onset - ONSET_TOLERANCE / 2) or ( + onset[1] > init_onset + ONSET_TOLERANCE / 2 + ): + onsets[idx] = onsets[idx][1] + probs[idx] = 0.0 + else: + onsets[idx] = onsets[idx][1] + + onsets = torch.tensor(onsets).to(probs.device) + new_onset = torch.sum(onsets * probs) / torch.sum(probs) + new_onset = round(new_onset.item() / 10) * 10 + + return tokenizer.tok_to_id[("onset", new_onset)] + + +def process_segments( + tasks: list, + model: AmtEncoderDecoder, + audio_transform: AudioTransform, + tokenizer: AmtTokenizer, +): + audio_segs = torch.stack( + [audio_seg for (audio_seg, prefix), _ in tasks] + ).cuda() + log_mels = audio_transform.log_mel(audio_segs) + audio_features = model.embed_audio(mel=log_mels) + + raw_prefixes = [prefix for (audio_seg, prefix), _ in tasks] + prefix_lens = [len(prefix) for prefix in raw_prefixes] + min_prefix_len = min(prefix_lens) + prefixes = [ + tokenizer.trunc_seq(prefix, MAX_SEQ_LEN) for prefix in raw_prefixes + ] + seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda() + eos_seen = [False for _ in prefixes] + + kv_cache = model.get_empty_cache() + + # for idx in ( + # pbar := tqdm( + # range(min_prefix_len, MAX_SEQ_LEN - 1), + # total=MAX_SEQ_LEN - (min_prefix_len + 1), + # leave=False, + # ) + # ): + for idx in range(min_prefix_len, MAX_SEQ_LEN - 1): + if idx == min_prefix_len: + logits = model.decoder( + xa=audio_features, + x=seq[:, :idx], + kv_cache=kv_cache, + ) + else: + logits = model.decoder( + xa=audio_features, + x=seq[:, idx - 1 : idx], + kv_cache=kv_cache, + ) + + next_tok_ids = torch.argmax(logits[:, -1], dim=-1) + + for batch_idx in range(logits.shape[0]): + if eos_seen[batch_idx] is not False: + # End already seen, add pad token + tok_id = tokenizer.pad_id + elif idx >= prefix_lens[batch_idx]: + # New token required, recalculated if needed + tok_id = next_tok_ids[batch_idx].item() + tok = tokenizer.id_to_tok[tok_id] + if type(tok) is tuple and tok[0] == "onset": + # If onset token, recalculate + tok_id = calculate_onset(logits[batch_idx, -1], tok[1]) + elif type(tok) is tuple and tok[0] == "vel": + # If velocity token, recalculate + tok_id = calculate_vel(logits[batch_idx, -1], tok[1]) + + else: + # Still in prefix tokens, do nothing + tok_id = tokenizer.tok_to_id[prefixes[batch_idx][idx]] + + seq[batch_idx, idx] = tok_id + if tokenizer.id_to_tok[tok_id] == tokenizer.eos_tok: + eos_seen[batch_idx] = idx + + if all(eos_seen): + break + + results = [ + tokenizer.decode(seq[_idx, : eos_seen[_idx] + 1]) + for _idx in range(seq.shape[0]) + ] + + return results + + +def gpu_manager( + gpu_task_queue: Queue, + result_queue: Queue, + model: AmtEncoderDecoder, + batch_size: int, +): + model.cuda() + model.eval() + model.compile() + audio_transform = AudioTransform().cuda() + tokenizer = AmtTokenizer(return_tensors=True) + process_pid = multiprocessing.current_process().pid + + wait_for_batch = True + batch = [] + while True: + try: + task, pid = gpu_task_queue.get(timeout=5) + except: + print(f"{process_pid}: GPU task timeout") + if len(batch) == 0: + print(f"{process_pid}: Finished GPU tasks") + return + else: + wait_for_batch = False + else: + batch.append((task, pid)) + + if len(batch) == batch_size or ( + len(batch) > 0 and wait_for_batch is False + ): + # Process batch on GPU + results = process_segments( + tasks=[task for task in batch], + model=model, + audio_transform=audio_transform, + tokenizer=tokenizer, + ) + for result, (_, pid) in zip(results, batch): + result_queue.put({"result": result, "pid": pid}) + batch.clear() + + +def _shift_onset(seq: list, shift_ms: int): + res = [] + for tok in seq: + if type(tok) is tuple and tok[0] == "onset": + res.append(("onset", tok[1] + shift_ms)) + else: + res.append(tok) + + return res + + +def _truncate_seq( + seq: list, + start_ms: int, + end_ms: int, + tokenizer: AmtTokenizer = AmtTokenizer(), +): + if start_ms == end_ms: + _mid_dict, unclosed_notes = tokenizer._detokenize_midi_dict( + seq, start_ms, return_unclosed_notes=True + ) + random.shuffle(unclosed_notes) + return [("prev", p) for p in unclosed_notes] + [tokenizer.bos_tok] + else: + _mid_dict = tokenizer._detokenize_midi_dict(seq, LEN_MS) + try: + res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1) + except: + return [""] + else: + return res[: res.index(tokenizer.eos_tok)] + + +def process_file( + file_path, + gpu_task_queue: Queue, + result_queue: Queue, + tokenizer: AmtTokenizer = AmtTokenizer(), +): + process_pid = multiprocessing.current_process().pid + print(f"{process_pid}: Getting wav segments") + audio_segments = [ + f + for f, _ in get_wav_mid_segments( + audio_path=file_path, stride_factor=STRIDE_FACTOR + ) + ] + seq = [""] + res = [""] + for idx, audio_seg in enumerate(audio_segments): + init_idx = len(seq) + + # Add to gpu queue and wait for results + gpu_task_queue.put(((audio_seg, seq), process_pid)) + while True: + gpu_result = result_queue.get() + if gpu_result["pid"] == process_pid: + seq = gpu_result["result"] + break + else: + result_queue.put(gpu_result) + + res += _shift_onset( + seq[init_idx : seq.index(tokenizer.eos_tok)], + idx * CHUNK_LEN_MS, + ) + print( + f"{process_pid}: Finished {idx+1}/{len(audio_segments)} audio segments" + ) + + if idx == len(audio_segments) - 1: + break + else: + seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS) + if len(seq) == 1: + print(f"{process_pid}: exiting early") + return res + + return res + + +def worker( + file_queue: Queue, + gpu_task_queue: Queue, + result_queue: Queue, + save_dir: str, + input_dir: str | None = None, +): + def _get_save_path(_file_path: str): + if input_dir is None: + save_path = os.path.join( + save_dir, + os.path.splitext(os.path.basename(file_path))[0] + ".mid", + ) + else: + input_rel_path = os.path.relpath(_file_path, input_dir) + save_path = os.path.join( + save_dir, os.path.splitext(input_rel_path)[0] + ".mid" + ) + if not os.path.exists(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path)) + + return save_path + + pid = multiprocessing.current_process().pid + tokenizer = AmtTokenizer() + files_processed = 0 + while not file_queue.empty(): + file_path = file_queue.get() + save_path = _get_save_path(file_path) + if os.path.exists(save_path): + print(f"{pid}: {save_path} already exists, overwriting") + + try: + res = process_file(file_path, gpu_task_queue, result_queue) + except Exception as e: + print(f"{pid}: Failed to transcribe {file_path}") + continue + + files_processed += 1 + + for tok in res[::-1]: + if type(tok) is tuple and tok[0] == "onset": + last_onset = tok[1] + break + + try: + mid_dict = tokenizer._detokenize_midi_dict( + tokenized_seq=res, len_ms=last_onset + ) + mid = mid_dict.to_midi() + mid.save(save_path) + except Exception as e: + print(f"{pid}: Failed to detokenize with error {e}") + else: + print(f"{pid}: Finished file {files_processed} - {file_path}") + print(f"{pid}: {file_queue.qsize()} file(s) remaining in queue") + + +def batch_transcribe( + file_paths: list, + model: AmtEncoderDecoder, + save_dir: str, + batch_size: int = 16, + gpu_id: int | None = None, + input_dir: str | None = None, +): + if gpu_id is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + + model.to("cuda") + file_queue = Queue() + for file_path in file_paths: + file_queue.put(file_path) + + gpu_task_queue = Queue() + result_queue = Queue() + + gpu_manager_process = multiprocessing.Process( + target=gpu_manager, + args=(gpu_task_queue, result_queue, model, batch_size), + ) + gpu_manager_process.start() + + worker_processes = [ + multiprocessing.Process( + target=worker, + args=( + file_queue, + gpu_task_queue, + result_queue, + save_dir, + input_dir, + ), + ) + for _ in range(batch_size + 1) + ] + for p in worker_processes: + p.start() + + for p in worker_processes: + p.join() + + gpu_manager_process.join() diff --git a/amt/inference.py b/amt/inference.py deleted file mode 100644 index 31eca8f..0000000 --- a/amt/inference.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import random -import torch - -from tqdm import tqdm - -from amt.model import AmtEncoderDecoder -from amt.tokenizer import AmtTokenizer -from amt.data import get_features -from amt.config import load_config -from aria.data.midi import MidiDict - - -# TODO: Implement this with KV-caching, see the whisper inference file - -# Due to the autoregressive nature, a good inference algorithm should use some -# sort of branching to make sure that we don't miss notes, ect... Implement this -# next week -- Exciting problem (checkout other inference algos) - -# Implement maximum note len =5s -# Implement either beam search or decoding initial onset note on first - - -def greedy_sample( - model: AmtEncoderDecoder, - audio_path: str, - device: str, -): - LEN_MS = 30000 # This should not be hardcoded - MAX_SEQ_LEN = model.dims.n_text_ctx - - def _process_segment( - audio_seg: torch.tensor, - prefix: list, - model: AmtEncoderDecoder, - tokenizer: AmtTokenizer = AmtTokenizer(), - ): - start_idx = len(prefix) - pad_id = tokenizer.pad_id - eos_id = tokenizer.tok_to_id[tokenizer.eos_tok] - audio_seg = audio_seg.unsqueeze(0).to(device) - seq = tokenizer.encode(tokenizer.trunc_seq(prefix, MAX_SEQ_LEN)) - seq = torch.tensor(seq).unsqueeze(0).to(device) - audio_feature = model.embed_audio(mel=audio_seg) - - for idx in ( - pbar := tqdm( - range(start_idx, MAX_SEQ_LEN - 1), - total=MAX_SEQ_LEN - (start_idx + 1), - leave=False, - ) - ): - logits = model.logits( - audio_features=audio_feature, tokens=seq[:, :idx] - ) - next_tok_id = torch.argmax(logits[0, -1], dim=-1) - - seq[0, idx] = next_tok_id - if next_tok_id == pad_id or next_tok_id == eos_id: - break - - if idx == MAX_SEQ_LEN - 2: - print("WARNING: Ran out of context when generating sequence") - - seq = tokenizer.decode(seq[0, :]) - _, unclosed_notes = tokenizer._detokenize_midi_dict( - tokenized_seq=seq, - len_ms=LEN_MS, - return_unclosed_notes=True, - ) - - return seq, unclosed_notes - - audio_segments = [f for f, _ in get_features(audio_path=audio_path)] - print(f"{len(audio_segments)} audio segments to process...") - - model.to(device) - model.eval() - tokenizer = AmtTokenizer() - _unclosed_notes = [] - concat_seq = [tokenizer.bos_tok] - _onset_adj = 0 - for idx, _audio_seg in enumerate(audio_segments): - _seq = [("prev", p) for p in _unclosed_notes] + [tokenizer.bos_tok] - - _seq, _unclosed_notes = _process_segment( - audio_seg=_audio_seg, - prefix=_seq, - model=model, - tokenizer=tokenizer, - ) - random.shuffle(_unclosed_notes) - - # DEBUG - __midi_dict = tokenizer._detokenize_midi_dict(_seq, 30000) - __midi = __midi_dict.to_midi() - __midi.save(f"/weka/proj-aria/aria-amt/samples/res{idx}.mid") - - print(f"Done {idx + 1}/{len(audio_segments)}") - for tok in _seq: - if type(tok) is tuple and tok[0] == "onset": - _onset_orig = tok[1] - _onset_adj = _onset_orig + (idx * LEN_MS) - concat_seq.append(("onset", _onset_adj)) - elif type(tok) is tuple and tok[0] == "prev": - continue - elif tok is tokenizer.bos_tok: - continue - elif tok is tokenizer.pad_tok or tok is tokenizer.eos_tok: - break - else: - concat_seq.append(tok) - - return tokenizer._detokenize_midi_dict( - tokenized_seq=concat_seq, - len_ms=_onset_adj, - ) diff --git a/amt/model.py b/amt/model.py index 5e19d7f..9e8ccb2 100644 --- a/amt/model.py +++ b/amt/model.py @@ -64,22 +64,54 @@ def forward( ): q = self.query(x) - if kv_cache is None or xa is None or self.key not in kv_cache: - # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; - # otherwise, perform key/value projections for self- or cross-attention as usual. - k = self.key(x if xa is None else xa) - v = self.value(x if xa is None else xa) + if kv_cache is None: + # Normal forward + if xa is not None: + # Cross att + k = self.key(xa) + v = self.value(xa) + else: + # Self att in encoder/decoder + k = self.key(x) + v = self.value(x) else: - # for cross-attention, calculate keys and values once and reuse in subsequent calls. - k = kv_cache[self.key] - v = kv_cache[self.value] - - # Old --- - # wv, qk = self.qkv_attention(q, k, v, mask) - # End --- + # Using cache + k_id = f"{id(self)}_k" + v_id = f"{id(self)}_v" + + if xa is not None: + # Cross att - calculate once and reuse + if kv_cache.get(k_id) is None: + # Not recorded yet, calculate and store + k = self.key(xa) + v = self.value(xa) + kv_cache[k_id] = k + kv_cache[v_id] = v + else: + # Already recorded, get + k = kv_cache[k_id] + v = kv_cache[v_id] + else: + # Decoder self att, append each time + if kv_cache.get(k_id) is None: + # Not recorded yet, calculate and store + k = self.key(x) + v = self.value(x) + kv_cache[k_id] = k + kv_cache[v_id] = v + else: + # Already recorded, get and append + k = torch.cat((kv_cache[k_id], self.key(x)), dim=1).detach() + v = torch.cat( + (kv_cache[v_id], self.value(x)), dim=1 + ).detach() + kv_cache[k_id] = k + kv_cache[v_id] = v + + # When using kv_cache for decoder self attention, we don't + # want to use a mask in the self attention calculation + mask = None - # New code ------ - debug = False # Reshape and transpose for attention calculation batch_size, target_seq_len, _ = q.shape batch_size, source_seq_len, _ = k.shape @@ -90,19 +122,11 @@ def forward( # (bz, L, nh, dh) -> (bz, nh, L, dh) q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) - if debug is True: - print(f"q shape: {q.shape}") - print(f"k shape: {k.shape}") - print(f"v shape: {v.shape}") - if mask is not None: - print(f"mask shape: {mask.shape}") - if mask == None: _is_causal = False else: _is_causal = True - qk = None # Only used during kv-caching? wv = F.scaled_dot_product_attention( query=q, key=k, @@ -114,31 +138,7 @@ def forward( wv = wv.transpose(1, 2) wv = wv.view(batch_size, target_seq_len, self.n_head * self.d_head) - if debug is True: - print(f"att_out shape: {wv.shape}") - if qk is not None: - print(f"att_weights shape: {qk.shape}") - - # End new code ------ - - return self.out(wv), qk - - def qkv_attention( - self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None - ): - n_batch, n_ctx, n_state = q.shape - scale = (n_state // self.n_head) ** -0.25 - q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale - k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale - v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) - - qk = q @ k - if mask is not None: - qk = qk + mask[:n_ctx, :n_ctx] - qk = qk.float() - - w = F.softmax(qk, dim=-1).to(q.dtype) - return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() + return self.out(wv), None class ResidualAttentionBlock(nn.Module): @@ -279,26 +279,6 @@ def __init__(self, dims: ModelConfig): self.dims.n_text_head, self.dims.n_text_layer, ) - # use the last half among the decoder layers for time alignment by default; - # to use a specific set of heads, see `set_alignment_heads()` below. - # all_heads = torch.zeros( - # self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool - # ) - # all_heads[self.dims.n_text_layer // 2 :] = True - # self.register_buffer( - # "alignment_heads", all_heads.to_sparse(), persistent=False - # ) - - # def set_alignment_heads(self, dump: bytes): - # array = np.frombuffer( - # gzip.decompress(base64.b85decode(dump)), dtype=bool - # ).copy() - # mask = torch.from_numpy(array).reshape( - # self.dims.n_text_layer, self.dims.n_text_head - # ) - # self.register_buffer( - # "alignment_heads", mask.to_sparse(), persistent=False - # ) def embed_audio(self, mel: torch.Tensor): return self.encoder(mel) @@ -314,37 +294,5 @@ def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: def device(self): return next(self.parameters()).device - def install_kv_cache_hooks(self, cache: Optional[dict] = None): - """ - The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value - tensors calculated for the previous positions. This method returns a dictionary that stores - all caches, and the necessary hooks for the key and value projection modules that save the - intermediate tensors to be reused during later calculations. - - Returns - ------- - cache : Dict[nn.Module, torch.Tensor] - A dictionary object mapping the key/value projection modules to its cache - hooks : List[RemovableHandle] - List of PyTorch RemovableHandle objects to stop the hooks to be called - """ - cache = {**cache} if cache is not None else {} - hooks = [] - - def save_to_cache(module, _, output): - if module not in cache or output.shape[1] > self.dims.n_text_ctx: - # save as-is, for the first token or cross attention - cache[module] = output - else: - cache[module] = torch.cat( - [cache[module], output], dim=1 - ).detach() - return cache[module] - - def install_hooks(layer: nn.Module): - if isinstance(layer, MultiHeadAttention): - hooks.append(layer.key.register_forward_hook(save_to_cache)) - hooks.append(layer.value.register_forward_hook(save_to_cache)) - - self.decoder.apply(install_hooks) - return cache, hooks + def get_empty_cache(self): + return {} diff --git a/amt/run.py b/amt/run.py index ec0eb9f..c068b98 100644 --- a/amt/run.py +++ b/amt/run.py @@ -3,6 +3,7 @@ import argparse import sys import os +import glob from csv import DictReader @@ -29,8 +30,15 @@ def _parse_transcribe_args(): argp = argparse.ArgumentParser(prog="amt transcribe") argp.add_argument("model_name", help="name of model config file") argp.add_argument("cp", help="checkpoint path") - argp.add_argument("-load_path", help="wav file load path", required=True) - argp.add_argument("-save_path", help="midi file save path", required=True) + argp.add_argument("-load_path", help="path to mp3/wav file", required=False) + argp.add_argument( + "-load_dir", help="dir containing mp3/wav files", required=False + ) + argp.add_argument("-save_dir", help="dir to save midi files", required=True) + argp.add_argument( + "-multi_gpu", help="use all GPUs", action="store_true", default=False + ) + argp.add_argument("-bs", help="batch size", type=int, default=16) return argp.parse_args(sys.argv[2:]) @@ -99,37 +107,101 @@ def build_maestro(args): def transcribe(args): + import torch from torch.cuda import is_available as cuda_is_available from amt.tokenizer import AmtTokenizer - from amt.inference import greedy_sample + from amt.infer import batch_transcribe from amt.config import load_model_config from amt.model import ModelConfig, AmtEncoderDecoder - from aria.data.midi import MidiDict from aria.utils import _load_weight - assert os.path.isfile(args.load_path), "audio file not found" + assert cuda_is_available(), "CUDA device not found" assert os.path.isfile(args.cp), "model checkpoint file not found" - - if not cuda_is_available(): - print("CUDA device is not available. Using CPU instead.") - device = "cpu" - else: - device = "cuda" - + assert args.load_path or args.load_dir, "must give either load path or dir" + if args.load_path: + assert os.path.isfile(args.load_path), "audio file not found" + trans_mode = "single" + if args.load_dir: + assert os.path.isdir(args.load_dir), "load directory doesn't exist" + trans_mode = "batch" + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + assert os.path.isdir(args.save_dir), "save dir doesn't exist" + + # Setup model tokenizer = AmtTokenizer() model_config = ModelConfig(**load_model_config(args.model_name)) model_config.set_vocab_size(tokenizer.vocab_size) model = AmtEncoderDecoder(model_config) - model_state = _load_weight(ckpt_path=args.cp, device=device) + model_state = _load_weight(ckpt_path=args.cp) + + # Fix keys in compiled model checkpoint + _model_state = {} + for k, v in model_state.items(): + if k.startswith("_orig_mod."): + _model_state[k[len("_orig_mod.") :]] = v + else: + _model_state[k] = v + model_state = _model_state model.load_state_dict(model_state) + torch.multiprocessing.set_start_method("spawn") + + if trans_mode == "batch": + found_wav = glob.glob( + os.path.join(args.load_dir, "**/*.wav"), recursive=True + ) + found_mp3 = glob.glob( + os.path.join(args.load_dir, "**/*.mp3"), recursive=True + ) + print(f"Found {len(found_mp3)} mp3 and {len(found_wav)} wav files") + file_paths = found_mp3 + found_wav + else: + file_paths = [args.load_path] + + if args.multi_gpu: + # Generate chunks + gpu_ids = [ + int(id) for id in os.getenv("CUDA_VISIBLE_DEVICES").split(",") + ] + num_gpus = len(gpu_ids) + print(f"Visible gpu_ids: {gpu_ids}") + + chunk_size = (len(file_paths) // num_gpus) + 1 + chunks = [ + file_paths[i : i + chunk_size] + for i in range(0, len(file_paths), chunk_size) + ] + print(f"Split {len(file_paths)} files into {len(chunks)} chunks") + + processes = [] + for idx, chunk in enumerate(chunks): + print( + f"Starting process on cuda-{idx}: {len(chunk)} files to process" + ) + process = torch.multiprocessing.Process( + target=batch_transcribe, + args=( + chunk, + model, + args.save_dir, + args.bs, + gpu_ids[idx], + args.load_dir, + ), + ) + process.start() + processes.append(process) - mid_dict = greedy_sample( - model=model, - audio_path=args.load_path, - device=device, - ) - mid = mid_dict.to_midi() - mid.save(args.save_path) + for process in processes: + process.join() + + else: + batch_transcribe( + file_paths=file_paths, + model=model, + save_dir=args.save_dir, + batch_size=args.bs, + ) def main():