Skip to content

Commit

Permalink
clamp average_retention
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Jan 21, 2025
1 parent ab8f8e8 commit aa47f58
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,12 @@ def train(self, verbose: bool = True):
outputs, _ = self.model(sequences)
stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0]
retentions = power_forgetting_curve(delta_ts, stabilities)
average_retention = labels.mean().repeat(real_batch_size)
normalized_cross_entropy = average_retention * torch.log(average_retention) + (
1 - average_retention
) * torch.log(1 - average_retention)
average_retention = (
labels.mean().clamp(0.0001, 0.9999).repeat(real_batch_size)
)
normalized_cross_entropy = average_retention * torch.log(
average_retention
) + (1 - average_retention) * torch.log(1 - average_retention)
loss = (
self.loss_fn(retentions, labels)
* weights
Expand Down Expand Up @@ -453,10 +455,12 @@ def eval(self):
outputs, _ = self.model(sequences.transpose(0, 1))
stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0]
retentions = power_forgetting_curve(delta_ts, stabilities)
average_retention = labels.mean().repeat(real_batch_size)
normalized_cross_entropy = average_retention * torch.log(average_retention) + (
1 - average_retention
) * torch.log(1 - average_retention)
average_retention = (
labels.mean().clamp(0.0001, 0.9999).repeat(real_batch_size)
)
normalized_cross_entropy = average_retention * torch.log(
average_retention
) + (1 - average_retention) * torch.log(1 - average_retention)
loss = (
self.loss_fn(retentions, labels)
* weights
Expand Down

0 comments on commit aa47f58

Please sign in to comment.