From bef6151d0aa47db9a36fb4a3dde7516d6b6e9800 Mon Sep 17 00:00:00 2001 From: Khari Date: Thu, 4 Jul 2024 20:33:22 -0400 Subject: [PATCH] =?UTF-8?q?Bug=20Fix=20-=20Assumes=20even=20distribution?= =?UTF-8?q?=20for=20timespans=20if=20timespans=20is=20None=20=E2=9C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_cfc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_cfc.py b/torch_cfc.py index 9c445c2c..7704e1b9 100644 --- a/torch_cfc.py +++ b/torch_cfc.py @@ -145,7 +145,6 @@ def forward(self, input, hx, ts): return new_hidden - class Cfc(nn.Module): def __init__( self, @@ -189,6 +188,8 @@ def forward(self, x, timespans=None, mask=None): time_since_update = torch.zeros( (batch_size, true_in_features), device=device ) + if timespans is None: + timespans = torch.ones(x.size(0), x.size(1), device=x.device) for t in range(seq_len): inputs = x[:, t] ts = timespans[:, t].squeeze() @@ -351,7 +352,6 @@ def _allocate_parameters(self): init_value=torch.zeros((self.sensory_size,)), ) - def _sigmoid(self, v_pre, mu, sigma): v_pre = torch.unsqueeze(v_pre, -1) # For broadcasting mues = v_pre - mu