From 60b96f5532e92cc48ab8d0dac0429f8127af6966 Mon Sep 17 00:00:00 2001 From: loubbrad Date: Fri, 22 Mar 2024 16:47:39 +0000 Subject: [PATCH] format --- amt/inference/transcribe.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 038fe3d..9109e6a 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -468,7 +468,7 @@ def transcribe_file( logger.info(f"Finished segment (eos_tok): {file_path}") else: # This might need it's logic adjusted - + seq = _truncate_seq( seq, CHUNK_LEN_MS, @@ -477,9 +477,7 @@ def transcribe_file( ) if len(seq) == 1: - logger.error( - f"Failed to transcribe segment: {file_path}" - ) + logger.error(f"Failed to transcribe segment: {file_path}") if len(concat_seq) > 500: res.append(concat_seq) else: @@ -594,17 +592,29 @@ def worker( while len(threads) < tasks_per_worker and not file_queue.empty(): logging.info("Starting worker") file_path = file_queue.get() - t = threading.Thread(target=process_file, args=(file_path, file_queue, gpu_task_queue, result_queue, tokenizer, save_dir, input_dir, logger)) + t = threading.Thread( + target=process_file, + args=( + file_path, + file_queue, + gpu_task_queue, + result_queue, + tokenizer, + save_dir, + input_dir, + logger, + ), + ) t.start() threads.append(t) - + threads = [t for t in threads if t.is_alive()] - - time.sleep(0.1) - + + time.sleep(0.1) + for t in threads: t.join() - + except Exception as e: logger.error(f"File worker failed with exception: {e}") finally: @@ -650,7 +660,7 @@ def batch_transcribe( result_queue, save_dir, input_dir, - # Wait for all threads to finish + # Wait for all threads to finish 4, ), )