diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 3f32a0d..057649d 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -35,7 +35,7 @@ torch._inductor.config.fx_graph_cache = True MAX_SEQ_LEN = 4096 -MAX_BLOCK_LEN = 4096 +MAX_BLOCK_LEN = 2048 LEN_MS = 30000 STRIDE_FACTOR = 3 CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR @@ -294,7 +294,7 @@ def process_segments( # to make sure that a sequence of the correct format is returned. Right now # it messes things up somehow if not all(_idx <= idx for _idx in eos_idxs): - logger.warning("Context length overflow when transcribing segment") + logger.warning("Context length overflow when transcribing segment(s)") results = [ tokenizer.decode(seq[_idx, : eos_idxs[_idx] + 1]) diff --git a/amt/run.py b/amt/run.py index 32bbd01..f3993d2 100644 --- a/amt/run.py +++ b/amt/run.py @@ -467,7 +467,7 @@ def transcribe( files_to_process = [] for audio_path in file_paths: - if segments_by_audio_file.get(audio_path, None): + if segments_by_audio_file.get(audio_path, None) is not None: file_info = { "path": audio_path, "segments": segments_by_audio_file[audio_path],