Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Jul 31, 2023
1 parent 70d603d commit 3631361
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -907,9 +907,9 @@ def deprecated_greedy_search_batch_for_cross_attn(
logits = model.joiner(
current_encoder_out,
decoder_out.unsqueeze(1),
attn_encoder_out if t > 0 else torch.zeros_like(current_encoder_out),
None,
apply_attn=True,
attn_encoder_out if t < 0 else torch.zeros_like(current_encoder_out),
encoder_out_lens,
apply_attn=False,
project_input=False,
)
# logits'shape (batch_size, 1, 1, vocab_size)
Expand Down
14 changes: 12 additions & 2 deletions egs/librispeech/ASR/zipformer_label_level_algn/joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import torch
import torch.nn as nn
from alignment_attention_module import AlignmentAttentionModule
Expand All @@ -34,6 +36,7 @@ def __init__(
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
self.enable_attn = False

def forward(
self,
Expand Down Expand Up @@ -64,15 +67,22 @@ def forward(
decoder_out.shape,
)

if apply_attn and lengths is not None:
if apply_attn:
if not self.enable_attn:
self.enable_attn = True
logging.info("enabling ATTN!")
attn_encoder_out = self.label_level_am_attention(
encoder_out, decoder_out, lengths
)

if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:
logit = encoder_out + decoder_out + attn_encoder_out
if apply_attn:
logit = encoder_out + decoder_out + attn_encoder_out
else:
# logging.info("disabling cross attn mdl")
logit = encoder_out + decoder_out

logit = self.output_linear(torch.tanh(logit))

Expand Down
11 changes: 9 additions & 2 deletions egs/librispeech/ASR/zipformer_label_level_algn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
from encoder_interface import EncoderInterface
from scaling import ScaledLinear

from icefall.utils import add_sos, make_pad_mask
from icefall.utils import add_sos, make_pad_mask, AttributeDict


class AsrModel(nn.Module):
def __init__(
self,
params: AttributeDict,
encoder_embed: nn.Module,
encoder: EncoderInterface,
decoder: Optional[nn.Module] = None,
Expand Down Expand Up @@ -79,6 +80,8 @@ def __init__(

assert isinstance(encoder, EncoderInterface), type(encoder)

self.params = params

self.encoder_embed = encoder_embed
self.encoder = encoder

Expand Down Expand Up @@ -180,6 +183,7 @@ def forward_transducer(
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
batch_idx_train: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
Expand Down Expand Up @@ -264,12 +268,13 @@ def forward_transducer(

# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
# print(batch_idx_train)
logits = self.joiner(
am_pruned,
lm_pruned,
None,
encoder_out_lens,
apply_attn=True,
apply_attn=batch_idx_train > self.params.warm_step, # True, # batch_idx_train > self.params.warm_step,
project_input=False,
)

Expand All @@ -293,6 +298,7 @@ def forward(
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
batch_idx_train: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
Expand Down Expand Up @@ -345,6 +351,7 @@ def forward(
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
batch_idx_train=batch_idx_train,
)
else:
simple_loss = torch.empty(0)
Expand Down
2 changes: 2 additions & 0 deletions egs/librispeech/ASR/zipformer_label_level_algn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ def get_model(params: AttributeDict) -> nn.Module:
joiner = None

model = AsrModel(
params=params,
encoder_embed=encoder_embed,
encoder=encoder,
decoder=decoder,
Expand Down Expand Up @@ -800,6 +801,7 @@ def compute_loss(
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
batch_idx_train=batch_idx_train,
)

loss = 0.0
Expand Down

0 comments on commit 3631361

Please sign in to comment.