diff --git a/src/training.rs b/src/training.rs index deaa47a..33ce701 100644 --- a/src/training.rs +++ b/src/training.rs @@ -369,6 +369,7 @@ fn train( } let grads = GradientsParams::from_grads(gradients, &model); model = optim.step(lr, model, grads); + // TODO: bug in https://github.com/tracel-ai/burn/issues/2428 model.w = Param::from_tensor(parameter_clipper(model.w.val())); // info!("epoch: {:?} iteration: {:?} lr: {:?}", epoch, iteration, lr); renderer.render_train(TrainingProgress { @@ -520,6 +521,7 @@ mod tests { let lr = 0.04; let grads = GradientsParams::from_grads(gradients, &model); model = optim.step(lr, model, grads); + model.w = Param::from_tensor(parameter_clipper(model.w.val())); assert_eq!( model.w.val().to_data(), Data::from([ @@ -528,6 +530,93 @@ mod tests { 0.47655, 0.62210006 ]) ); + + let item = FSRSBatch { + t_historys: Tensor::from_floats( + Data::from([ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 1.0, 3.0], + [1.0, 3.0, 3.0, 5.0], + [3.0, 6.0, 6.0, 12.0], + ]), + &device, + ), + r_historys: Tensor::from_floats( + Data::from([ + [1.0, 2.0, 3.0, 4.0], + [3.0, 4.0, 2.0, 4.0], + [1.0, 4.0, 4.0, 3.0], + [4.0, 3.0, 3.0, 3.0], + [3.0, 1.0, 3.0, 3.0], + [2.0, 3.0, 3.0, 4.0], + ]), + &device, + ), + delta_ts: Tensor::from_floats(Data::from([4.0, 11.0, 12.0, 23.0]), &device), + labels: Tensor::from_ints(Data::from([1, 1, 1, 0]), &device), + }; + + let loss = model.forward_classification( + item.t_historys, + item.r_historys, + item.delta_ts, + item.labels, + Reduction::Sum, + ); + assert_eq!(loss.clone().into_data().convert::().value[0], 4.176347); + let gradients = loss.backward(); + let w_grad = model.w.grad(&gradients).unwrap(); + Data::from([ + -0.0401341, + -0.0061790533, + -0.00288913, + 0.01216853, + -0.05624995, + 1.147413, + 0.068084724, + -0.6906936, + 0.48760873, + -2.5428302, + 0.49044546, + -0.011574259, + 0.037729632, + -0.09633919, + -0.0009513022, + -0.12789416, + 0.19088513, + 0.2574597, + 0.049311582, + ]) + .assert_approx_eq(&w_grad.clone().into_data(), 5); + let grads = GradientsParams::from_grads(gradients, &model); + model = optim.step(lr, model, grads); + model.w = Param::from_tensor(parameter_clipper(model.w.val())); + assert_eq!( + model.w.val().to_data(), + Data::from([ + 0.48150504, + 1.2636971, + 3.2530522, + 15.611003, + 7.2749534, + 0.45482785, + 1.3808222, + 0.083782874, + 1.4658877, + 0.19898315, + 0.9393105, + 2.0193, + 0.030164223, + 0.37562984, + 2.3498251, + 0.3112984, + 2.909878, + 0.43652722, + 0.5825156 + ]) + ); } #[test]