diff --git a/torch_cfc.py b/torch_cfc.py index 9c445c2c..7704e1b9 100644 --- a/torch_cfc.py +++ b/torch_cfc.py @@ -145,7 +145,6 @@ def forward(self, input, hx, ts): return new_hidden - class Cfc(nn.Module): def __init__( self, @@ -189,6 +188,8 @@ def forward(self, x, timespans=None, mask=None): time_since_update = torch.zeros( (batch_size, true_in_features), device=device ) + if timespans is None: + timespans = torch.ones(x.size(0), x.size(1), device=x.device) for t in range(seq_len): inputs = x[:, t] ts = timespans[:, t].squeeze() @@ -351,7 +352,6 @@ def _allocate_parameters(self): init_value=torch.zeros((self.sensory_size,)), ) - def _sigmoid(self, v_pre, mu, sigma): v_pre = torch.unsqueeze(v_pre, -1) # For broadcasting mues = v_pre - mu