Skip to content

Commit

Permalink
Fix s_range in rnnt_loss (#1245)
Browse files Browse the repository at this point in the history
* Fix s_range in rnnt_loss

* Fix style

* fix flake8
  • Loading branch information
pkufool authored Sep 26, 2023
1 parent 1f11a51 commit e24993c
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 27 deletions.
81 changes: 56 additions & 25 deletions k2/python/k2/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@ def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor:
return px.scatter_(dim=2, index=boundary, value=float("-inf"))


def _validate_st_lengths(
S: int,
T: int,
is_rnnt_type_regular: bool,
boundary: Optional[Tensor] = None,
):
assert S >= 0, S
if boundary is None:
assert (
is_rnnt_type_regular or T >= S
), f"Modified transducer requires T >= S, but got T={T} and S={S}"
else:
Ss = boundary[:, 2]
Ts = boundary[:, 3]
assert (
is_rnnt_type_regular or (Ts >= Ss).all()
), f"Modified transducer requires T >= S, but got T={Ts} and S={Ss}"


def get_rnnt_logprobs(
lm: Tensor,
am: Tensor,
Expand Down Expand Up @@ -145,11 +164,8 @@ def get_rnnt_logprobs(
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S), (symbols.shape, B, S)
assert S >= 0, S
assert (
rnnt_type != "modified" or T >= S
), f"Modified transducer requires T >= S, but got T={T} and S={S}"
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
_validate_st_lengths(S, T, rnnt_type == "regular", boundary)

# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
Expand Down Expand Up @@ -394,11 +410,8 @@ def get_rnnt_logprobs_joint(
(B, T, S1, C) = logits.shape
S = S1 - 1
assert symbols.shape == (B, S), (symbols.shape, B, S)
assert S >= 0, S
assert (
rnnt_type != "modified" or T >= S
), f"Modified transducer requires T >= S, but got T={T} and S={S}"
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
_validate_st_lengths(S, T, rnnt_type == "regular", boundary)

normalizers = torch.logsumexp(logits, dim=3)
normalizers = normalizers.permute((0, 2, 1))
Expand Down Expand Up @@ -669,28 +682,50 @@ def get_rnnt_prune_ranges(
"""
(B, S, T1) = px_grad.shape
T = py_grad.shape[-1]

is_regular = T1 != T

assert T1 in [T, T + 1], (T1, T)
S1 = S + 1
assert py_grad.shape == (B, S1, T), (py_grad.shape, B, S1, T)
assert boundary.shape == (B, 4), (boundary.shape, B)
assert S >= 0, S

_validate_st_lengths(S, T, is_regular, boundary)

# in regular case s_range should be no less than
# a minimum integer satisfying `(s_range - 1) * t + 1 >= s + 1`
if is_regular:
Ss = boundary[:, 2]
Ts = boundary[:, 3]
s_range_min = (
Ss.sub(1).div(Ts, rounding_mode="trunc").add(2).max().item()
)
if s_range < s_range_min:
print(
f"Warning: get_rnnt_prune_ranges - got s_range={s_range} "
f"for boundaries S={Ss}, T={Ts}. Adjusting to {s_range_min}"
)
s_range = s_range_min

# s_range > S means we won't prune out any symbols. To make indexing with
# ranges run normally, s_range should be equal to or less than ``S + 1``.
if s_range > S:
print(
f"Warning: get_rnnt_prune_ranges - got s_range={s_range} "
f"for boundaries S={S}. Adjusting to {S + 1}"
)
s_range = S + 1

if T1 == T:
assert (
s_range >= 1
), f"""Pruning range for modified RNN-T should be equal to or greater
than 1, or no valid paths could survive pruning. Given {s_range}"""

else:
if is_regular:
assert (
s_range >= 2
), f"""Pruning range for standard RNN-T should be equal to or greater
than 2, or no valid paths could survive pruning. Given {s_range}"""
else:
assert (
s_range >= 1
), f"""Pruning range for modified RNN-T should be equal to or greater
than 1, or no valid paths could survive pruning. Given {s_range}"""

(B_stride, S_stride, T_stride) = py_grad.stride()
blk_grad = torch.as_strided(
Expand Down Expand Up @@ -1035,11 +1070,8 @@ def get_rnnt_logprobs_pruned(
(B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range), (ranges.shape, B, T, s_range)
(B, S) = symbols.shape
assert S >= 0, S
assert (
rnnt_type != "modified" or T >= S
), f"Modified transducer requires T >= S, but got T={T} and S={S}"
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
_validate_st_lengths(S, T, rnnt_type == "regular", boundary)

normalizers = torch.logsumexp(logits, dim=3)

Expand Down Expand Up @@ -1347,11 +1379,8 @@ def get_rnnt_logprobs_smoothed(
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S), (symbols.shape, B, S)
assert S >= 0, S
assert (
rnnt_type != "modified" or T >= S
), f"Modified transducer requires T >= S, but got T={T} and S={S}"
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
_validate_st_lengths(S, T, rnnt_type == "regular", boundary)

# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
Expand Down Expand Up @@ -1422,7 +1451,9 @@ def get_rnnt_logprobs_smoothed(
unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]

px = px_am + px_lm # [B][S][T+1] if rnnt_type == "regular", otherwise [B][S][T]
px = (
px_am + px_lm
) # [B][S][T+1] if rnnt_type == "regular", otherwise [B][S][T]
px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1] or [B][S][T]

px_amonly = (
Expand Down
105 changes: 103 additions & 2 deletions k2/python/tests/rnnt_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ def test_rnnt_loss_random(self):
)
assert (
px.shape == (B, S, T)
if rnnt_type != "regular" else (B, S, T + 1)
if rnnt_type != "regular"
else (B, S, T + 1)
)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
Expand Down Expand Up @@ -484,6 +485,7 @@ def test_rnnt_loss_smoothed(self):
assert torch.allclose(m, expected.to(device))

def test_rnnt_loss_pruned(self):
print("\ntest_rnnt_loss_pruned.")
B = 4
T = 300
S = 50
Expand Down Expand Up @@ -570,6 +572,7 @@ def test_rnnt_loss_pruned(self):
# at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions.
def test_rnnt_loss_pruned_small_symbols_number(self):
print("\ntest_rnnt_loss_pruned_small_symbols_number.")
B = 2
T = 20
S = 3
Expand Down Expand Up @@ -629,7 +632,7 @@ def test_rnnt_loss_pruned_small_symbols_number(self):
)

S0 = 2
if rnnt_type != "regular":
if rnnt_type == "modified":
S0 = 1

for r in range(S0, S + 2):
Expand Down Expand Up @@ -669,6 +672,7 @@ def test_rnnt_loss_pruned_small_symbols_number(self):
# because we can not 100% sure that the new method is better than the old
# one all the time, both of them are local optimal bounds.
def test_prune_ranges(self):
print("\ntest_prune_range.")
B = 5
T = 200
S = 100
Expand Down Expand Up @@ -755,6 +759,103 @@ def test_prune_ranges(self):

print(f"Pruned with old ranges {r} : {loss}")

# Test low s_range values with large S and small T,
# at this circumstance, the s_range would not be enough
# to cover the whole sequence length (in regular rnnt mode)
# and would result in inf loss
def test_rnnt_loss_pruned_small_s_range(self):
print("\ntest_rnnt_loss_pruned_small_s_range.")
B = 2
T = 2
S = 10
C = 10

frames = torch.randint(1, T, (B,))
seq_lengths = torch.randint(1, S, (B,))
T = torch.max(frames)
S = torch.max(seq_lengths)

am_ = torch.randn((B, T, C), dtype=torch.float64)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float64)
symbols_ = torch.randint(0, C, (B, S))
terminal_symbol = C - 1

boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_lengths
boundary_[:, 3] = frames

print(f"B = {B}, T = {T}, S = {S}, C = {C}")

for rnnt_type in ["regular"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
lm = lm_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)

logits = am.unsqueeze(2) + lm.unsqueeze(1)
logits = logits.float()

# nonlinear transform
logits = torch.sigmoid(logits)

loss = k2.rnnt_loss(
logits=logits,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
)

print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}")

# pruning
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)

S0 = 2

for r in range(S0, S + 2):
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=r,
)
# (B, T, r, C)
pruned_am, pruned_lm = k2.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)

logits = pruned_am + pruned_lm

# nonlinear transform
logits = torch.sigmoid(logits)

pruned_loss = k2.rnnt_loss_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
)
assert (
not pruned_loss.isinf().any()
), f"Pruned loss is inf for r={r}, S={S}, T={T}."
print(f"Pruned loss with range {r} : {pruned_loss}")

# Check that training with an empty reference does not cause a crash.
def _test_rnnt_loss_empty_reference(self):
B = 1
Expand Down

0 comments on commit e24993c

Please sign in to comment.