diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 7a0b25632..7e07271d4 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -43,6 +43,25 @@ def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor: return px.scatter_(dim=2, index=boundary, value=float("-inf")) +def _validate_st_lengths( + S: int, + T: int, + is_rnnt_type_regular: bool, + boundary: Optional[Tensor] = None, +): + assert S >= 0, S + if boundary is None: + assert ( + is_rnnt_type_regular or T >= S + ), f"Modified transducer requires T >= S, but got T={T} and S={S}" + else: + Ss = boundary[:, 2] + Ts = boundary[:, 3] + assert ( + is_rnnt_type_regular or (Ts >= Ss).all() + ), f"Modified transducer requires T >= S, but got T={Ts} and S={Ss}" + + def get_rnnt_logprobs( lm: Tensor, am: Tensor, @@ -145,11 +164,8 @@ def get_rnnt_logprobs( (B, T, C) = am.shape S = lm.shape[1] - 1 assert symbols.shape == (B, S), (symbols.shape, B, S) - assert S >= 0, S - assert ( - rnnt_type != "modified" or T >= S - ), f"Modified transducer requires T >= S, but got T={T} and S={S}" assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type + _validate_st_lengths(S, T, rnnt_type == "regular", boundary) # subtracting am_max and lm_max is to ensure the probs are in a good range # to do exp() without causing underflow or overflow. @@ -394,11 +410,8 @@ def get_rnnt_logprobs_joint( (B, T, S1, C) = logits.shape S = S1 - 1 assert symbols.shape == (B, S), (symbols.shape, B, S) - assert S >= 0, S - assert ( - rnnt_type != "modified" or T >= S - ), f"Modified transducer requires T >= S, but got T={T} and S={S}" assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type + _validate_st_lengths(S, T, rnnt_type == "regular", boundary) normalizers = torch.logsumexp(logits, dim=3) normalizers = normalizers.permute((0, 2, 1)) @@ -669,28 +682,50 @@ def get_rnnt_prune_ranges( """ (B, S, T1) = px_grad.shape T = py_grad.shape[-1] + + is_regular = T1 != T + assert T1 in [T, T + 1], (T1, T) S1 = S + 1 assert py_grad.shape == (B, S1, T), (py_grad.shape, B, S1, T) assert boundary.shape == (B, 4), (boundary.shape, B) - assert S >= 0, S + + _validate_st_lengths(S, T, is_regular, boundary) + + # in regular case s_range should be no less than + # a minimum integer satisfying `(s_range - 1) * t + 1 >= s + 1` + if is_regular: + Ss = boundary[:, 2] + Ts = boundary[:, 3] + s_range_min = ( + Ss.sub(1).div(Ts, rounding_mode="trunc").add(2).max().item() + ) + if s_range < s_range_min: + print( + f"Warning: get_rnnt_prune_ranges - got s_range={s_range} " + f"for boundaries S={Ss}, T={Ts}. Adjusting to {s_range_min}" + ) + s_range = s_range_min # s_range > S means we won't prune out any symbols. To make indexing with # ranges run normally, s_range should be equal to or less than ``S + 1``. if s_range > S: + print( + f"Warning: get_rnnt_prune_ranges - got s_range={s_range} " + f"for boundaries S={S}. Adjusting to {S + 1}" + ) s_range = S + 1 - if T1 == T: - assert ( - s_range >= 1 - ), f"""Pruning range for modified RNN-T should be equal to or greater - than 1, or no valid paths could survive pruning. Given {s_range}""" - - else: + if is_regular: assert ( s_range >= 2 ), f"""Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning. Given {s_range}""" + else: + assert ( + s_range >= 1 + ), f"""Pruning range for modified RNN-T should be equal to or greater + than 1, or no valid paths could survive pruning. Given {s_range}""" (B_stride, S_stride, T_stride) = py_grad.stride() blk_grad = torch.as_strided( @@ -1035,11 +1070,8 @@ def get_rnnt_logprobs_pruned( (B, T, s_range, C) = logits.shape assert ranges.shape == (B, T, s_range), (ranges.shape, B, T, s_range) (B, S) = symbols.shape - assert S >= 0, S - assert ( - rnnt_type != "modified" or T >= S - ), f"Modified transducer requires T >= S, but got T={T} and S={S}" assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type + _validate_st_lengths(S, T, rnnt_type == "regular", boundary) normalizers = torch.logsumexp(logits, dim=3) @@ -1347,11 +1379,8 @@ def get_rnnt_logprobs_smoothed( (B, T, C) = am.shape S = lm.shape[1] - 1 assert symbols.shape == (B, S), (symbols.shape, B, S) - assert S >= 0, S - assert ( - rnnt_type != "modified" or T >= S - ), f"Modified transducer requires T >= S, but got T={T} and S={S}" assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type + _validate_st_lengths(S, T, rnnt_type == "regular", boundary) # Caution: some parts of this code are a little less clear than they could # be due to optimizations. In particular it may not be totally obvious that @@ -1422,7 +1451,9 @@ def get_rnnt_logprobs_smoothed( unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1) ) # [B][S][1] - px = px_am + px_lm # [B][S][T+1] if rnnt_type == "regular", otherwise [B][S][T] + px = ( + px_am + px_lm + ) # [B][S][T+1] if rnnt_type == "regular", otherwise [B][S][T] px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1] or [B][S][T] px_amonly = ( diff --git a/k2/python/tests/rnnt_loss_test.py b/k2/python/tests/rnnt_loss_test.py index 9e4cd8646..917d2f936 100644 --- a/k2/python/tests/rnnt_loss_test.py +++ b/k2/python/tests/rnnt_loss_test.py @@ -281,7 +281,8 @@ def test_rnnt_loss_random(self): ) assert ( px.shape == (B, S, T) - if rnnt_type != "regular" else (B, S, T + 1) + if rnnt_type != "regular" + else (B, S, T + 1) ) assert py.shape == (B, S + 1, T) assert symbols.shape == (B, S) @@ -484,6 +485,7 @@ def test_rnnt_loss_smoothed(self): assert torch.allclose(m, expected.to(device)) def test_rnnt_loss_pruned(self): + print("\ntest_rnnt_loss_pruned.") B = 4 T = 300 S = 50 @@ -570,6 +572,7 @@ def test_rnnt_loss_pruned(self): # at this circumstance, the s_range would be greater than S, which will # raise errors (like, nan or inf loss) in our previous versions. def test_rnnt_loss_pruned_small_symbols_number(self): + print("\ntest_rnnt_loss_pruned_small_symbols_number.") B = 2 T = 20 S = 3 @@ -629,7 +632,7 @@ def test_rnnt_loss_pruned_small_symbols_number(self): ) S0 = 2 - if rnnt_type != "regular": + if rnnt_type == "modified": S0 = 1 for r in range(S0, S + 2): @@ -669,6 +672,7 @@ def test_rnnt_loss_pruned_small_symbols_number(self): # because we can not 100% sure that the new method is better than the old # one all the time, both of them are local optimal bounds. def test_prune_ranges(self): + print("\ntest_prune_range.") B = 5 T = 200 S = 100 @@ -755,6 +759,103 @@ def test_prune_ranges(self): print(f"Pruned with old ranges {r} : {loss}") + # Test low s_range values with large S and small T, + # at this circumstance, the s_range would not be enough + # to cover the whole sequence length (in regular rnnt mode) + # and would result in inf loss + def test_rnnt_loss_pruned_small_s_range(self): + print("\ntest_rnnt_loss_pruned_small_s_range.") + B = 2 + T = 2 + S = 10 + C = 10 + + frames = torch.randint(1, T, (B,)) + seq_lengths = torch.randint(1, S, (B,)) + T = torch.max(frames) + S = torch.max(seq_lengths) + + am_ = torch.randn((B, T, C), dtype=torch.float64) + lm_ = torch.randn((B, S + 1, C), dtype=torch.float64) + symbols_ = torch.randint(0, C, (B, S)) + terminal_symbol = C - 1 + + boundary_ = torch.zeros((B, 4), dtype=torch.int64) + boundary_[:, 2] = seq_lengths + boundary_[:, 3] = frames + + print(f"B = {B}, T = {T}, S = {S}, C = {C}") + + for rnnt_type in ["regular"]: + for device in self.devices: + # normal rnnt + am = am_.to(device) + lm = lm_.to(device) + symbols = symbols_.to(device) + boundary = boundary_.to(device) + + logits = am.unsqueeze(2) + lm.unsqueeze(1) + logits = logits.float() + + # nonlinear transform + logits = torch.sigmoid(logits) + + loss = k2.rnnt_loss( + logits=logits, + symbols=symbols, + termination_symbol=terminal_symbol, + boundary=boundary, + rnnt_type=rnnt_type, + reduction="none", + ) + + print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}") + + # pruning + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=terminal_symbol, + boundary=boundary, + rnnt_type=rnnt_type, + return_grad=True, + reduction="none", + ) + + S0 = 2 + + for r in range(S0, S + 2): + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=r, + ) + # (B, T, r, C) + pruned_am, pruned_lm = k2.do_rnnt_pruning( + am=am, lm=lm, ranges=ranges + ) + + logits = pruned_am + pruned_lm + + # nonlinear transform + logits = torch.sigmoid(logits) + + pruned_loss = k2.rnnt_loss_pruned( + logits=logits, + symbols=symbols, + ranges=ranges, + termination_symbol=terminal_symbol, + boundary=boundary, + rnnt_type=rnnt_type, + reduction="none", + ) + assert ( + not pruned_loss.isinf().any() + ), f"Pruned loss is inf for r={r}, S={S}, T={T}." + print(f"Pruned loss with range {r} : {pruned_loss}") + # Check that training with an empty reference does not cause a crash. def _test_rnnt_loss_empty_reference(self): B = 1