diff --git a/src/model.rs b/src/model.rs index 59761946..e82e80ab 100644 --- a/src/model.rs +++ b/src/model.rs @@ -26,18 +26,6 @@ impl Get for Tensor { } } -trait Pow { - // https://github.com/burn-rs/burn/issues/590 , after that finished, just remove this trait and below impl, all will ok. - fn pow(&self, other: Tensor) -> Tensor; -} - -impl Pow for Tensor { - fn pow(&self, other: Self) -> Self { - // a ^ b => exp(ln(a^b)) => exp(b ln (a)) - (self.clone().log() * other).exp() - } -} - impl Model { #[allow(clippy::new_without_default)] pub fn new(config: ModelConfig) -> Self { @@ -77,7 +65,7 @@ impl Model { last_s.clone() * (self.w.get(8).exp() * (-last_d + 11) - * (last_s.pow(-self.w.get(9))) + * (last_s.powf_scalar(self.w.get(9).neg().into_scalar())) * (((-r + 1) * self.w.get(10)).exp() - 1) * hard_penalty * easy_bonus @@ -91,8 +79,8 @@ impl Model { r: Tensor, ) -> Tensor { let new_s = self.w.get(11) - * last_d.pow(-self.w.get(12)) - * ((last_s.clone() + 1).pow(self.w.get(13)) - 1) + * last_d.powf_scalar(self.w.get(12).neg().into_scalar()) + * ((last_s.clone() + 1).powf_scalar(self.w.get(13).into_scalar()) - 1) * ((-r + 1) * self.w.get(14)).exp(); new_s .clone()