Skip to content

Commit

Permalink
remove pow polyfill
Browse files Browse the repository at this point in the history
  • Loading branch information
asukaminato0721 committed Jun 10, 2024
1 parent a2d9ea0 commit dd88b5d
Showing 1 changed file with 3 additions and 15 deletions.
18 changes: 3 additions & 15 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,6 @@ impl<B: Backend, const N: usize> Get<B, N> for Tensor<B, N> {
}
}

trait Pow<B: Backend, const N: usize> {
// https://github.com/burn-rs/burn/issues/590 , after that finished, just remove this trait and below impl, all will ok.
fn pow(&self, other: Tensor<B, N>) -> Tensor<B, N>;
}

impl<B: Backend, const N: usize> Pow<B, N> for Tensor<B, N> {
fn pow(&self, other: Self) -> Self {
// a ^ b => exp(ln(a^b)) => exp(b ln (a))
(self.clone().log() * other).exp()
}
}

impl<B: Backend> Model<B> {
#[allow(clippy::new_without_default)]
pub fn new(config: ModelConfig) -> Self {
Expand Down Expand Up @@ -77,7 +65,7 @@ impl<B: Backend> Model<B> {
last_s.clone()
* (self.w.get(8).exp()
* (-last_d + 11)
* (last_s.pow(-self.w.get(9)))
* (last_s.powf_scalar(self.w.get(9).neg().into_scalar()))
* (((-r + 1) * self.w.get(10)).exp() - 1)
* hard_penalty
* easy_bonus
Expand All @@ -91,8 +79,8 @@ impl<B: Backend> Model<B> {
r: Tensor<B, 1>,
) -> Tensor<B, 1> {
let new_s = self.w.get(11)
* last_d.pow(-self.w.get(12))
* ((last_s.clone() + 1).pow(self.w.get(13)) - 1)
* last_d.powf_scalar(self.w.get(12).neg().into_scalar())
* ((last_s.clone() + 1).powf_scalar(self.w.get(13).into_scalar()) - 1)
* ((-r + 1) * self.w.get(14)).exp();
new_s
.clone()
Expand Down

0 comments on commit dd88b5d

Please sign in to comment.