From 5a79aad78b0c79f7eabdd40066175d0b2c29e854 Mon Sep 17 00:00:00 2001 From: NextFire Date: Sun, 29 Oct 2023 01:15:55 -0400 Subject: [PATCH] subtitles: rewrite make_ass --- notebook/yohane.ipynb | 4 +- yohane/__main__.py | 4 +- yohane/{audio_processing.py => audio.py} | 3 +- yohane/{text_processing.py => lyrics.py} | 8 +- yohane/subtitles.py | 162 +++++++++-------------- 5 files changed, 75 insertions(+), 106 deletions(-) rename yohane/{audio_processing.py => audio.py} (96%) rename yohane/{text_processing.py => lyrics.py} (94%) diff --git a/notebook/yohane.ipynb b/notebook/yohane.ipynb index 53faeb331..fe398b250 100644 --- a/notebook/yohane.ipynb +++ b/notebook/yohane.ipynb @@ -19,9 +19,9 @@ "import logging\n", "from pathlib import Path\n", "\n", - "from yohane.audio_processing import compute_alignments, prepare_audio\n", + "from yohane.audio import compute_alignments, prepare_audio\n", + "from yohane.lyrics import Lyrics\n", "from yohane.subtitles import make_ass\n", - "from yohane.text_processing import Lyrics\n", "\n", "logging.basicConfig(level=\"INFO\")\n" ] diff --git a/yohane/__main__.py b/yohane/__main__.py index f715d177f..492ffbf57 100644 --- a/yohane/__main__.py +++ b/yohane/__main__.py @@ -5,9 +5,9 @@ import click -from yohane.audio_processing import compute_alignments, prepare_audio +from yohane.audio import compute_alignments, prepare_audio +from yohane.lyrics import Lyrics from yohane.subtitles import make_ass -from yohane.text_processing import Lyrics logger = logging.getLogger(__name__) diff --git a/yohane/audio_processing.py b/yohane/audio.py similarity index 96% rename from yohane/audio_processing.py rename to yohane/audio.py index e789c4570..4a4211e2b 100644 --- a/yohane/audio_processing.py +++ b/yohane/audio.py @@ -61,7 +61,8 @@ def compute_alignments(waveform: torch.Tensor, transcript: list[str]): with torch.inference_mode(): emission, _ = model(waveform.to(device)) emission = cast(torch.Tensor, emission) - tokens = cast(list[list[int]], tokenizer(transcript)) + tokens = tokenizer(transcript) + tokens = cast(list[list[int]], tokens) token_spans = aligner(emission[0], tokens) return emission, token_spans diff --git a/yohane/text_processing.py b/yohane/lyrics.py similarity index 94% rename from yohane/text_processing.py rename to yohane/lyrics.py index 619a71864..2616ab358 100644 --- a/yohane/text_processing.py +++ b/yohane/lyrics.py @@ -5,7 +5,7 @@ @dataclass -class Text: +class _Text: raw: str @cached_property @@ -18,21 +18,21 @@ def transcript(self): @dataclass -class Lyrics(Text): +class Lyrics(_Text): @cached_property def lines(self): return [Line(line) for line in filter(None, self.raw.splitlines())] @dataclass -class Line(Text): +class Line(_Text): @cached_property def words(self): return [Word(word) for word in filter(None, self.transcript)] @dataclass -class Word(Text): +class Word(_Text): @cached_property def syllables(self): return auto_split(self.normalized) diff --git a/yohane/subtitles.py b/yohane/subtitles.py index 9e715556d..9baf249df 100644 --- a/yohane/subtitles.py +++ b/yohane/subtitles.py @@ -1,10 +1,9 @@ -from typing import cast from pysubs2 import SSAEvent, SSAFile from torch import Tensor from torchaudio.functional import TokenSpan -from yohane.audio_processing import bundle -from yohane.text_processing import Lyrics +from yohane.audio import bundle +from yohane.lyrics import Lyrics def make_ass( @@ -19,101 +18,70 @@ def make_ass( sample_rate = bundle.sample_rate tokenizer = bundle.get_tokenizer() - # init subs and event subs = SSAFile() - event = SSAEvent(-1) - k_cumul = 0 - - # init iterators - lines_iter = iter(lyrics.lines) - curr_line = next(lines_iter) - words_iter = iter(curr_line.words) - curr_word = next(words_iter) - syllables_iter = iter(curr_word.syllables) - - # iterate over all aligned tokens - # 1 span = 1 word - for i, spans in enumerate(token_spans): - j = 0 - while True: - curr_syllable = next(syllables_iter) - syllable_tokens = cast(list[list[int]], tokenizer([curr_syllable])) - nb_tokens = len(syllable_tokens[0]) - - # start and end time of syllable - x0 = ratio * spans[j].start - x1 = ratio * spans[j + nb_tokens - 1].end - t_start = x0 / sample_rate # s - t_end = x1 / sample_rate # s - - if event.start == -1: - # set event start time if this is a new one - event.start = int(t_start * 1000) # ms - elif j == 0: + spans_ids = 0 + + for line in lyrics.lines: + event: SSAEvent | None = None + k_cumul = 0 + + for word in line.words: + spans = token_spans[spans_ids] + span_idx = 0 + + for syllable in word.syllables: + syllable_tokens = tokenizer([syllable]) + nb_tokens = len(syllable_tokens[0]) + + print(syllable, span_idx, span_idx + nb_tokens - 1, len(spans)) + + # start and end time of syllable + x0 = ratio * spans[span_idx].start + x1 = ratio * spans[span_idx + nb_tokens - 1].end + t_start = x0 / sample_rate # s + t_end = x1 / sample_rate # s + + if event is None: + # new line, new event + event = SSAEvent(int(t_start * 1000)) # ms + elif span_idx == 0: + # k tag logic: + # new word starting on the same line + # add a space and adjust timing + space_dt = round(t_start * 100 - event.start / 10 - k_cumul) # cs + event.text += rf"{{\k{space_dt}}} " + k_cumul += space_dt # cs + # k tag logic: - # if this is a new word, add a space and adjust timing - space_duration = round(t_start * 100 - event.start / 10 - k_cumul) # cs - event.text += rf"{{\k{space_duration}}} " - k_cumul += space_duration # cs - - # k tag logic: - # snap the token end time to the next token start time - try: - next_x0 = ratio * spans[j + nb_tokens].start - next_t_start = next_x0 / sample_rate - adjusted_t_end = next_t_start - except IndexError: - adjusted_t_end = t_end - - k_duration = round((adjusted_t_end - t_start) * 100) # cs - k_cumul += k_duration # cs - - event.text += rf"{{\k{k_duration}}}{curr_syllable}" - - # increment event end time - event.end = int(adjusted_t_end * 1000) # ms - - j += nb_tokens - if j >= len(spans): - break - - # if that was the last token, don't enter the iter logic - if i == len(token_spans) - 1: - break - - try: - # add a space, advance to next word - curr_word = next(words_iter) - syllables_iter = iter(curr_word.syllables) - except StopIteration: - try: - # no more words in line: - # save the event in subs, reset it and advance to next line - - # save the raw line in a comment - comment = SSAEvent( - event.start, event.end, curr_line.raw, type="Comment" - ) - subs.append(comment) - - # save the timed event - event.text = event.text.strip() - subs.append(event) - event = SSAEvent(-1) - k_cumul = 0 - - # iterators logic - curr_line = next(lines_iter) - words_iter = iter(curr_line.words) - curr_word = next(words_iter) - syllables_iter = iter(curr_word.syllables) - except StopIteration as e: - raise RuntimeError("should not happen") from e - - # save last line - comment = SSAEvent(event.start, event.end, curr_line.raw, type="Comment") - subs.append(comment) - event.text = event.text.strip() - subs.append(event) + # snap the token end time to the next token start time + try: + next_x0 = ratio * spans[span_idx + nb_tokens].start + next_t_start = next_x0 / sample_rate # s + adjusted_t_end = next_t_start # s + except IndexError: + adjusted_t_end = t_end # s + + k_duration = round((adjusted_t_end - t_start) * 100) # cs + k_cumul += k_duration # cs + + event.text += rf"{{\k{k_duration}}}{syllable}" + + # increment event end time + event.end = int(adjusted_t_end * 1000) # ms + + # consumed tokens + span_idx += nb_tokens + + # consumed spans + spans_ids += 1 + + if event is not None: + # save the raw line in a comment + comment = SSAEvent(event.start, event.end, line.raw, type="Comment") + subs.append(comment) + + # save the timed event + event.text = event.text.strip() + subs.append(event) return subs