diff --git a/rust/moshi-core/src/seanet.rs b/rust/moshi-core/src/seanet.rs index 6dae77e..af75914 100644 --- a/rust/moshi-core/src/seanet.rs +++ b/rust/moshi-core/src/seanet.rs @@ -117,7 +117,7 @@ impl Module for SeaNetResnetBlock { impl StreamingModule for SeaNetResnetBlock { fn reset_state(&mut self) { - // TODO(laurent): self.skip_op should probably be resetted here. + self.skip_op.reset_state(); for block in self.block.iter_mut() { block.reset_state() } @@ -132,12 +132,9 @@ impl StreamingModule for SeaNetResnetBlock { for block in self.block.iter_mut() { ys = block.step(&ys.apply(&self.activation)?)?; } - match self.shortcut.as_ref() { + match self.shortcut.as_mut() { None => self.skip_op.step(&ys, xs), - Some(shortcut) => { - // TODO(laurent): shouldn't this use shortcut.step(xs) instead? - self.skip_op.step(&ys, &xs.apply(shortcut)?) - } + Some(shortcut) => self.skip_op.step(&ys, &shortcut.step(xs)?), } } }