Skip to content

Commit

Permalink
locate the bug related to parameter_clipper
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Oct 27, 2024
1 parent 55cb2e4 commit f9c128c
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ fn train<B: AutodiffBackend>(
}
let grads = GradientsParams::from_grads(gradients, &model);
model = optim.step(lr, model, grads);
// TODO: bug in https://github.com/tracel-ai/burn/issues/2428
model.w = Param::from_tensor(parameter_clipper(model.w.val()));
// info!("epoch: {:?} iteration: {:?} lr: {:?}", epoch, iteration, lr);
renderer.render_train(TrainingProgress {
Expand Down Expand Up @@ -520,6 +521,7 @@ mod tests {
let lr = 0.04;
let grads = GradientsParams::from_grads(gradients, &model);
model = optim.step(lr, model, grads);
model.w = Param::from_tensor(parameter_clipper(model.w.val()));
assert_eq!(
model.w.val().to_data(),
Data::from([
Expand All @@ -528,6 +530,93 @@ mod tests {
0.47655, 0.62210006
])
);

let item = FSRSBatch {
t_historys: Tensor::from_floats(
Data::from([
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 1.0, 3.0],
[1.0, 3.0, 3.0, 5.0],
[3.0, 6.0, 6.0, 12.0],
]),
&device,
),
r_historys: Tensor::from_floats(
Data::from([
[1.0, 2.0, 3.0, 4.0],
[3.0, 4.0, 2.0, 4.0],
[1.0, 4.0, 4.0, 3.0],
[4.0, 3.0, 3.0, 3.0],
[3.0, 1.0, 3.0, 3.0],
[2.0, 3.0, 3.0, 4.0],
]),
&device,
),
delta_ts: Tensor::from_floats(Data::from([4.0, 11.0, 12.0, 23.0]), &device),
labels: Tensor::from_ints(Data::from([1, 1, 1, 0]), &device),
};

let loss = model.forward_classification(
item.t_historys,
item.r_historys,
item.delta_ts,
item.labels,
Reduction::Sum,
);
assert_eq!(loss.clone().into_data().convert::<f32>().value[0], 4.176347);
let gradients = loss.backward();
let w_grad = model.w.grad(&gradients).unwrap();
Data::from([
-0.0401341,
-0.0061790533,
-0.00288913,
0.01216853,
-0.05624995,
1.147413,
0.068084724,
-0.6906936,
0.48760873,
-2.5428302,
0.49044546,
-0.011574259,
0.037729632,
-0.09633919,
-0.0009513022,
-0.12789416,
0.19088513,
0.2574597,
0.049311582,
])
.assert_approx_eq(&w_grad.clone().into_data(), 5);
let grads = GradientsParams::from_grads(gradients, &model);
model = optim.step(lr, model, grads);
model.w = Param::from_tensor(parameter_clipper(model.w.val()));
assert_eq!(
model.w.val().to_data(),
Data::from([
0.48150504,
1.2636971,
3.2530522,
15.611003,
7.2749534,
0.45482785,
1.3808222,
0.083782874,
1.4658877,
0.19898315,
0.9393105,
2.0193,
0.030164223,
0.37562984,
2.3498251,
0.3112984,
2.909878,
0.43652722,
0.5825156
])
);
}

#[test]
Expand Down

0 comments on commit f9c128c

Please sign in to comment.