Skip to content

Commit

Permalink
fix parameter_clipper
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Oct 28, 2024
1 parent f9c128c commit 29e13f8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
24 changes: 17 additions & 7 deletions src/parameter_clipper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@ use crate::{
inference::{Parameters, S_MIN},
pre_training::INIT_S_MAX,
};
use burn::tensor::{backend::Backend, Data, Tensor};
use burn::{
module::Param,
tensor::{backend::Backend, Data, Tensor},
};

pub(crate) fn parameter_clipper<B: Backend>(parameters: Tensor<B, 1>) -> Tensor<B, 1> {
let val = clip_parameters(&parameters.to_data().convert().value);
Tensor::from_data(
Data::new(val, parameters.shape()).convert(),
&B::Device::default(),
pub(crate) fn parameter_clipper<B: Backend>(
parameters: Param<Tensor<B, 1>>,
) -> Param<Tensor<B, 1>> {
let (id, val) = parameters.consume();
let clipped = clip_parameters(&val.to_data().convert().value);
Param::initialized(
id,
Tensor::from_data(
Data::new(clipped, val.shape()).convert(),
&B::Device::default(),
)
.require_grad(),
)
}

Expand Down Expand Up @@ -58,7 +68,7 @@ mod tests {
&device,
);

let param: Tensor<1> = parameter_clipper(tensor);
let param = parameter_clipper(Param::from_tensor(tensor));
let values = &param.to_data().value;

assert_eq!(
Expand Down
9 changes: 4 additions & 5 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use burn::tensor::backend::Backend;
use burn::tensor::{Int, Tensor};
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
use burn::train::TrainingInterrupter;
use burn::{config::Config, module::Param, tensor::backend::AutodiffBackend};
use burn::{config::Config, tensor::backend::AutodiffBackend};
use core::marker::PhantomData;
use log::info;

Expand Down Expand Up @@ -369,8 +369,7 @@ fn train<B: AutodiffBackend>(
}
let grads = GradientsParams::from_grads(gradients, &model);
model = optim.step(lr, model, grads);
// TODO: bug in https://github.com/tracel-ai/burn/issues/2428
model.w = Param::from_tensor(parameter_clipper(model.w.val()));
model.w = parameter_clipper(model.w);
// info!("epoch: {:?} iteration: {:?} lr: {:?}", epoch, iteration, lr);
renderer.render_train(TrainingProgress {
progress,
Expand Down Expand Up @@ -521,7 +520,7 @@ mod tests {
let lr = 0.04;
let grads = GradientsParams::from_grads(gradients, &model);
model = optim.step(lr, model, grads);
model.w = Param::from_tensor(parameter_clipper(model.w.val()));
model.w = parameter_clipper(model.w);
assert_eq!(
model.w.val().to_data(),
Data::from([
Expand Down Expand Up @@ -592,7 +591,7 @@ mod tests {
.assert_approx_eq(&w_grad.clone().into_data(), 5);
let grads = GradientsParams::from_grads(gradients, &model);
model = optim.step(lr, model, grads);
model.w = Param::from_tensor(parameter_clipper(model.w.val()));
model.w = parameter_clipper(model.w);
assert_eq!(
model.w.val().to_data(),
Data::from([
Expand Down

0 comments on commit 29e13f8

Please sign in to comment.