diff --git a/egs/aishell/ASR/zipformer/attention_decoder.py b/egs/aishell/ASR/zipformer/attention_decoder.py new file mode 120000 index 0000000000..384e1b95ea --- /dev/null +++ b/egs/aishell/ASR/zipformer/attention_decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/ctc_decode.py b/egs/aishell/ASR/zipformer/ctc_decode.py new file mode 100755 index 0000000000..01df090ab1 --- /dev/null +++ b/egs/aishell/ASR/zipformer/ctc_decode.py @@ -0,0 +1,864 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-greedy-search +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-greedy-search + +(2) ctc-decoding +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(3) 1best +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method 1best + +(4) nbest +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method nbest + +(5) nbest-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring + +(6) whole-lattice-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring + +(7) attention-decoder-rescoring-no-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --decoding-method attention-decoder-rescoring-no-ngram + +(8) attention-decoder-rescoring-with-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method attention-decoder-rescoring-with-ngram +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from lhotse import set_caching_enabled +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall.context_graph import ContextGraph, ContextState +from icefall.ngram_lm import NgramLm, NgramLmStateCost +from icefall.lm_wrapper import LmScorer + + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + ctc_greedy_search, + ctc_prefix_beam_search, + ctc_prefix_beam_search_attention_decoder_rescoring, + ctc_prefix_beam_search_shallow_fussion, + get_lattice, + one_best_decoding, + rescore_with_attention_decoder_no_ngram, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (3) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, # for k2 fsa composition + "output_beam": 8, # for k2 fsa composition + "beam": 4, # for prefix-beam-search + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + batch: dict, + H: Optional[k2.Fsa], + LM: Optional[LmScorer] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + # TODO + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + batch_size = encoder_out.size(0) + + if params.decoding_method == "ctc-greedy-search": + hyp_tokens = ctc_greedy_search(ctc_output, encoder_out_lens) + hyps = [] + for i in range(batch_size): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + key = "ctc-greedy-search" + return {key: hyps} + + if params.decoding_method == "prefix-beam-search": + hyp_tokens = ctc_prefix_beam_search( + ctc_output=ctc_output, encoder_out_lens=encoder_out_lens + ) + hyps = [] + for i in range(batch_size): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + key = "prefix-beam-search" + return {key: hyps} + + if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring": + best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output=ctc_output, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + ans = dict() + for a_scale_str, hyp_tokens in best_path_dict.items(): + hyps = [] + for i in range(batch_size): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + hyp_tokens = ctc_prefix_beam_search_shallow_fussion( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + LM=LM, + ) + hyps = [] + for i in range(batch_size): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + key = "prefix-beam-search-shallow-fussion" + return {key: hyps} + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + assert H is not None + decoding_graph = H + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + hyp_tokens = get_texts(best_path) + hyps = [] + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + key = "ctc-decoding" + return {key: hyps} # note: returns words + + if params.decoding_method == "attention-decoder-rescoring-no-ngram": + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + ans = dict() + for a_scale_str, best_path in best_path_dict.items(): + # token_ids is a lit-of-list of IDs + hyps = [] + hyp_tokens = get_texts(best_path) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + ans[a_scale_str] = hyps + return ans + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + H: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + batch=batch, + lexicon=lexicon, + H=H, + LM=LM, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results, char_level=True) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + if params.decoding_method == "attention-decoder-rescoring-no-ngram": + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + + test_set_wers = dict() + for key, results in results_dict.items(): + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, + f"{test_set_name}-{key}", + results, + enable_log=enable_log, + compute_CER=True, + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "ctc-greedy-search", + "prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + params.suffix += f"_lm-scale-{params.lm_scale}" + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + params.eos_id = 1 + params.sos_id = 1 + + if params.decoding_method in [ + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: + H = k2.ctc_topo( + max_token=max_token_id, + modified=True, + device=device, + ) + else: + H = None + + # only load the neural network LM if required + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + dev_cuts = aishell.valid_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) + dev_dl = aishell.valid_dataloaders(dev_cuts) + + test_cuts = aishell.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = aishell.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + H=H, + lexicon=lexicon, + LM=LM, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/label_smoothing.py b/egs/aishell/ASR/zipformer/label_smoothing.py new file mode 120000 index 0000000000..175c633cc7 --- /dev/null +++ b/egs/aishell/ASR/zipformer/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/spec_augment.py b/egs/aishell/ASR/zipformer/spec_augment.py new file mode 120000 index 0000000000..d00c7c9ddc --- /dev/null +++ b/egs/aishell/ASR/zipformer/spec_augment.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/spec_augment.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index cd253c5970..e01025cb21 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -61,6 +61,7 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import AishellAsrDataModule +from attention_decoder import AttentionDecoderModel from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -96,6 +97,7 @@ setup_logger, str2bool, ) +from spec_augment import SpecAugment LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -216,6 +218,41 @@ def add_model_arguments(parser: argparse.ArgumentParser): """, ) + parser.add_argument( + "--attention-decoder-dim", + type=int, + default=512, + help="""Dimension used in the attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-dim", + type=int, + default=512, + help="""Attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-dim", + type=int, + default=2048, + help="""Feedforward dimension used in attention decoder""", + ) + parser.add_argument( "--causal", type=str2bool, @@ -239,6 +276,34 @@ def add_model_arguments(parser: argparse.ArgumentParser): chunk left-context frames will be chosen randomly from this list; else not relevant.""", ) + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -379,6 +444,41 @@ def get_parser(): with this parameter before adding to the final loss.""", ) + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.15, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.0, + help="When using cr-ctc, we increase the time-masking ratio.", + ) + + parser.add_argument( + "--cr-loss-masked-scale", + type=float, + default=1.0, + help="The value used to scale up the cr_loss at masked positions", + ) + + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + parser.add_argument( "--seed", type=int, @@ -507,6 +607,9 @@ def get_params() -> AttributeDict: # parameters for zipformer "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, "warm_step": 2000, "env_info": get_env_info(), } @@ -579,24 +682,79 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=params.attention_decoder_dim, + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=params.attention_decoder_attention_dim, + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_dim, + memory_dim=max(_to_int_tuple(params.encoder_dim)), + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), + attention_decoder=attention_decoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, ) return model +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -722,6 +880,7 @@ def compute_loss( graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, + spec_augment: Optional[SpecAugment] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -738,8 +897,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] @@ -757,32 +916,62 @@ def compute_loss( y = graph_compiler.texts_to_ids(texts) y = k2.RaggedTensor(y).to(device) + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + with torch.set_grad_enabled(is_training): - losses = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, + cr_loss_masked_scale=params.cr_loss_masked_scale, ) - simple_loss, pruned_loss = losses[:2] - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + loss += params.cr_loss_scale * cr_loss + + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss assert loss.requires_grad == is_training @@ -793,8 +982,15 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() return loss, info @@ -842,6 +1038,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -868,6 +1065,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -917,6 +1116,7 @@ def save_bad_model(suffix: str = ""): graph_compiler=graph_compiler, batch=batch, is_training=True, + spec_augment=spec_augment, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1080,8 +1280,18 @@ def run(rank, world_size, args): ) params.blank_id = lexicon.token_table[""] + params.sos_id = params.eos_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 + if not params.use_transducer: + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, + params.attention_decoder_loss_scale, + ) + logging.info(params) logging.info("About to create model") @@ -1090,6 +1300,13 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + if params.use_cr_ctc: + assert params.use_ctc + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1199,6 +1416,7 @@ def remove_short_and_long_utt(c: Cut): optimizer=optimizer, graph_compiler=graph_compiler, params=params, + spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) @@ -1226,6 +1444,7 @@ def remove_short_and_long_utt(c: Cut): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, + spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1292,6 +1511,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, graph_compiler: CharCtcTrainingGraphCompiler, params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1309,6 +1529,7 @@ def scan_pessimistic_batches_for_oom( graph_compiler=graph_compiler, batch=batch, is_training=True, + spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad() diff --git a/egs/gigaspeech/ASR/zipformer/attention_decoder.py b/egs/gigaspeech/ASR/zipformer/attention_decoder.py new file mode 120000 index 0000000000..384e1b95ea --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/attention_decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py index 651f20cb65..f9597379b7 100755 --- a/egs/gigaspeech/ASR/zipformer/ctc_decode.py +++ b/egs/gigaspeech/ASR/zipformer/ctc_decode.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 # -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, # Liyong Guo, # Quandong Wang, -# Zengwei Yao) +# Zengwei Yao, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -21,7 +22,16 @@ """ Usage: -(1) ctc-decoding +(1) ctc-greedy-search +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-greedy-search + +(2) ctc-decoding ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -30,7 +40,7 @@ --max-duration 600 \ --decoding-method ctc-decoding -(2) 1best +(3) 1best ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -40,7 +50,7 @@ --hlg-scale 0.6 \ --decoding-method 1best -(3) nbest +(4) nbest ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -50,7 +60,7 @@ --hlg-scale 0.6 \ --decoding-method nbest -(4) nbest-rescoring +(5) nbest-rescoring ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -62,7 +72,7 @@ --lm-dir data/lm \ --decoding-method nbest-rescoring -(5) whole-lattice-rescoring +(6) whole-lattice-rescoring ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -73,6 +83,29 @@ --nbest-scale 1.0 \ --lm-dir data/lm \ --decoding-method whole-lattice-rescoring + +(7) attention-decoder-rescoring-no-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --decoding-method attention-decoder-rescoring-no-ngram + +(8) attention-decoder-rescoring-with-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method attention-decoder-rescoring-with-ngram """ @@ -87,8 +120,16 @@ import sentencepiece as spm import torch import torch.nn as nn + from asr_datamodule import GigaSpeechAsrDataModule -from train import add_model_arguments, get_model, get_params +from gigaspeech_scoring import asr_text_post_processing + +from lhotse import set_caching_enabled +from train_cr_aed import add_model_arguments, get_model, get_params + +from icefall.context_graph import ContextGraph, ContextState +from icefall.ngram_lm import NgramLm, NgramLmStateCost +from icefall.lm_wrapper import LmScorer from icefall.checkpoint import ( average_checkpoints, @@ -97,10 +138,16 @@ load_checkpoint, ) from icefall.decode import ( + ctc_greedy_search, + ctc_prefix_beam_search, + ctc_prefix_beam_search_attention_decoder_rescoring, + ctc_prefix_beam_search_shallow_fussion, get_lattice, nbest_decoding, nbest_oracle, one_best_decoding, + rescore_with_attention_decoder_no_ngram, + rescore_with_attention_decoder_with_ngram, rescore_with_n_best_list, rescore_with_whole_lattice, ) @@ -195,23 +242,30 @@ def get_parser(): default="ctc-decoding", help="""Decoding method. Supported values are: - - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece model, i.e., lang_dir/bpe.model, to convert word pieces to words. It needs neither a lexicon nor an n-gram LM. - - (2) 1best. Extract the best path from the decoding lattice as the + - (2) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (3) 1best. Extract the best path from the decoding lattice as the decoding result. - - (3) nbest. Extract n paths from the decoding lattice; the path + - (4) nbest. Extract n paths from the decoding lattice; the path with the highest score is the decoding result. - - (4) nbest-rescoring. Extract n paths from the decoding lattice, + - (5) nbest-rescoring. Extract n paths from the decoding lattice, rescore them with an n-gram LM (e.g., a 4-gram LM), the path with the highest score is the decoding result. - - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + - (6) whole-lattice-rescoring. Rescore the decoding lattice with an n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice is the decoding result. you have trained an RNN LM using ./rnn_lm/train.py - - (6) nbest-oracle. Its WER is the lower bound of any n-best + - (7) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. + - (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. + - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM + rescored lattice, rescore them with the attention decoder. """, ) @@ -237,6 +291,23 @@ def get_parser(): """, ) + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + parser.add_argument( "--hlg-scale", type=float, @@ -254,6 +325,13 @@ def get_parser(): """, ) + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + add_model_arguments(parser) return parser @@ -264,8 +342,9 @@ def get_decoding_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, + "search_beam": 20, # for k2 fsa composition + "output_beam": 8, # for k2 fsa composition + "beam": 4, # for prefix-beam-search "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, @@ -274,6 +353,17 @@ def get_decoding_params() -> AttributeDict: return params +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -283,6 +373,7 @@ def decode_one_batch( batch: dict, word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -327,10 +418,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. Note: If it decodes to nothing, then return None. """ - if HLG is not None: - device = HLG.device - else: - device = H.device + device = params.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -352,6 +440,57 @@ def decode_one_batch( encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) ctc_output = model.ctc_output(encoder_out) # (N, T, C) + if params.decoding_method == "ctc-greedy-search": + hyps = ctc_greedy_search(ctc_output, encoder_out_lens) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(hyps) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-greedy-search" + return {key: hyps} + + if params.decoding_method == "prefix-beam-search": + token_ids = ctc_prefix_beam_search( + ctc_output=ctc_output, encoder_out_lens=encoder_out_lens + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search" + return {key: hyps} + + if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring": + best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output=ctc_output, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + ans = dict() + for a_scale_str, token_ids in best_path_dict.items(): + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + token_ids = ctc_prefix_beam_search_shallow_fussion( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + LM=LM, + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search-shallow-fussion" + return {key: hyps} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -404,7 +543,27 @@ def decode_one_batch( # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] hyps = [s.split() for s in hyps] key = "ctc-decoding" - return {key: hyps} + return {key: hyps} # note: returns words + + if params.decoding_method == "attention-decoder-rescoring-no-ngram": + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + ans = dict() + for a_scale_str, best_path in best_path_dict.items(): + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans if params.decoding_method == "nbest-oracle": # Note: You can also pass rescored lattices to it. @@ -421,7 +580,7 @@ def decode_one_batch( ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa return {key: hyps} if params.decoding_method in ["1best", "nbest"]: @@ -429,7 +588,7 @@ def decode_one_batch( best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - key = "no_rescore" + key = "no-rescore" else: best_path = nbest_decoding( lattice=lattice, @@ -437,15 +596,16 @@ def decode_one_batch( use_double_scores=params.use_double_scores, nbest_scale=params.nbest_scale, ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} + return {key: hyps} # note: returns BPE tokens assert params.decoding_method in [ "nbest-rescoring", "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ] lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] @@ -466,6 +626,21 @@ def decode_one_batch( G_with_epsilon_loops=G, lm_scale_list=lm_scale_list, ) + elif params.decoding_method == "attention-decoder-rescoring-with-ngram": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + best_path_dict = rescore_with_attention_decoder_with_ngram( + lattice=rescored_lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) else: assert False, f"Unsupported decoding method: {params.decoding_method}" @@ -489,6 +664,7 @@ def decode_dataset( bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -539,6 +715,7 @@ def decode_dataset( batch=batch, word_table=word_table, G=G, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -559,38 +736,64 @@ def decode_dataset( return results -def save_results( +def save_asr_output( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): - test_set_wers = dict() + """ + Save text produced by ASR. + """ for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = post_processing(results) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + if params.decoding_method in ( + "attention-decoder-rescoring-with-ngram", + "whole-lattice-rescoring", + ): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + results = post_processing(results) # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}_{key}", results, enable_log=enable_log + ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) + print(f"{key}\t{val}", file=fd) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -609,20 +812,29 @@ def main(): params.update(get_decoding_params()) params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + assert params.decoding_method in ( + "ctc-greedy-search", + "prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", "ctc-decoding", "1best", "nbest", "nbest-rescoring", "whole-lattice-rescoring", "nbest-oracle", + "attention-decoder-rescoring-no-ngram", + "attention-decoder-rescoring-with-ngram", ) params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" + params.suffix = f"iter-{params.iter}_avg-{params.avg}" else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" if params.causal: assert ( @@ -631,11 +843,16 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + params.suffix += f"_lm-scale-{params.lm_scale}" if params.use_averaged_model: - params.suffix += "-use-averaged-model" + params.suffix += "_use-averaged-model" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -643,6 +860,7 @@ def main(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) + params.device = device logging.info(f"Device: {device}") logging.info(params) @@ -654,14 +872,28 @@ def main(): params.vocab_size = num_classes # and are defined in local/train_bpe_model.py params.blank_id = 0 + params.eos_id = 1 + params.sos_id = 1 - if params.decoding_method == "ctc-decoding": + if params.decoding_method in [ + "ctc-greedy-search", + "ctc-decoding", + "prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", + "attention-decoder-rescoring-no-ngram", + ]: HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) + H = None + if params.decoding_method in [ + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) bpe_model = spm.SentencePieceProcessor() bpe_model.load(str(params.lang_dir / "bpe.model")) else: @@ -679,6 +911,7 @@ def main(): if params.decoding_method in ( "nbest-rescoring", "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ): if not (params.lm_dir / "G_4_gram.pt").is_file(): logging.info("Loading G_4_gram.fst.txt") @@ -710,7 +943,10 @@ def main(): d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) G = k2.Fsa.from_dict(d) - if params.decoding_method == "whole-lattice-rescoring": + if params.decoding_method in [ + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = k2.add_epsilon_self_loops(G) @@ -723,6 +959,19 @@ def main(): else: G = None + # only load the neural network LM if required + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + logging.info("About to create model") model = get_model(params) @@ -811,18 +1060,19 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) - test_clean_cuts = gigaspeech.test_clean_cuts() - test_other_cuts = gigaspeech.test_other_cuts() + test_cuts = gigaspeech.test_cuts() + dev_cuts = gigaspeech.dev_cuts() - test_clean_dl = gigaspeech.test_dataloaders(test_clean_cuts) - test_other_dl = gigaspeech.test_dataloaders(test_other_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) + dev_dl = gigaspeech.test_dataloaders(dev_cuts) - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_sets = ["test", "dev"] + test_dls = [test_dl, dev_dl] - for test_set, test_dl in zip(test_sets, test_dl): + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl, params=params, @@ -834,12 +1084,19 @@ def main(): G=G, ) - save_results( + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!") diff --git a/egs/gigaspeech/ASR/zipformer/label_smoothing.py b/egs/gigaspeech/ASR/zipformer/label_smoothing.py new file mode 120000 index 0000000000..175c633cc7 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/spec_augment.py b/egs/gigaspeech/ASR/zipformer/spec_augment.py new file mode 120000 index 0000000000..d00c7c9ddc --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/spec_augment.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/spec_augment.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index 4c122effee..0174b427ba 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -65,6 +65,7 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import GigaSpeechAsrDataModule +from attention_decoder import AttentionDecoderModel from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -99,6 +100,8 @@ str2bool, ) +from spec_augment import SpecAugment + LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -220,6 +223,41 @@ def add_model_arguments(parser: argparse.ArgumentParser): """, ) + parser.add_argument( + "--attention-decoder-dim", + type=int, + default=512, + help="""Dimension used in the attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-dim", + type=int, + default=512, + help="""Attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-dim", + type=int, + default=2048, + help="""Feedforward dimension used in attention decoder""", + ) + parser.add_argument( "--causal", type=str2bool, @@ -258,6 +296,20 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use CTC head.", ) + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -403,6 +455,34 @@ def get_parser(): help="Scale for CTC loss.", ) + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.15, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.0, + help="When using cr-ctc, we increase the time-masking ratio.", + ) + + parser.add_argument( + "--cr-loss-masked-scale", + type=float, + default=1.0, + help="The value used to scale up the cr_loss at masked positions", + ) + + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + parser.add_argument( "--seed", type=int, @@ -542,6 +622,9 @@ def get_params() -> AttributeDict: # parameters for zipformer "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, "warm_step": 2000, "env_info": get_env_info(), } @@ -614,6 +697,23 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=params.attention_decoder_dim, + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=params.attention_decoder_attention_dim, + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_dim, + memory_dim=max(_to_int_tuple(params.encoder_dim)), + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + def get_model(params: AttributeDict) -> nn.Module: assert params.use_transducer or params.use_ctc, ( f"At least one of them should be True, " @@ -631,20 +731,45 @@ def get_model(params: AttributeDict) -> nn.Module: decoder = None joiner = None + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, + attention_decoder=attention_decoder, encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, use_transducer=params.use_transducer, use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, ) return model +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -767,6 +892,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + spec_augment: Optional[SpecAugment] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -783,8 +909,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] @@ -802,6 +928,21 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + with torch.set_grad_enabled(is_training): losses = model( x=feature, @@ -810,8 +951,14 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, + cr_loss_masked_scale=params.cr_loss_masked_scale, ) - simple_loss, pruned_loss, ctc_loss = losses[:3] + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = losses[:5] loss = 0.0 @@ -833,6 +980,11 @@ def compute_loss( if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + loss += params.cr_loss_scale * cr_loss + + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss assert loss.requires_grad == is_training @@ -848,6 +1000,10 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() return loss, info @@ -895,6 +1051,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -921,6 +1078,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -965,6 +1124,7 @@ def save_bad_model(suffix: str = ""): sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1124,10 +1284,17 @@ def run(rank, world_size, args): # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() if not params.use_transducer: - params.ctc_loss_scale = 1.0 + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, + params.attention_decoder_loss_scale, + ) logging.info(params) @@ -1137,6 +1304,13 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + if params.use_cr_ctc: + assert params.use_ctc + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1215,6 +1389,7 @@ def remove_short_utt(c: Cut): optimizer=optimizer, sp=sp, params=params, + spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) @@ -1242,6 +1417,7 @@ def remove_short_utt(c: Cut): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, + spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1307,6 +1483,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1324,6 +1501,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad() diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py index 81682e87b5..d1cedb6fda 100644 --- a/egs/librispeech/ASR/zipformer/attention_decoder.py +++ b/egs/librispeech/ASR/zipformer/attention_decoder.py @@ -17,6 +17,7 @@ import math +import warnings from typing import List, Optional import k2 @@ -234,10 +235,13 @@ def forward( # construct attn_mask for self-attn modules padding_mask = make_pad_mask(x_lens) # (batch, tgt_len) causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len) - attn_mask = torch.logical_or( - padding_mask.unsqueeze(1), # (batch, 1, seq_len) - torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len) - ) # (batch, seq_len, seq_len) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + attn_mask = torch.logical_or( + padding_mask.unsqueeze(1), # (batch, 1, seq_len) + torch.logical_not(causal_mask).unsqueeze(0), # (1, seq_len, seq_len) + ) # (batch, seq_len, seq_len) if memory is not None: memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim) @@ -367,7 +371,9 @@ def __init__( self.num_heads = num_heads self.head_dim = attention_dim // num_heads assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, num_heads, attention_dim + self.head_dim, + num_heads, + attention_dim, ) self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. @@ -437,15 +443,19 @@ def forward( if key_padding_mask is not None: assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape attn_weights = attn_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), ) if attn_mask is not None: - assert ( - attn_mask.shape == (batch, 1, src_len) - or attn_mask.shape == (batch, tgt_len, src_len) + assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == ( + batch, + tgt_len, + src_len, ), attn_mask.shape - attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf")) + attn_weights = attn_weights.masked_fill( + attn_mask.unsqueeze(1), float("-inf") + ) attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -456,7 +466,11 @@ def forward( # (batch * head, tgt_len, head_dim) attn_output = torch.bmm(attn_weights, v) - assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape + assert attn_output.shape == ( + batch * num_heads, + tgt_len, + head_dim, + ), attn_output.shape attn_output = attn_output.transpose(0, 1).contiguous() attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim) diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 9db4299592..183d42360b 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -111,6 +111,7 @@ import argparse import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -123,6 +124,10 @@ from lhotse import set_caching_enabled from train import add_model_arguments, get_model, get_params +from icefall.context_graph import ContextGraph, ContextState +from icefall.ngram_lm import NgramLm, NgramLmStateCost +from icefall.lm_wrapper import LmScorer + from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -131,6 +136,9 @@ ) from icefall.decode import ( ctc_greedy_search, + ctc_prefix_beam_search, + ctc_prefix_beam_search_attention_decoder_rescoring, + ctc_prefix_beam_search_shallow_fussion, get_lattice, nbest_decoding, nbest_oracle, @@ -280,6 +288,23 @@ def get_parser(): """, ) + parser.add_argument( + "--nnlm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--nnlm-scale", + type=float, + default=0, + help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion. + Used only when `--use-shallow-fusion` is set to True. + """, + ) + parser.add_argument( "--hlg-scale", type=float, @@ -297,11 +322,52 @@ def get_parser(): """, ) + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--lodr-ngram", + type=str, + help="The path to the lodr ngram", + ) + + parser.add_argument( + "--lodr-lm-scale", + type=float, + default=0, + help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.", + ) + + parser.add_argument( + "--context-score", + type=float, + default=0, + help=""" + The bonus score of each token for the context biasing words/phrases. + 0 means don't use contextual biasing. + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + parser.add_argument( "--skip-scoring", type=str2bool, default=False, - help="""Skip scoring, but still save the ASR output (for eval sets).""" + help="""Skip scoring, but still save the ASR output (for eval sets).""", ) add_model_arguments(parser) @@ -314,8 +380,9 @@ def get_decoding_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, + "search_beam": 20, # for k2 fsa composition + "output_beam": 8, # for k2 fsa composition + "beam": 4, # for prefix-beam-search "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, @@ -333,6 +400,9 @@ def decode_one_batch( batch: dict, word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -377,10 +447,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. Note: If it decodes to nothing, then return None. """ - if HLG is not None: - device = HLG.device - else: - device = H.device + device = params.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -411,6 +478,51 @@ def decode_one_batch( key = "ctc-greedy-search" return {key: hyps} + if params.decoding_method == "prefix-beam-search": + token_ids = ctc_prefix_beam_search( + ctc_output=ctc_output, encoder_out_lens=encoder_out_lens + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search" + return {key: hyps} + + if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring": + best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output=ctc_output, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + ans = dict() + for a_scale_str, token_ids in best_path_dict.items(): + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + token_ids = ctc_prefix_beam_search_shallow_fussion( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + NNLM=NNLM, + LODR_lm=LODR_lm, + LODR_lm_scale=params.lodr_lm_scale, + context_graph=context_graph, + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search-shallow-fussion" + return {key: hyps} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -584,6 +696,9 @@ def decode_dataset( bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -634,6 +749,9 @@ def decode_dataset( batch=batch, word_table=word_table, G=G, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, ) for name, hyps in hyps_dict.items(): @@ -664,9 +782,7 @@ def save_asr_output( """ for key, results in results_dict.items(): - recogs_filename = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recogs_filename, texts=results) @@ -680,7 +796,8 @@ def save_wer_results( results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): if params.decoding_method in ( - "attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring" + "attention-decoder-rescoring-with-ngram", + "whole-lattice-rescoring", ): # Set it to False since there are too many logs. enable_log = False @@ -721,6 +838,7 @@ def save_wer_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -735,8 +853,11 @@ def main(): set_caching_enabled(True) # lhotse assert params.decoding_method in ( - "ctc-greedy-search", "ctc-decoding", + "ctc-greedy-search", + "prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", "1best", "nbest", "nbest-rescoring", @@ -762,6 +883,16 @@ def main(): params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_left-context-{params.left_context_frames}" + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + if params.nnlm_scale != 0: + params.suffix += f"_nnlm-scale-{params.nnlm_scale}" + if params.lodr_lm_scale != 0: + params.suffix += f"_lodr-scale-{params.lodr_lm_scale}" + if params.context_score != 0: + params.suffix += f"_context_score-{params.context_score}" + if params.use_averaged_model: params.suffix += "_use-averaged-model" @@ -772,6 +903,8 @@ def main(): if torch.cuda.is_available(): device = torch.device("cuda", 0) + params.device = device + logging.info(f"Device: {device}") logging.info(params) @@ -786,14 +919,24 @@ def main(): params.sos_id = 1 if params.decoding_method in [ - "ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram" + "ctc-greedy-search", + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + "prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", ]: HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) + H = None + if params.decoding_method in [ + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) bpe_model = spm.SentencePieceProcessor() bpe_model.load(str(params.lang_dir / "bpe.model")) else: @@ -844,7 +987,8 @@ def main(): G = k2.Fsa.from_dict(d) if params.decoding_method in [ - "whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram" + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later @@ -858,6 +1002,51 @@ def main(): else: G = None + # only load the neural network LM if required + NNLM = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.nnlm_scale != 0 + ): + NNLM = LmScorer( + lm_type=params.nnlm_type, + params=params, + device=device, + lm_scale=params.nnlm_scale, + ) + NNLM.to(device) + NNLM.eval() + + LODR_lm = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.lodr_lm_scale != 0 + ): + assert os.path.exists( + params.lodr_ngram + ), f"LODR ngram does not exists, given path : {params.lodr_ngram}" + logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}") + LODR_lm = NgramLm( + params.lodr_ngram, + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {LODR_lm.lm.num_states}") + + context_graph = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.context_score != 0 + ): + assert os.path.exists( + params.context_file + ), f"context_file does not exists, given path : {params.context_file}" + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(bpe_model.encode(line.strip())) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + logging.info("About to create model") model = get_model(params) @@ -967,6 +1156,9 @@ def main(): bpe_model=bpe_model, word_table=lexicon.word_table, G=G, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, ) save_asr_output( diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index bd1ed26d8d..2de1e08fee 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -25,6 +25,7 @@ from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask +from spec_augment import SpecAugment, time_warp class AsrModel(nn.Module): @@ -181,6 +182,63 @@ def forward_ctc( ) return ctc_loss + def forward_cr_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + time_mask: Optional[torch.Tensor] = None, + cr_loss_masked_scale: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute CTC loss with consistency regularization loss. + Args: + encoder_out: + Encoder output, of shape (2 * N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (2 * N,). + targets: + Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + time_mask: + Downsampled time masks of shape (2 * N, T, 1). + cr_loss_masked_scale: + The loss scale used to scale up the cr_loss at masked positions. + """ + # Compute CTC loss + ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) + targets=targets.cpu(), + input_lengths=encoder_out_lens.cpu(), + target_lengths=target_lengths.cpu(), + reduction="sum", + ) + + # Compute consistency regularization loss + exchanged_targets = ctc_output.detach().chunk(2, dim=0) + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + cr_loss = nn.functional.kl_div( + input=ctc_output, + target=exchanged_targets, + reduction="none", + log_target=True, + ) # (2 * N, T, C) + if time_mask is not None: + assert time_mask.shape[:-1] == ctc_output.shape[:-1], ( + time_mask.shape, ctc_output.shape + ) + masked_scale = time_mask * (cr_loss_masked_scale - 1) + 1 + # e.g., if cr_loss_masked_scale = 3, scales at masked positions are 3, + # scales at unmasked positions are 1 + cr_loss = cr_loss * masked_scale # scaling up masked positions + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) + cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + + return ctc_loss, cr_loss + def forward_transducer( self, encoder_out: torch.Tensor, @@ -296,7 +354,13 @@ def forward( prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + use_cr_ctc: bool = False, + use_spec_aug: bool = False, + spec_augment: Optional[SpecAugment] = None, + supervision_segments: Optional[torch.Tensor] = None, + time_warp_factor: Optional[int] = 80, + cr_loss_masked_scale: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -316,9 +380,28 @@ def forward( lm_scale: The scale to smooth the loss with lm (output of predictor network) part + use_cr_ctc: + Whether use consistency-regularized CTC. + use_spec_aug: + Whether apply spec-augment manually, used only if use_cr_ctc is True. + spec_augment: + The SpecAugment instance that returns time masks, + used only if use_cr_ctc is True. + supervision_segments: + An int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features``. + Used only if use_cr_ctc is True. + time_warp_factor: + Parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + Used only if use_cr_ctc is True. + cr_loss_masked_scale: + The loss scale used to scale up the cr_loss at masked positions. + Returns: - Return the transducer losses and CTC loss, - in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss) + Return the transducer losses, CTC loss, AED loss, + and consistency-regularization loss in form of + (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -334,6 +417,27 @@ def forward( device = x.device + if use_cr_ctc: + assert self.use_ctc + if use_spec_aug: + assert spec_augment is not None and spec_augment.time_warp_factor < 1 + # Apply time warping before input duplicating + assert supervision_segments is not None + x = time_warp( + x, + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments, + ) + # Independently apply frequency masking and time masking to the two copies + x, time_mask = spec_augment(x.repeat(2, 1, 1)) + # time_mask: 1 for masked, 0 for unmasked + time_mask = downsample_time_mask(time_mask, x.dtype) + else: + x = x.repeat(2, 1, 1) + time_mask = None + x_lens = x_lens.repeat(2) + y = k2.ragged.cat([y, y], axis=0) + # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) @@ -351,6 +455,9 @@ def forward( am_scale=am_scale, lm_scale=lm_scale, ) + if use_cr_ctc: + simple_loss = simple_loss * 0.5 + pruned_loss = pruned_loss * 0.5 else: simple_loss = torch.empty(0) pruned_loss = torch.empty(0) @@ -358,14 +465,28 @@ def forward( if self.use_ctc: # Compute CTC loss targets = y.values - ctc_loss = self.forward_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) + if not use_cr_ctc: + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + cr_loss = torch.empty(0) + else: + ctc_loss, cr_loss = self.forward_cr_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + time_mask=time_mask, + cr_loss_masked_scale=cr_loss_masked_scale, + ) + ctc_loss = ctc_loss * 0.5 + cr_loss = cr_loss * 0.5 else: ctc_loss = torch.empty(0) + cr_loss = torch.empty(0) if self.use_attention_decoder: attention_decoder_loss = self.attention_decoder.calc_att_loss( @@ -374,7 +495,37 @@ def forward( ys=y.to(device), ys_lens=y_lens.to(device), ) + if use_cr_ctc: + attention_decoder_loss = attention_decoder_loss * 0.5 else: attention_decoder_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss + + +def downsample_time_mask(time_mask: torch.Tensor, dtype: torch.dtype): + """Downsample the time masks as in Zipformer. + Args: + time_mask: shape of (N, T) + Returns: + The downsampled time masks of shape (N, T', 1), + where T' = ((T - 7) // 2 + 1) // 2 + """ + # Downsample the time masks as in Zipformer + time_mask = time_mask.to(dtype).unsqueeze(dim=1) + # as in conv-embed + time_mask = nn.functional.max_pool1d( + time_mask, kernel_size=3, stride=1, padding=0 + ) # T - 2 + time_mask = nn.functional.max_pool1d( + time_mask, kernel_size=3, stride=2, padding=0 + ) # (T - 3) // 2 + time_mask = nn.functional.max_pool1d( + time_mask, kernel_size=3, stride=1, padding=0 + ) # (T - 7) // 2 + # as in output-downsampling + time_mask = nn.functional.max_pool1d( + time_mask, kernel_size=2, stride=2, padding=0, ceil_mode=True + ) + time_mask = time_mask.transpose(1, 2) # (N * 2, T', 1) + return time_mask diff --git a/egs/librispeech/ASR/zipformer/spec_augment.py b/egs/librispeech/ASR/zipformer/spec_augment.py new file mode 100644 index 0000000000..6ddf2b09bd --- /dev/null +++ b/egs/librispeech/ASR/zipformer/spec_augment.py @@ -0,0 +1,313 @@ +# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copied from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py +# with minor modification for cr-ctc training. + + +import math +import random +from typing import Any, Dict, Optional, Tuple + +import torch +from lhotse.dataset.signal_transforms import time_warp as time_warp_impl + + +class SpecAugment(torch.nn.Module): + """SpecAugment from lhotse with minor modification, returning time masks. + + SpecAugment performs three augmentations: + - time warping of the feature matrix + - masking of ranges of features (frequency bands) + - masking of ranges of frames (time) + + The current implementation works with batches, but processes each example separately + in a loop rather than simultaneously to achieve different augmentation parameters for + each example. + """ + + def __init__( + self, + time_warp_factor: Optional[int] = 80, + num_feature_masks: int = 2, + features_mask_size: int = 27, + num_frame_masks: int = 10, + frames_mask_size: int = 100, + max_frames_mask_fraction: float = 0.15, + p=0.9, + ): + """ + SpecAugment's constructor. + + :param time_warp_factor: parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + :param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable. + :param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins). + This is the ``F`` parameter from the SpecAugment paper. + :param num_frame_masks: the number of masking regions for utterances. Set to ``0`` to disable. + :param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames). + This is the ``T`` parameter from the SpecAugment paper. + :param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length + of the utterance (or supervision segment). + This is the parameter denoted by ``p`` in the SpecAugment paper. + :param p: the probability of applying this transform. + It is different from ``p`` in the SpecAugment paper! + """ + super().__init__() + assert 0 <= p <= 1 + assert num_feature_masks >= 0 + assert num_frame_masks >= 0 + assert features_mask_size > 0 + assert frames_mask_size > 0 + self.time_warp_factor = time_warp_factor + self.num_feature_masks = num_feature_masks + self.features_mask_size = features_mask_size + self.num_frame_masks = num_frame_masks + self.frames_mask_size = frames_mask_size + self.max_frames_mask_fraction = max_frames_mask_fraction + self.p = p + + def forward( + self, + features: torch.Tensor, + supervision_segments: Optional[torch.IntTensor] = None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes SpecAugment for a batch of feature matrices. + + Since the batch will usually already be padded, the user can optionally + provide a ``supervision_segments`` tensor that will be used to apply SpecAugment + only to selected areas of the input. The format of this input is described below. + + :param features: a batch of feature matrices with shape ``(B, T, F)``. + :param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features`` -- there may be either + less or more than the batch size. + The second dimension encoder three kinds of information: + the sequence index of the corresponding feature matrix in `features`, + the start frame index, and the number of frames for each segment. + :return: + - an augmented tensor of shape ``(B, T, F)``. + - the corresponding time masks of shape ``(B, T)``. + """ + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of " "single-channel feature matrices." + ) + features = features.clone() + + time_masks = [] + + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + masked_feature, time_mask = self._forward_single(features[sequence_idx]) + features[sequence_idx] = masked_feature + time_masks.append(time_mask) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + end_frame = start_frame + num_frames + warped_feature, _ = self._forward_single( + features[sequence_idx, start_frame:end_frame], warp=True, mask=False + ) + features[sequence_idx, start_frame:end_frame] = warped_feature + # ... and then time-mask the full feature matrices. Note that in this mode, + # it might happen that masks are applied to different sequences/examples + # than the time warping. + for sequence_idx in range(features.size(0)): + masked_feature, time_mask = self._forward_single( + features[sequence_idx], warp=False, mask=True + ) + features[sequence_idx] = masked_feature + time_masks.append(time_mask) + + time_masks = torch.cat(time_masks, dim=0) + assert time_masks.shape == features.shape[:-1], (time_masks.shape == features.shape[:-1]) + return features, time_masks + + def _forward_single( + self, features: torch.Tensor, warp: bool = True, mask: bool = True + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply SpecAugment to a single feature matrix of shape (T, F). + """ + if random.random() > self.p: + # Randomly choose whether this transform is applied + time_mask = torch.zeros( + 1, features.size(0), dtype=torch.bool, device=features.device + ) + return features, time_mask + + time_mask = None + if warp: + if self.time_warp_factor is not None and self.time_warp_factor >= 1: + features = time_warp_impl(features, factor=self.time_warp_factor) + + if mask: + mean = features.mean() + # Frequency masking + features, _ = mask_along_axis_optimized( + features, + mask_size=self.features_mask_size, + mask_times=self.num_feature_masks, + mask_value=mean, + axis=2, + ) + # Time masking + max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) + num_frame_masks = min( + self.num_frame_masks, + math.ceil(max_tot_mask_frames / self.frames_mask_size), + ) + max_mask_frames = min( + self.frames_mask_size, max_tot_mask_frames // num_frame_masks + ) + features, time_mask = mask_along_axis_optimized( + features, + mask_size=max_mask_frames, + mask_times=num_frame_masks, + mask_value=mean, + axis=1, + return_time_mask=True, + ) + + return features, time_mask + + def state_dict(self, **kwargs) -> Dict[str, Any]: + return dict( + time_warp_factor=self.time_warp_factor, + num_feature_masks=self.num_feature_masks, + features_mask_size=self.features_mask_size, + num_frame_masks=self.num_frame_masks, + frames_mask_size=self.frames_mask_size, + max_frames_mask_fraction=self.max_frames_mask_fraction, + p=self.p, + ) + + def load_state_dict(self, state_dict: Dict[str, Any]): + self.time_warp_factor = state_dict.get( + "time_warp_factor", self.time_warp_factor + ) + self.num_feature_masks = state_dict.get( + "num_feature_masks", self.num_feature_masks + ) + self.features_mask_size = state_dict.get( + "features_mask_size", self.features_mask_size + ) + self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) + self.frames_mask_size = state_dict.get( + "frames_mask_size", self.frames_mask_size + ) + self.max_frames_mask_fraction = state_dict.get( + "max_frames_mask_fraction", self.max_frames_mask_fraction + ) + self.p = state_dict.get("p", self.p) + + +def mask_along_axis_optimized( + features: torch.Tensor, + mask_size: int, + mask_times: int, + mask_value: float, + axis: int, + return_time_mask: bool = False, +) -> torch.Tensor: + """ + Apply Frequency and Time masking along axis. + Frequency and Time masking as described in the SpecAugment paper. + + :param features: input tensor of shape ``(T, F)`` + :mask_size: the width size for masking. + :mask_times: the number of masking regions. + :mask_value: Value to assign to the masked regions. + :axis: Axis to apply masking on (1 -> time, 2 -> frequency) + :return_time_mask: Whether return the time mask of shape ``(1, T)`` + """ + if axis not in [1, 2]: + raise ValueError("Only Frequency and Time masking are supported!") + + if return_time_mask and axis == 1: + time_mask = torch.zeros( + 1, features.size(0), dtype=torch.bool, device=features.device + ) + else: + time_mask = None + + features = features.unsqueeze(0) + features = features.reshape([-1] + list(features.size()[-2:])) + + values = torch.randint(int(0), int(mask_size), (1, mask_times)) + min_values = torch.rand(1, mask_times) * (features.size(axis) - values) + mask_starts = (min_values.long()).squeeze() + mask_ends = (min_values.long() + values.long()).squeeze() + + if axis == 1: + if mask_times == 1: + features[:, mask_starts:mask_ends] = mask_value + if return_time_mask: + time_mask[:, mask_starts:mask_ends] = True + return features.squeeze(0), time_mask + for (mask_start, mask_end) in zip(mask_starts, mask_ends): + features[:, mask_start:mask_end] = mask_value + if return_time_mask: + time_mask[:, mask_start:mask_end] = True + else: + if mask_times == 1: + features[:, :, mask_starts:mask_ends] = mask_value + return features.squeeze(0), time_mask + for (mask_start, mask_end) in zip(mask_starts, mask_ends): + features[:, :, mask_start:mask_end] = mask_value + + features = features.squeeze(0) + return features, time_mask + + +def time_warp( + features: torch.Tensor, + p: float = 0.9, + time_warp_factor: Optional[int] = 80, + supervision_segments: Optional[torch.Tensor] = None, +): + if time_warp_factor is None or time_warp_factor < 1: + return features + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of single-channel feature matrices." + ) + features = features.clone() + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + if random.random() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx] = time_warp_impl( + features[sequence_idx], factor=time_warp_factor + ) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + if random.random() > p: + # Randomly choose whether this transform is applied + continue + end_frame = start_frame + num_frames + features[sequence_idx, start_frame:end_frame] = time_warp_impl( + features[sequence_idx, start_frame:end_frame], factor=time_warp_factor + ) + + return features diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9c1c7f5a78..3fde55de24 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -102,6 +102,7 @@ setup_logger, str2bool, ) +from spec_augment import SpecAugment LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -304,6 +305,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use attention-decoder head.", ) + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -449,6 +457,27 @@ def get_parser(): help="Scale for CTC loss.", ) + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.15, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.0, + help="When using cr-ctc, we increase the time-masking ratio.", + ) + + parser.add_argument( + "--cr-loss-masked-scale", + type=float, + default=1.0, + help="The value used to scale up the cr_loss at masked positions", + ) + parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -717,6 +746,24 @@ def get_model(params: AttributeDict) -> nn.Module: return model +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -839,6 +886,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + spec_augment: Optional[SpecAugment] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -855,8 +903,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] @@ -874,14 +922,35 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, + cr_loss_masked_scale=params.cr_loss_masked_scale, ) loss = 0.0 @@ -904,6 +973,8 @@ def compute_loss( if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + loss += params.cr_loss_scale * cr_loss if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -922,6 +993,8 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() @@ -971,6 +1044,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -997,6 +1071,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -1043,6 +1119,7 @@ def save_bad_model(suffix: str = ""): sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1238,6 +1315,13 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + if params.use_cr_ctc: + assert params.use_ctc + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1360,6 +1444,7 @@ def remove_short_and_long_utt(c: Cut): optimizer=optimizer, sp=sp, params=params, + spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) @@ -1387,6 +1472,7 @@ def remove_short_and_long_utt(c: Cut): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, + spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1452,6 +1538,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1471,6 +1558,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad() diff --git a/icefall/decode.py b/icefall/decode.py index dd3af1e99b..6b642c94d8 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -15,11 +16,18 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Union import k2 import torch +from multiprocessing.pool import Pool + +from icefall.context_graph import ContextGraph, ContextState +from icefall.ngram_lm import NgramLm, NgramLmStateCost +from icefall.lm_wrapper import LmScorer + from icefall.utils import add_eos, add_sos, get_texts DEFAULT_LM_SCALE = [ @@ -1497,3 +1505,684 @@ def ctc_greedy_search( hyps = [h[h != blank_id].tolist() for h in hyps] return hyps + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] = field(default_factory=list) + + # The log prob of ys that ends with blank token. + # It contains only one entry. + log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32) + + # The log prob of ys that ends with non blank token. + # It contains only one entry. + log_prob_non_blank: torch.Tensor = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] = field(default_factory=list) + + # The lm score of ys + # May contain external LM score (including LODR score) and contextual biasing score + # It contains only one entry + lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32) + + # the lm log_probs for next token given the history ys + # The number of elements should be equal to vocabulary size. + lm_log_probs: Optional[torch.Tensor] = None + + # the RNNLM states (h and c in LSTM) + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # LODR (N-gram LM) state + LODR_state: Optional[NgramLmStateCost] = None + + # N-gram LM state + Ngram_state: Optional[NgramLmStateCost] = None + + # Context graph state + context_state: Optional[ContextState] = None + + # This is the total score of current path, acoustic plus external LM score. + @property + def tot_score(self) -> torch.Tensor: + return self.log_prob + self.lm_score + + # This is only the probability from model output (i.e External LM score not included). + @property + def log_prob(self) -> torch.Tensor: + return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank) + + @property + def key(self) -> tuple: + """Return a tuple representation of self.ys""" + return tuple(self.ys) + + def clone(self) -> "Hypothesis": + return Hypothesis( + ys=self.ys, + log_prob_blank=self.log_prob_blank, + log_prob_non_blank=self.log_prob_non_blank, + timestamp=self.timestamp, + lm_log_probs=self.lm_log_probs, + lm_score=self.lm_score, + state=self.state, + LODR_state=self.LODR_state, + Ngram_state=self.Ngram_state, + context_state=self.context_state, + ) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[tuple, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[tuple, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob_blank, hyp.log_prob_blank, out=old_hyp.log_prob_blank + ) + torch.logaddexp( + old_hyp.log_prob_non_blank, + hyp.log_prob_non_blank, + out=old_hyp.log_prob_non_blank, + ) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `tot_score`. + + Args: + length_norm: + If True, the `tot_score` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `tot_score`. + """ + if length_norm: + return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys)) + else: + return max(self._data.values(), key=lambda hyp: hyp.tot_score) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose tot_score is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `tot_score` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.tot_score > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": + """Return the top-k hypothesis. + + Args: + length_norm: + If True, the `tot_score` of a hypothesis is normalized by the + number of tokens in it. + """ + hyps = list(self._data.items()) + + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].tot_score / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].tot_score, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: tuple): + return key in self._data + + def __getitem__(self, key: tuple): + return self._data[key] + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(str(s)) + + +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def _step_worker( + log_probs: torch.Tensor, + indexes: torch.Tensor, + B: HypothesisList, + beam: int = 4, + blank_id: int = 0, + nnlm_scale: float = 0, + LODR_lm_scale: float = 0, + context_graph: Optional[ContextGraph] = None, +) -> HypothesisList: + """The worker to decode one step. + + Args: + log_probs: + topk log_probs of current step (i.e. the kept tokens of first pass pruning), + the shape is (beam,) + topk_indexes: + The indexes of the topk_values above, the shape is (beam,) + B: + An instance of HypothesisList containing the kept hypothesis. + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + lm_scale: + The scale of nn lm. + LODR_lm_scale: + The scale of the LODR_lm + context_graph: + A ContextGraph instance containing contextual phrases. + + Return: + Returns the updated HypothesisList. + """ + A = list(B) + B = HypothesisList() + for h in range(len(A)): + hyp = A[h] + for k in range(log_probs.size(0)): + log_prob, index = log_probs[k], indexes[k] + new_token = index.item() + update_prefix = False + new_hyp = hyp.clone() + if new_token == blank_id: + # Case 0: *a + ε => *a + # *aε + ε => *a + # Prefix does not change, update log_prob of blank + new_hyp.log_prob_non_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + new_hyp.log_prob_blank = hyp.log_prob + log_prob + B.add(new_hyp) + elif len(hyp.ys) > 0 and hyp.ys[-1] == new_token: + # Case 1: *a + a => *a + # Prefix does not change, update log_prob of non_blank + new_hyp.log_prob_non_blank = hyp.log_prob_non_blank + log_prob + new_hyp.log_prob_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + B.add(new_hyp) + + # Case 2: *aε + a => *aa + # Prefix changes, update log_prob of blank + new_hyp = hyp.clone() + # Caution: DO NOT use append, as clone is shallow copy + new_hyp.ys = hyp.ys + [new_token] + new_hyp.log_prob_non_blank = hyp.log_prob_blank + log_prob + new_hyp.log_prob_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + update_prefix = True + else: + # Case 3: *a + b => *ab, *aε + b => *ab + # Prefix changes, update log_prob of non_blank + # Caution: DO NOT use append, as clone is shallow copy + new_hyp.ys = hyp.ys + [new_token] + new_hyp.log_prob_non_blank = hyp.log_prob + log_prob + new_hyp.log_prob_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + update_prefix = True + + if update_prefix: + lm_score = hyp.lm_score + if hyp.lm_log_probs is not None: + lm_score = lm_score + hyp.lm_log_probs[new_token] * nnlm_scale + new_hyp.lm_log_probs = None + + if context_graph is not None and hyp.context_state is not None: + ( + context_score, + new_context_state, + matched_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + lm_score = lm_score + context_score + new_hyp.context_state = new_context_state + + if hyp.LODR_state is not None: + state_cost = hyp.LODR_state.forward_one_step(new_token) + # calculate the score of the latest token + current_ngram_score = state_cost.lm_score - hyp.LODR_state.lm_score + assert current_ngram_score <= 0.0, ( + state_cost.lm_score, + hyp.LODR_state.lm_score, + ) + lm_score = lm_score + LODR_lm_scale * current_ngram_score + new_hyp.LODR_state = state_cost + + new_hyp.lm_score = lm_score + B.add(new_hyp) + B = B.topk(beam) + return B + + +def _sequence_worker( + topk_values: torch.Tensor, + topk_indexes: torch.Tensor, + B: HypothesisList, + encoder_out_lens: torch.Tensor, + beam: int = 4, + blank_id: int = 0, +) -> HypothesisList: + """The worker to decode one sequence. + + Args: + topk_values: + topk log_probs of model output (i.e. the kept tokens of first pass pruning), + the shape is (T, beam) + topk_indexes: + The indexes of the topk_values above, the shape is (T, beam) + B: + An instance of HypothesisList containing the kept hypothesis. + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + + Return: + Returns the updated HypothesisList. + """ + B.add(Hypothesis()) + for j in range(encoder_out_lens): + log_probs, indexes = topk_values[j], topk_indexes[j] + B = _step_worker(log_probs, indexes, B, beam, blank_id) + return B + + +def ctc_prefix_beam_search( + ctc_output: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 4, + blank_id: int = 0, + process_pool: Optional[Pool] = None, + return_nbest: Optional[bool] = False, +) -> Union[List[List[int]], List[HypothesisList]]: + """Implement prefix search decoding in "Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks". + + Args: + ctc_output: + The output of ctc head (log probability), the shape is (B, T, V) + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + process_pool: + The process pool for parallel decoding, if not provided, it will use all + you cpu cores by default. + return_nbest: + If true, return a list of HypothesisList, return a list of list of decoded token ids otherwise. + """ + batch_size, num_frames, vocab_size = ctc_output.shape + + # TODO: using a larger beam for first pass pruning + topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam) + topk_values = topk_values.cpu() + topk_indexes = topk_indexes.cpu() + + B = [HypothesisList() for _ in range(batch_size)] + + pool = Pool() if process_pool is None else process_pool + arguments = [] + for i in range(batch_size): + arguments.append( + ( + topk_values[i], + topk_indexes[i], + B[i], + encoder_out_lens[i].item(), + beam, + blank_id, + ) + ) + async_results = pool.starmap_async(_sequence_worker, arguments) + B = list(async_results.get()) + if process_pool is None: + pool.close() + pool.join() + if return_nbest: + return B + else: + best_hyps = [b.get_most_probable() for b in B] + return [hyp.ys for hyp in best_hyps] + + +def ctc_prefix_beam_search_shallow_fussion( + ctc_output: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 4, + blank_id: int = 0, + LODR_lm: Optional[NgramLm] = None, + LODR_lm_scale: Optional[float] = 0, + NNLM: Optional[LmScorer] = None, + context_graph: Optional[ContextGraph] = None, +) -> List[List[int]]: + """Implement prefix search decoding in "Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add + nervous language model shallow fussion, it also supports contextual + biasing with a given grammar. + + Args: + ctc_output: + The output of ctc head (log probability), the shape is (B, T, V) + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + LODR_lm: + A low order n-gram LM, whose score will be subtracted during shallow fusion + LODR_lm_scale: + The scale of the LODR_lm + LM: + A neural net LM, e.g an RNNLM or transformer LM + context_graph: + A ContextGraph instance containing contextual phrases. + + Return: + Returns a list of list of decoded token ids. + """ + batch_size, num_frames, vocab_size = ctc_output.shape + # TODO: using a larger beam for first pass pruning + topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam) + topk_values = topk_values.cpu() + topk_indexes = topk_indexes.cpu() + encoder_out_lens = encoder_out_lens.tolist() + device = ctc_output.device + + nnlm_scale = 0 + init_scores = None + init_states = None + if NNLM is not None: + nnlm_scale = NNLM.lm_scale + sos_id = getattr(NNLM, "sos_id", 1) + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_scores, init_states = NNLM.score_token(sos_token, lens) + init_scores, init_states = init_scores.cpu(), ( + init_states[0].cpu(), + init_states[1].cpu(), + ) + + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[], + log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32), + log_prob_blank=torch.zeros(1, dtype=torch.float32), + lm_score=torch.zeros(1, dtype=torch.float32), + state=init_states, + lm_log_probs=None if init_scores is None else init_scores.reshape(-1), + LODR_state=None if LODR_lm is None else NgramLmStateCost(LODR_lm), + context_state=None if context_graph is None else context_graph.root, + ) + ) + for j in range(num_frames): + for i in range(batch_size): + if j < encoder_out_lens[i]: + log_probs, indexes = topk_values[i][j], topk_indexes[i][j] + B[i] = _step_worker( + log_probs=log_probs, + indexes=indexes, + B=B[i], + beam=beam, + blank_id=blank_id, + nnlm_scale=nnlm_scale, + LODR_lm_scale=LODR_lm_scale, + context_graph=context_graph, + ) + if NNLM is None: + continue + # update lm_log_probs + token_list = [] # a list of list + hs = [] + cs = [] + indexes = [] # (batch_idx, key) + for batch_idx, hyps in enumerate(B): + for hyp in hyps: + if hyp.lm_log_probs is None: # those hyps that prefix changes + if NNLM.lm_type == "rnn": + token_list.append([hyp.ys[-1]]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append([sos_id] + hyp.ys[:]) + indexes.append((batch_idx, hyp.key)) + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if NNLM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + state = None + + scores, lm_states = NNLM.score_token(tokens_to_score, x_lens, state) + scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu()) + assert scores.size(0) == len(indexes), (scores.size(0), len(indexes)) + for i in range(scores.size(0)): + batch_idx, key = indexes[i] + B[batch_idx][key].lm_log_probs = scores[i] + if NNLM.lm_type == "rnn": + state = ( + lm_states[0][:, i, :].unsqueeze(1), + lm_states[1][:, i, :].unsqueeze(1), + ) + B[batch_idx][key].state = state + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + for hyps in B: + for hyp in hyps: + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + hyp.lm_score += context_score + hyp.context_state = new_context_state + + best_hyps = [b.get_most_probable() for b in B] + return [hyp.ys for hyp in best_hyps] + + +def ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output: torch.Tensor, + attention_decoder: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 8, + blank_id: int = 0, + attention_scale: Optional[float] = None, + process_pool: Optional[Pool] = None, +): + """Implement prefix search decoding in "Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add + attention decoder rescoring. + + Args: + ctc_output: + The output of ctc head (log probability), the shape is (B, T, V) + attention_decoder: + The attention decoder. + encoder_out: + The output of encoder, the shape is (B, T, D) + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + attention_scale: + The scale of attention decoder score, if not provided it will search in + a default list (see the code below). + process_pool: + The process pool for parallel decoding, if not provided, it will use all + you cpu cores by default. + """ + # List[HypothesisList] + nbest = ctc_prefix_beam_search( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + beam=beam, + blank_id=blank_id, + return_nbest=True, + ) + + device = ctc_output.device + + hyp_shape = get_hyps_shape(nbest).to(device) + hyp_to_utt_map = hyp_shape.row_ids(1).to(torch.long) + # the shape of encoder_out is (N, T, C), so we use axis=0 here + expanded_encoder_out = encoder_out.index_select(0, hyp_to_utt_map) + expanded_encoder_out_lens = encoder_out_lens.index_select(0, hyp_to_utt_map) + + nbest = [list(x) for x in nbest] + token_ids = [] + scores = [] + for hyps in nbest: + for hyp in hyps: + token_ids.append(hyp.ys) + scores.append(hyp.log_prob.reshape(1)) + scores = torch.cat(scores).to(device) + + nll = attention_decoder.nll( + encoder_out=expanded_encoder_out, + encoder_out_lens=expanded_encoder_out_lens, + token_ids=token_ids, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + if attention_scale is None: + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0] + else: + attention_scale_list = [attention_scale] + + ans = dict() + + start_indexes = hyp_shape.row_splits(1)[0:-1] + for a_scale in attention_scale_list: + tot_scores = scores + a_scale * attention_scores + ragged_tot_scores = k2.RaggedTensor(hyp_shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + max_indexes = max_indexes - start_indexes + max_indexes = max_indexes.cpu() + best_path = [nbest[i][max_indexes[i]].ys for i in range(len(max_indexes))] + key = f"attention_scale_{a_scale}" + ans[key] = best_path + return ans diff --git a/icefall/utils.py b/icefall/utils.py index 1dbb954ded..1f72addf23 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -19,8 +19,10 @@ import argparse import collections +import json import logging import os +import pathlib import re import subprocess from collections import defaultdict @@ -178,6 +180,15 @@ def __delattr__(self, key): return raise AttributeError(f"No such attribute '{key}'") + def __str__(self, indent: int = 2): + tmp = {} + for k, v in self.items(): + # PosixPath is ont JSON serializable + if isinstance(v, pathlib.Path) or isinstance(v, torch.device): + v = str(v) + tmp[k] = v + return json.dumps(tmp, indent=indent, sort_keys=True) + def encode_supervisions( supervisions: dict,