Skip to content

Commit

Permalink
fix the mismatch in batch_idx_train (#1757)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzasdf authored Oct 12, 2024
1 parent fbba712 commit 2653df5
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions icefall/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,12 @@ def average_checkpoints_with_averaged_model(
state_dict_start = torch.load(filename_start, map_location=device)
state_dict_end = torch.load(filename_end, map_location=device)

average_period = state_dict_start["average_period"]

batch_idx_train_start = state_dict_start["batch_idx_train"]
batch_idx_train_start = (batch_idx_train_start // average_period) * average_period
batch_idx_train_end = state_dict_end["batch_idx_train"]
batch_idx_train_end = (batch_idx_train_end // average_period) * average_period
interval = batch_idx_train_end - batch_idx_train_start
assert interval > 0, interval
weight_end = batch_idx_train_end / interval
Expand Down

0 comments on commit 2653df5

Please sign in to comment.