Skip to content

Commit

Permalink
subtitles: rewrite make_ass
Browse files Browse the repository at this point in the history
  • Loading branch information
NextFire committed Oct 29, 2023
1 parent cda0d42 commit 5a79aad
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 106 deletions.
4 changes: 2 additions & 2 deletions notebook/yohane.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
"import logging\n",
"from pathlib import Path\n",
"\n",
"from yohane.audio_processing import compute_alignments, prepare_audio\n",
"from yohane.audio import compute_alignments, prepare_audio\n",
"from yohane.lyrics import Lyrics\n",
"from yohane.subtitles import make_ass\n",
"from yohane.text_processing import Lyrics\n",
"\n",
"logging.basicConfig(level=\"INFO\")\n"
]
Expand Down
4 changes: 2 additions & 2 deletions yohane/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import click

from yohane.audio_processing import compute_alignments, prepare_audio
from yohane.audio import compute_alignments, prepare_audio
from yohane.lyrics import Lyrics
from yohane.subtitles import make_ass
from yohane.text_processing import Lyrics

logger = logging.getLogger(__name__)

Expand Down
3 changes: 2 additions & 1 deletion yohane/audio_processing.py → yohane/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def compute_alignments(waveform: torch.Tensor, transcript: list[str]):
with torch.inference_mode():
emission, _ = model(waveform.to(device))
emission = cast(torch.Tensor, emission)
tokens = cast(list[list[int]], tokenizer(transcript))
tokens = tokenizer(transcript)
tokens = cast(list[list[int]], tokens)
token_spans = aligner(emission[0], tokens)

return emission, token_spans
8 changes: 4 additions & 4 deletions yohane/text_processing.py → yohane/lyrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@dataclass
class Text:
class _Text:
raw: str

@cached_property
Expand All @@ -18,21 +18,21 @@ def transcript(self):


@dataclass
class Lyrics(Text):
class Lyrics(_Text):
@cached_property
def lines(self):
return [Line(line) for line in filter(None, self.raw.splitlines())]


@dataclass
class Line(Text):
class Line(_Text):
@cached_property
def words(self):
return [Word(word) for word in filter(None, self.transcript)]


@dataclass
class Word(Text):
class Word(_Text):
@cached_property
def syllables(self):
return auto_split(self.normalized)
Expand Down
162 changes: 65 additions & 97 deletions yohane/subtitles.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import cast
from pysubs2 import SSAEvent, SSAFile
from torch import Tensor
from torchaudio.functional import TokenSpan

from yohane.audio_processing import bundle
from yohane.text_processing import Lyrics
from yohane.audio import bundle
from yohane.lyrics import Lyrics


def make_ass(
Expand All @@ -19,101 +18,70 @@ def make_ass(
sample_rate = bundle.sample_rate
tokenizer = bundle.get_tokenizer()

# init subs and event
subs = SSAFile()
event = SSAEvent(-1)
k_cumul = 0

# init iterators
lines_iter = iter(lyrics.lines)
curr_line = next(lines_iter)
words_iter = iter(curr_line.words)
curr_word = next(words_iter)
syllables_iter = iter(curr_word.syllables)

# iterate over all aligned tokens
# 1 span = 1 word
for i, spans in enumerate(token_spans):
j = 0
while True:
curr_syllable = next(syllables_iter)
syllable_tokens = cast(list[list[int]], tokenizer([curr_syllable]))
nb_tokens = len(syllable_tokens[0])

# start and end time of syllable
x0 = ratio * spans[j].start
x1 = ratio * spans[j + nb_tokens - 1].end
t_start = x0 / sample_rate # s
t_end = x1 / sample_rate # s

if event.start == -1:
# set event start time if this is a new one
event.start = int(t_start * 1000) # ms
elif j == 0:
spans_ids = 0

for line in lyrics.lines:
event: SSAEvent | None = None
k_cumul = 0

for word in line.words:
spans = token_spans[spans_ids]
span_idx = 0

for syllable in word.syllables:
syllable_tokens = tokenizer([syllable])
nb_tokens = len(syllable_tokens[0])

print(syllable, span_idx, span_idx + nb_tokens - 1, len(spans))

# start and end time of syllable
x0 = ratio * spans[span_idx].start
x1 = ratio * spans[span_idx + nb_tokens - 1].end
t_start = x0 / sample_rate # s
t_end = x1 / sample_rate # s

if event is None:
# new line, new event
event = SSAEvent(int(t_start * 1000)) # ms
elif span_idx == 0:
# k tag logic:
# new word starting on the same line
# add a space and adjust timing
space_dt = round(t_start * 100 - event.start / 10 - k_cumul) # cs
event.text += rf"{{\k{space_dt}}} "
k_cumul += space_dt # cs

# k tag logic:
# if this is a new word, add a space and adjust timing
space_duration = round(t_start * 100 - event.start / 10 - k_cumul) # cs
event.text += rf"{{\k{space_duration}}} "
k_cumul += space_duration # cs

# k tag logic:
# snap the token end time to the next token start time
try:
next_x0 = ratio * spans[j + nb_tokens].start
next_t_start = next_x0 / sample_rate
adjusted_t_end = next_t_start
except IndexError:
adjusted_t_end = t_end

k_duration = round((adjusted_t_end - t_start) * 100) # cs
k_cumul += k_duration # cs

event.text += rf"{{\k{k_duration}}}{curr_syllable}"

# increment event end time
event.end = int(adjusted_t_end * 1000) # ms

j += nb_tokens
if j >= len(spans):
break

# if that was the last token, don't enter the iter logic
if i == len(token_spans) - 1:
break

try:
# add a space, advance to next word
curr_word = next(words_iter)
syllables_iter = iter(curr_word.syllables)
except StopIteration:
try:
# no more words in line:
# save the event in subs, reset it and advance to next line

# save the raw line in a comment
comment = SSAEvent(
event.start, event.end, curr_line.raw, type="Comment"
)
subs.append(comment)

# save the timed event
event.text = event.text.strip()
subs.append(event)
event = SSAEvent(-1)
k_cumul = 0

# iterators logic
curr_line = next(lines_iter)
words_iter = iter(curr_line.words)
curr_word = next(words_iter)
syllables_iter = iter(curr_word.syllables)
except StopIteration as e:
raise RuntimeError("should not happen") from e

# save last line
comment = SSAEvent(event.start, event.end, curr_line.raw, type="Comment")
subs.append(comment)
event.text = event.text.strip()
subs.append(event)
# snap the token end time to the next token start time
try:
next_x0 = ratio * spans[span_idx + nb_tokens].start
next_t_start = next_x0 / sample_rate # s
adjusted_t_end = next_t_start # s
except IndexError:
adjusted_t_end = t_end # s

k_duration = round((adjusted_t_end - t_start) * 100) # cs
k_cumul += k_duration # cs

event.text += rf"{{\k{k_duration}}}{syllable}"

# increment event end time
event.end = int(adjusted_t_end * 1000) # ms

# consumed tokens
span_idx += nb_tokens

# consumed spans
spans_ids += 1

if event is not None:
# save the raw line in a comment
comment = SSAEvent(event.start, event.end, line.raw, type="Comment")
subs.append(comment)

# save the timed event
event.text = event.text.strip()
subs.append(event)

return subs

0 comments on commit 5a79aad

Please sign in to comment.