diff --git a/beats/backbone.py b/beats/backbone.py index 43a7207dc..5d5a1bfc4 100644 --- a/beats/backbone.py +++ b/beats/backbone.py @@ -113,7 +113,7 @@ def extract_features(self, x, padding_mask=None, tgt_layer=None): x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) - x += x_conv + x = x.clone() + x_conv if not self.layer_norm_first: x = self.layer_norm(x)