diff --git a/src/pre_training.rs b/src/pre_training.rs index c027a3da..a79617fc 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -303,6 +303,8 @@ fn smooth_and_fill( #[cfg(test)] mod tests { + use burn::tensor::Data; + use super::*; use crate::dataset::split_data; use crate::training::calculate_average_recall; @@ -324,10 +326,8 @@ mod tests { let init_s0 = 1.0; let actual = loss(&delta_t, &recall, &count, init_s0, init_s0); assert_eq!(actual, 13.6243305); - assert_eq!( - format!("{:.4}", loss(&delta_t, &recall, &count, 2.0, init_s0)), - "14.5771" - ); + Data::from([loss(&delta_t, &recall, &count, 2.0, init_s0)]) + .assert_approx_eq(&Data::from([14.5771]), 5); } #[test]