From 1f6687317d1fb13d12f9f8f9e672d1963363d4c8 Mon Sep 17 00:00:00 2001 From: Jan Chorowski Date: Thu, 1 Apr 2021 02:42:37 +0200 Subject: [PATCH] Remove self loop gain. --- cpc/criterion/soft_align.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpc/criterion/soft_align.py b/cpc/criterion/soft_align.py index 12ed438..358f30a 100644 --- a/cpc/criterion/soft_align.py +++ b/cpc/criterion/soft_align.py @@ -202,10 +202,10 @@ def __init__(self, self.allowed_skips_beg = allowed_skips_beg self.allowed_skips_end = allowed_skips_end self.predict_self_loop = predict_self_loop - if predict_self_loop: - self.self_loop_gain = torch.nn.Parameter(torch.ones(1)) - else: - self.register_parameter('self_loop_gain', None) + # if predict_self_loop: + # self.self_loop_gain = torch.nn.Parameter(torch.ones(1)) + # else: + # self.register_parameter('self_loop_gain', None) self.limit_negs_in_batch = limit_negs_in_batch if masq_rules: @@ -325,7 +325,7 @@ def forward(self, cFeature, encodedData, label, captureOptions=None, return_loca # old and buggy # extra_preds.append(cFeature.unsqueeze(-1)) # new and shiny - extra_preds.append(encodedData[:, :windowSize, :].unsqueeze(-1) * self.self_loop_gain) + extra_preds.append(encodedData[:, :windowSize, :].unsqueeze(-1) ) # * self.self_loop_gain) if extra_preds: nPredicts += len(extra_preds)