Skip to content

Commit

Permalink
Fix self loop implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
janchorowski committed Mar 31, 2021
1 parent d45002b commit 97993a7
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion cpc/criterion/soft_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ def __init__(self,
self.allowed_skips_beg = allowed_skips_beg
self.allowed_skips_end = allowed_skips_end
self.predict_self_loop = predict_self_loop
if predict_self_loop:
self.self_loop_gain = torch.nn.Parameter(torch.ones(1))
else:
self.register_parameter('self_loop_gain', None)
self.limit_negs_in_batch = limit_negs_in_batch

if masq_rules:
Expand Down Expand Up @@ -318,7 +322,10 @@ def forward(self, cFeature, encodedData, label, captureOptions=None, return_loca
extra_preds.append(self.blank_proto.expand(batchSize, windowSize, self.blank_proto.size(2), 1))

if self.predict_self_loop:
extra_preds.append(cFeature.unsqueeze(-1))
# old and buggy
# extra_preds.append(cFeature.unsqueeze(-1))
# new and shiny
extra_preds.append(encodedData[:, :windowSize, :].unsqueeze(-1) * self.self_loop_gain)

if extra_preds:
nPredicts += len(extra_preds)
Expand Down

0 comments on commit 97993a7

Please sign in to comment.