diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 85b32d3..e78f717 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -783,17 +783,17 @@ def get_save_path( file_path: str, input_dir: str | None, save_dir: str, - idx: int | str = "", + idx_str: int | str = "", ): if input_dir is None: save_path = os.path.join( save_dir, - os.path.splitext(os.path.basename(file_path))[0] + f"{idx}.mid", + os.path.splitext(os.path.basename(file_path))[0] + f"{idx_str}.mid", ) else: input_rel_path = os.path.relpath(file_path, input_dir) save_path = os.path.join( - save_dir, os.path.splitext(input_rel_path)[0] + f"{idx}.mid" + save_dir, os.path.splitext(input_rel_path)[0] + f"{idx_str}.mid" ) if not os.path.isdir(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) @@ -810,7 +810,7 @@ def process_file( save_dir: str, input_dir: str, logger: logging.Logger, - segments: List[Tuple[int, int]] | None = None, + segments: List[Tuple[int, Tuple[int, int]]] | None = None, ): def _save_seq(_seq: List, _save_path: str): if os.path.exists(_save_path): @@ -852,12 +852,17 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int): pid = threading.get_ident() if segments is None: - segments = [None] + # process_file and get_wav_segments will interpret segment=None as + # processing the entire file + segments = [(None, None)] if len(segments) == 0: logger.info(f"No segments to transcribe, skipping file: {file_path}") - for idx, segment in enumerate(segments): + for idx, segment in segments: + idx_str = f"_{idx}" if idx is not None else "" + save_path = get_save_path(file_path, input_dir, save_dir, idx_str) + try: seq = transcribe_file( file_path, @@ -876,15 +881,17 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int): logger.info(f"Removed {res_rmv_cnt} from result queue") continue - logger.info(f"Finished file: {file_path} (segment: {idx})") + logger.info( + f"Finished file: {file_path} (segment: {idx if idx is not None else 'full'})" + ) if len(seq) < 500: - logger.info(f"Skipping seq - too short (segment {idx})") + logger.info( + f"Skipping seq - too short (segment {idx if idx is not None else 'full'})" + ) else: logger.debug( - f"Saving seq of length {len(seq)} from file: {file_path} (segment: {idx})" + f"Saving seq of length {len(seq)} from file: {file_path} (segment: {idx if idx is not None else 'full'})" ) - idx = f"_{idx}" if segment is not None else "" - save_path = get_save_path(file_path, input_dir, save_dir, idx) _save_seq(seq, save_path) logger.info(f"{file_queue.qsize()} file(s) remaining in queue") @@ -997,20 +1004,28 @@ def batch_transcribe( files_to_process, key=lambda x: os.path.getsize(x["path"]), reverse=True ) for file_to_process in files_to_process: - # Only add to file_queue if transcription MIDI file doesn't exist - if ( - os.path.isfile( + if "segments" in file_to_process: + # Process files with segments + unsaved_segments = [] + for idx, segment in enumerate(file_to_process["segments"]): + segment_save_path = get_save_path( + file_to_process["path"], + input_dir, + save_dir, + idx_str=f"_{idx}", + ) + if not os.path.isfile(segment_save_path): + unsaved_segments.append((idx, segment)) + + if unsaved_segments: + file_to_process["segments"] = unsaved_segments + file_queue.put(file_to_process) + else: + # Process files without segments (whole file) + if not os.path.isfile( get_save_path(file_to_process["path"], input_dir, save_dir) - ) - is False - ) and os.path.isfile( - get_save_path( - file_to_process["path"], input_dir, save_dir, idx="_0" - ) - ) is False: - file_queue.put(file_to_process) - elif len(files_to_process) == 1: - file_queue.put(file_to_process) + ): + file_queue.put(file_to_process) logger.info( f"Files to process: {file_queue.qsize()}/{len(files_to_process)}" @@ -1026,7 +1041,7 @@ def batch_transcribe( file_queue.qsize(), ) num_processes_per_worker = min( - 3 * (batch_size // num_workers), file_queue.qsize() // num_workers + 5 * (batch_size // num_workers), file_queue.qsize() // num_workers ) mp_manager = Manager()