Skip to content

Commit

Permalink
Minor fixes to shallow fussion
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Oct 9, 2024
1 parent e4fa25a commit 6a0e41b
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 43 deletions.
126 changes: 108 additions & 18 deletions egs/librispeech/ASR/zipformer/ctc_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
import argparse
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -288,18 +289,18 @@ def get_parser():
)

parser.add_argument(
"--lm-type",
"--nnlm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)

parser.add_argument(
"--lm-scale",
"--nnlm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
default=0,
help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion.
Used only when `--use-shallow-fusion` is set to True.
""",
)
Expand All @@ -321,6 +322,47 @@ def get_parser():
""",
)

parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="ID of the backoff symbol in the ngram LM",
)

parser.add_argument(
"--lodr-ngram",
type=str,
help="The path to the lodr ngram",
)

parser.add_argument(
"--lodr-lm-scale",
type=float,
default=0,
help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.",
)

parser.add_argument(
"--context-score",
type=float,
default=0,
help="""
The bonus score of each token for the context biasing words/phrases.
0 means don't use contextual biasing.
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
""",
)

parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
""",
)

parser.add_argument(
"--skip-scoring",
type=str2bool,
Expand Down Expand Up @@ -358,7 +400,9 @@ def decode_one_batch(
batch: dict,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
NNLM: Optional[LmScorer] = None,
LODR_lm: Optional[NgramLm] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
Expand Down Expand Up @@ -466,7 +510,10 @@ def decode_one_batch(
token_ids = ctc_prefix_beam_search_shallow_fussion(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
LM=LM,
NNLM=NNLM,
LODR_lm=LODR_lm,
LODR_lm_scale=params.lodr_lm_scale,
context_graph=context_graph,
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
Expand Down Expand Up @@ -649,7 +696,9 @@ def decode_dataset(
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
NNLM: Optional[LmScorer] = None,
LODR_lm: Optional[NgramLm] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Expand Down Expand Up @@ -700,7 +749,9 @@ def decode_dataset(
batch=batch,
word_table=word_table,
G=G,
LM=LM,
NNLM=NNLM,
LODR_lm=LODR_lm,
context_graph=context_graph,
)

for name, hyps in hyps_dict.items():
Expand Down Expand Up @@ -835,7 +886,12 @@ def main():
if "prefix-beam-search" in params.decoding_method:
params.suffix += f"_beam-{params.beam}"
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
params.suffix += f"_lm-scale-{params.lm_scale}"
if params.nnlm_scale != 0:
params.suffix += f"_nnlm-scale-{params.nnlm_scale}"
if params.lodr_lm_scale != 0:
params.suffix += f"_lodr-scale-{params.lodr_lm_scale}"
if params.context_score != 0:
params.suffix += f"_context_score-{params.context_score}"

if params.use_averaged_model:
params.suffix += "_use-averaged-model"
Expand Down Expand Up @@ -947,17 +1003,49 @@ def main():
G = None

# only load the neural network LM if required
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
LM = LmScorer(
lm_type=params.lm_type,
NNLM = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.nnlm_scale != 0
):
NNLM = LmScorer(
lm_type=params.nnlm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
lm_scale=params.nnlm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
NNLM.to(device)
NNLM.eval()

LODR_lm = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.lodr_lm_scale != 0
):
assert os.path.exists(
params.lodr_ngram
), f"LODR ngram does not exists, given path : {params.lodr_ngram}"
logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}")
LODR_lm = NgramLm(
params.lodr_ngram,
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {LODR_lm.lm.num_states}")

context_graph = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.context_score != 0
):
assert os.path.exists(
params.context_file
), f"context_file does not exists, given path : {params.context_file}"
contexts = []
for line in open(params.context_file).readlines():
contexts.append(bpe_model.encode(line.strip()))
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)

logging.info("About to create model")
model = get_model(params)
Expand Down Expand Up @@ -1068,7 +1156,9 @@ def main():
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
LM=LM,
NNLM=NNLM,
LODR_lm=LODR_lm,
context_graph=context_graph,
)

save_asr_output(
Expand Down
51 changes: 26 additions & 25 deletions icefall/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,7 +1736,7 @@ def _step_worker(
B: HypothesisList,
beam: int = 4,
blank_id: int = 0,
lm_scale: float = 0,
nnlm_scale: float = 0,
LODR_lm_scale: float = 0,
context_graph: Optional[ContextGraph] = None,
) -> HypothesisList:
Expand Down Expand Up @@ -1815,14 +1815,16 @@ def _step_worker(
if update_prefix:
lm_score = hyp.lm_score
if hyp.lm_log_probs is not None:
lm_score += hyp.lm_log_probs[new_token] * lm_scale
lm_score = lm_score + hyp.lm_log_probs[new_token] * nnlm_scale
new_hyp.lm_log_probs = None

if context_graph is not None and hyp.context_state is not None:
context_score, new_context_state = context_graph.forward_one_step(
hyp.context_state, new_token
)
lm_score += context_score
(
context_score,
new_context_state,
matched_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
lm_score = lm_score + context_score
new_hyp.context_state = new_context_state

if hyp.LODR_state is not None:
Expand All @@ -1833,7 +1835,7 @@ def _step_worker(
state_cost.lm_score,
hyp.LODR_state.lm_score,
)
lm_score += LODR_lm_scale * current_ngram_score
lm_score = lm_score + LODR_lm_scale * current_ngram_score
new_hyp.LODR_state = state_cost

new_hyp.lm_score = lm_score
Expand Down Expand Up @@ -1944,7 +1946,7 @@ def ctc_prefix_beam_search_shallow_fussion(
blank_id: int = 0,
LODR_lm: Optional[NgramLm] = None,
LODR_lm_scale: Optional[float] = 0,
LM: Optional[LmScorer] = None,
NNLM: Optional[LmScorer] = None,
context_graph: Optional[ContextGraph] = None,
) -> List[List[int]]:
"""Implement prefix search decoding in "Connectionist Temporal Classification:
Expand Down Expand Up @@ -1981,17 +1983,16 @@ def ctc_prefix_beam_search_shallow_fussion(
encoder_out_lens = encoder_out_lens.tolist()
device = ctc_output.device

lm_scale = 0
nnlm_scale = 0
init_scores = None
init_states = None

if LM is not None:
lm_scale = LM.lm_scale
sos_id = getattr(LM, "sos_id", 1)
if NNLM is not None:
nnlm_scale = NNLM.lm_scale
sos_id = getattr(NNLM, "sos_id", 1)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
lens = torch.tensor([1]).to(device)
init_scores, init_states = LM.score_token(sos_token, lens)
init_scores, init_states = NNLM.score_token(sos_token, lens)
init_scores, init_states = init_scores.cpu(), (
init_states[0].cpu(),
init_states[1].cpu(),
Expand All @@ -2016,16 +2017,16 @@ def ctc_prefix_beam_search_shallow_fussion(
if j < encoder_out_lens[i]:
log_probs, indexes = topk_values[i][j], topk_indexes[i][j]
B[i] = _step_worker(
log_probs,
indexes,
B[i],
beam,
blank_id,
lm_scale=lm_scale,
log_probs=log_probs,
indexes=indexes,
B=B[i],
beam=beam,
blank_id=blank_id,
nnlm_scale=nnlm_scale,
LODR_lm_scale=LODR_lm_scale,
context_graph=context_graph,
)
if LM is None:
if NNLM is None:
continue
# update lm_log_probs
token_list = [] # a list of list
Expand All @@ -2035,7 +2036,7 @@ def ctc_prefix_beam_search_shallow_fussion(
for batch_idx, hyps in enumerate(B):
for hyp in hyps:
if hyp.lm_log_probs is None: # those hyps that prefix changes
if LM.lm_type == "rnn":
if NNLM.lm_type == "rnn":
token_list.append([hyp.ys[-1]])
# store the LSTM states
hs.append(hyp.state[0])
Expand All @@ -2046,7 +2047,7 @@ def ctc_prefix_beam_search_shallow_fussion(
indexes.append((batch_idx, hyp.key))
if len(token_list) != 0:
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
if LM.lm_type == "rnn":
if NNLM.lm_type == "rnn":
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
Expand All @@ -2065,13 +2066,13 @@ def ctc_prefix_beam_search_shallow_fussion(
)
state = None

scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
scores, lm_states = NNLM.score_token(tokens_to_score, x_lens, state)
scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu())
assert scores.size(0) == len(indexes), (scores.size(0), len(indexes))
for i in range(scores.size(0)):
batch_idx, key = indexes[i]
B[batch_idx][key].lm_log_probs = scores[i]
if LM.lm_type == "rnn":
if NNLM.lm_type == "rnn":
state = (
lm_states[0][:, i, :].unsqueeze(1),
lm_states[1][:, i, :].unsqueeze(1),
Expand Down

0 comments on commit 6a0e41b

Please sign in to comment.