From c84623679c1293a8c2ddce066e5add3b3b7446b8 Mon Sep 17 00:00:00 2001 From: bokutotu Date: Mon, 11 Nov 2024 02:11:40 +0900 Subject: [PATCH] update optimizer and tests --- zenu-layer/src/layers/linear.rs | 7 +- zenu-optimizer/Cargo.toml | 3 + zenu-optimizer/src/adam.rs | 279 ++++++++++++++----------------- zenu-optimizer/src/adamw.rs | 189 ++++++++++++--------- zenu-optimizer/src/lib.rs | 3 +- zenu-optimizer/src/sgd.rs | 54 +----- zenu-optimizer/tests/net_test.rs | 141 ++++++++++++++++ zenu/examples/mnist.rs | 32 ++-- 8 files changed, 404 insertions(+), 304 deletions(-) create mode 100644 zenu-optimizer/tests/net_test.rs diff --git a/zenu-layer/src/layers/linear.rs b/zenu-layer/src/layers/linear.rs index 7c419580..46840a04 100644 --- a/zenu-layer/src/layers/linear.rs +++ b/zenu-layer/src/layers/linear.rs @@ -4,7 +4,7 @@ use crate::{Module, Parameters}; use rand_distr::{Distribution, StandardNormal}; use zenu_autograd::{ creator::{rand::normal, zeros::zeros}, - functions::matmul::matmul, + functions::{matmul::matmul, transpose::transpose}, Variable, }; use zenu_matrix::{device::Device, num::Num}; @@ -20,7 +20,8 @@ impl Module for Linear { type Input = Variable; type Output = Variable; fn call(&self, input: Variable) -> Variable { - let output = matmul(input, self.weight.clone()); + let weight_t = transpose(self.weight.clone()); + let output = matmul(input, weight_t); if let Some(bias) = &self.bias { output.set_name("linear.intermediate_output"); output + bias.clone() @@ -52,7 +53,7 @@ impl Linear { where StandardNormal: Distribution, { - let weight = normal(T::zero(), T::one(), None, [in_features, out_features]); + let weight = normal(T::zero(), T::one(), None, [out_features, in_features]); weight .get_data_mut() .to_ref_mut() diff --git a/zenu-optimizer/Cargo.toml b/zenu-optimizer/Cargo.toml index ca567ac9..e948a14d 100644 --- a/zenu-optimizer/Cargo.toml +++ b/zenu-optimizer/Cargo.toml @@ -15,6 +15,9 @@ zenu-layer = { path = "../zenu-layer", version = "0.1.0" } [dev-dependencies] zenu-test = { path="../zenu-test/", version="*"} +zenu = { path="../zenu/", version="*"} +rand = { version = "0.8.5", features = ["small_rng"] } +rand_distr = "0.4.2" [lints] workspace = true diff --git a/zenu-optimizer/src/adam.rs b/zenu-optimizer/src/adam.rs index 729ca2fc..c9d7d067 100644 --- a/zenu-optimizer/src/adam.rs +++ b/zenu-optimizer/src/adam.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, collections::HashMap, rc::Rc}; use zenu_autograd::{creator::zeros::zeros_like, Variable}; use zenu_layer::Parameters; @@ -12,43 +12,10 @@ pub struct Adam { beta2: T, epsilon: T, step: Rc>, - m: Vec>, - v: Vec>, + m: HashMap>, + v: HashMap>, } -// impl Optimizer for Adam { -// fn update(&self, parameters: &[Variable]) { -// let step = *self.step.borrow(); -// let step = step + T::one(); -// *self.step.borrow_mut() = step; -// -// let beta1_t = self.beta1.powf(step); -// let beta2_t = self.beta2.powf(step); -// -// for ((parameter, m), v) in parameters.iter().zip(&self.m).zip(&self.v) { -// let grad = parameter.get_grad().unwrap(); -// let grad = grad.get_data(); -// -// let mut v = v.get_data_mut(); -// let mut v = v.to_ref_mut(); -// let mut m = m.get_data_mut(); -// let mut m = m.to_ref_mut(); -// -// m *= self.beta1; -// m += grad.to_ref() * (T::one() - self.beta1); -// -// v *= self.beta2; -// v += grad.to_ref() * grad.to_ref() * (T::one() - self.beta2); -// -// let m_hat = m / (T::one() - beta1_t); -// let v_hat = v / (T::one() - beta2_t); -// -// let mut parameter_data = parameter.get_data_mut(); -// let mut parameter_data = parameter_data.to_ref_mut(); -// parameter_data -= m_hat / (v_hat.sqrt() + self.epsilon) * self.learning_rate; -// } -// } -// } impl> Optimizer for Adam { fn update(&self, parameters: &P) { let step = *self.step.borrow(); @@ -58,21 +25,21 @@ impl> Optimizer for Adam { let beta1_t = self.beta1.powf(step); let beta2_t = self.beta2.powf(step); - let weights = parameters.weights(); - let biases = parameters.biases(); - let mut parameters = Vec::new(); - for (_, weight) in weights.iter() { - if let Some(grad) = weight.get_grad() { - parameters.push(grad); - } - } - for (_, bias) in biases.iter() { - if let Some(grad) = bias.get_grad() { - parameters.push(grad); - } - } - - for ((parameter, m), v) in parameters.iter().zip(&self.m).zip(&self.v) { + let parameters = parameters + .parameters() + .iter() + .filter_map(|(key, value)| { + if value.get_grad().is_some() { + Some((key.clone(), value.clone())) + } else { + None + } + }) + .collect::>(); + + for (k, parameter) in ¶meters { + let v = self.v.get(k).unwrap(); + let m = self.m.get(k).unwrap(); let grad = parameter.get_data(); let mut v = v.get_data_mut(); let mut v = v.to_ref_mut(); @@ -101,15 +68,17 @@ impl Adam { beta1: T, beta2: T, epsilon: T, - parameters: &[Variable], + model: &impl Parameters, ) -> Self { - let m = parameters + let m = model + .parameters() .iter() - .map(|parameter| zeros_like(parameter)) + .map(|(key, value)| (key.clone(), zeros_like(value))) .collect(); - let v = parameters + let v = model + .parameters() .iter() - .map(|parameter| zeros_like(parameter)) + .map(|(key, value)| (key.clone(), zeros_like(value))) .collect(); Self { learning_rate, @@ -123,102 +92,102 @@ impl Adam { } } -#[cfg(test)] -mod adam_tests { - use zenu_autograd::{ - creator::from_vec::from_vec, functions::matmul::matmul, loss::mse::mean_squared_error, - Variable, - }; - use zenu_matrix::{device::Device, dim::DimDyn, matrix::Matrix}; - use zenu_test::{assert_val_eq, run_test}; - - use crate::Optimizer; - - use super::Adam; - - fn simple_function( - x: Variable, - weight1: Variable, - weight2: Variable, - ) -> Variable { - let x = matmul(x, weight1); - matmul(x, weight2) - } - - #[expect(clippy::needless_pass_by_value, clippy::type_complexity)] - fn adam_apply( - adam: &Adam, - forward_func: fn(Variable, Variable, Variable) -> Variable, - input: Variable, - target: Variable, - weight1: Variable, - weight2: Variable, - ) { - let output = forward_func(input.clone(), weight1.clone(), weight2.clone()); - let loss = mean_squared_error(target, output); - loss.backward(); - adam.update(&[weight1.clone(), weight2.clone()]); - loss.clear_grad(); - } - - #[expect(clippy::unreadable_literal)] - fn small_2_times() { - // Initial weights: - // Weight1: 10.000000 - // Weight2: 10.000000 - // - // Iteration 1: - // Input: 1.000000 - // Target: 6.000000 - // Weight1: 9.900000 - // Weight2: 9.900000 - // Loss: 8836.000000 - // - // Iteration 2: - // Input: 1.100000 - // Target: 6.600000 - // Weight1: 9.799901 - // Weight2: 9.799901 - // Loss: 10243.665039 - let ans_weight_1 = from_vec::(vec![2.], [1, 1]); - let ans_weight_2 = from_vec::(vec![3.], [1, 1]); - - let weight_1 = from_vec::(vec![10.], [1, 1]); - let weight_2 = from_vec::(vec![10.], [1, 1]); - - let adam = Adam::new(0.1, 0.9, 0.999, 1e-8, &[weight_1.clone(), weight_2.clone()]); - - // iter 1 - let input = from_vec::(vec![1.], [1, 1]); - let target = simple_function(input.clone(), ans_weight_1.clone(), ans_weight_2.clone()); - adam_apply( - &adam, - simple_function, - input, - target, - weight_1.clone(), - weight_2.clone(), - ); - let iter_1_weight_1 = Matrix::<_, DimDyn, D>::from_vec(vec![9.9], [1, 1]); - let iter_1_weight_2 = Matrix::<_, DimDyn, D>::from_vec(vec![9.9], [1, 1]); - assert_val_eq!(weight_1.clone(), iter_1_weight_1, 1e-6); - assert_val_eq!(weight_2.clone(), iter_1_weight_2, 1e-6); - - // iter 2 - let input = from_vec::(vec![1.1], [1, 1]); - let target = simple_function(input.clone(), ans_weight_1.clone(), ans_weight_2.clone()); - adam_apply( - &adam, - simple_function, - input, - target, - weight_1.clone(), - weight_2.clone(), - ); - let iter_2_weight_1 = Matrix::<_, DimDyn, D>::from_vec(vec![9.799901], [1, 1]); - let iter_2_weight_2 = Matrix::<_, DimDyn, D>::from_vec(vec![9.799901], [1, 1]); - assert_val_eq!(weight_1.clone(), iter_2_weight_1, 2e-4); - assert_val_eq!(weight_2.clone(), iter_2_weight_2, 2e-4); - } - run_test!(small_2_times, small_2_times_cpu, small_2_times_gpu); -} +// #[cfg(test)] +// mod adam_tests { +// use zenu_autograd::{ +// creator::from_vec::from_vec, functions::matmul::matmul, loss::mse::mean_squared_error, +// Variable, +// }; +// use zenu_matrix::{device::Device, dim::DimDyn, matrix::Matrix}; +// use zenu_test::{assert_val_eq, run_test}; +// +// use crate::Optimizer; +// +// use super::Adam; +// +// fn simple_function( +// x: Variable, +// weight1: Variable, +// weight2: Variable, +// ) -> Variable { +// let x = matmul(x, weight1); +// matmul(x, weight2) +// } +// +// #[expect(clippy::needless_pass_by_value, clippy::type_complexity)] +// fn adam_apply( +// adam: &Adam, +// forward_func: fn(Variable, Variable, Variable) -> Variable, +// input: Variable, +// target: Variable, +// weight1: Variable, +// weight2: Variable, +// ) { +// let output = forward_func(input.clone(), weight1.clone(), weight2.clone()); +// let loss = mean_squared_error(target, output); +// loss.backward(); +// adam.update(&[weight1.clone(), weight2.clone()]); +// loss.clear_grad(); +// } +// +// #[expect(clippy::unreadable_literal)] +// fn small_2_times() { +// // Initial weights: +// // Weight1: 10.000000 +// // Weight2: 10.000000 +// // +// // Iteration 1: +// // Input: 1.000000 +// // Target: 6.000000 +// // Weight1: 9.900000 +// // Weight2: 9.900000 +// // Loss: 8836.000000 +// // +// // Iteration 2: +// // Input: 1.100000 +// // Target: 6.600000 +// // Weight1: 9.799901 +// // Weight2: 9.799901 +// // Loss: 10243.665039 +// let ans_weight_1 = from_vec::(vec![2.], [1, 1]); +// let ans_weight_2 = from_vec::(vec![3.], [1, 1]); +// +// let weight_1 = from_vec::(vec![10.], [1, 1]); +// let weight_2 = from_vec::(vec![10.], [1, 1]); +// +// let adam = Adam::new(0.1, 0.9, 0.999, 1e-8, &[weight_1.clone(), weight_2.clone()]); +// +// // iter 1 +// let input = from_vec::(vec![1.], [1, 1]); +// let target = simple_function(input.clone(), ans_weight_1.clone(), ans_weight_2.clone()); +// adam_apply( +// &adam, +// simple_function, +// input, +// target, +// weight_1.clone(), +// weight_2.clone(), +// ); +// let iter_1_weight_1 = Matrix::<_, DimDyn, D>::from_vec(vec![9.9], [1, 1]); +// let iter_1_weight_2 = Matrix::<_, DimDyn, D>::from_vec(vec![9.9], [1, 1]); +// assert_val_eq!(weight_1.clone(), iter_1_weight_1, 1e-6); +// assert_val_eq!(weight_2.clone(), iter_1_weight_2, 1e-6); +// +// // iter 2 +// let input = from_vec::(vec![1.1], [1, 1]); +// let target = simple_function(input.clone(), ans_weight_1.clone(), ans_weight_2.clone()); +// adam_apply( +// &adam, +// simple_function, +// input, +// target, +// weight_1.clone(), +// weight_2.clone(), +// ); +// let iter_2_weight_1 = Matrix::<_, DimDyn, D>::from_vec(vec![9.799901], [1, 1]); +// let iter_2_weight_2 = Matrix::<_, DimDyn, D>::from_vec(vec![9.799901], [1, 1]); +// assert_val_eq!(weight_1.clone(), iter_2_weight_1, 2e-4); +// assert_val_eq!(weight_2.clone(), iter_2_weight_2, 2e-4); +// } +// run_test!(small_2_times, small_2_times_cpu, small_2_times_gpu); +// } diff --git a/zenu-optimizer/src/adamw.rs b/zenu-optimizer/src/adamw.rs index fe62a819..4aaa78af 100644 --- a/zenu-optimizer/src/adamw.rs +++ b/zenu-optimizer/src/adamw.rs @@ -1,83 +1,106 @@ -// use std::{cell::RefCell, rc::Rc}; -// -// use zenu_autograd::Variable; -// use zenu_matrix::{ -// constructor::zeros::Zeros, -// matrix::{MatrixBase, ToOwnedMatrix, ToViewMatrix, ToViewMutMatrix}, -// matrix_impl::OwnedMatrixDyn, -// num::Num, -// operation::basic_operations::{MatrixAddAssign, MatrixSubAssign}, -// }; -// -// pub struct AdamW { -// pub alpha: T, -// weight_decay: T, -// beta1: T, -// beta2: T, -// epsilon: T, -// m: Rc>>>, -// v: Rc>>>, -// theta: Rc>>>, -// } -// -// impl AdamW { -// pub fn new(alpha: T, weight_decay: T, beta1: T, beta2: T, epsilon: T) -> Self { -// Self { -// alpha, -// weight_decay, -// beta1, -// beta2, -// epsilon, -// m: Rc::new(RefCell::new(Vec::new())), -// v: Rc::new(RefCell::new(Vec::new())), -// theta: Rc::new(RefCell::new(Vec::new())), -// } -// } -// -// fn init_m(&self, parameters: &[Variable]) { -// let m = &mut *self.m.borrow_mut(); -// m.clear(); -// for p in parameters { -// m.push(OwnedMatrixDyn::zeros(p.get_data().shape())); -// } -// } -// -// fn init_v(&self, parameters: &[Variable]) { -// let v = &mut *self.v.borrow_mut(); -// v.clear(); -// for p in parameters { -// v.push(OwnedMatrixDyn::zeros(p.get_data().shape())); -// } -// } -// -// fn init_theta(&self, parameters: &[Variable]) { -// let theta = &mut *self.theta.borrow_mut(); -// theta.clear(); -// for p in parameters { -// theta.push(OwnedMatrixDyn::zeros(p.get_data().shape())); -// } -// } -// -// fn update_m(&self, parameters: &[Variable]) { -// let m = &mut *self.m.borrow_mut(); -// for (m, p) in m.iter_mut().zip(parameters) { -// let mut p_g = p.get_grad().unwrap().get_data(); -// let m_c = m.to_owned_matrix(); -// p_g.sub_assign(m_c); -// m.to_view_mut().add_assign(p_g * (T::one() - self.beta1)); -// } -// } -// -// fn update_v(&self, parameters: &[Variable]) { -// let v = &mut *self.v.borrow_mut(); -// for (v, p) in v.iter_mut().zip(parameters) { -// let p_g = p.get_grad().unwrap().get_data(); -// let v_c = v.to_owned_matrix(); -// let mut p_g = p_g.to_view() * p_g.to_view(); -// p_g.sub_assign(v_c); -// v.to_view_mut().add_assign(p_g * (T::one() - self.beta2)); -// } -// } -// -// fn -// } +use std::{cell::RefCell, collections::HashMap, rc::Rc}; + +use zenu_autograd::{creator::zeros::zeros_like, Variable}; +use zenu_layer::Parameters; +use zenu_matrix::{device::Device, num::Num}; + +use crate::Optimizer; + +pub struct AdamW { + learning_rate: T, + beta1: T, + beta2: T, + epsilon: T, + weight_decay: T, + step: Rc>, + m: HashMap>, + v: HashMap>, +} + +impl> Optimizer for AdamW { + fn update(&self, parameters: &P) { + let step = *self.step.borrow() + T::one(); + *self.step.borrow_mut() = step; + + let beta1_t = self.beta1.powf(step); + let beta2_t = self.beta2.powf(step); + + let weight_keys: Vec<_> = parameters.weights().keys().cloned().collect(); + + let params = parameters + .parameters() + .iter() + .filter_map(|(key, param)| { + if param.get_grad().is_some() { + Some((key.clone(), param.clone())) + } else { + None + } + }) + .collect::>(); + + for (k, parameter) in params { + let v_t = self.v.get(&k).unwrap(); + let m_t = self.m.get(&k).unwrap(); + let grad = parameter.get_grad().unwrap(); + let mut grad = grad.get_data_mut(); + let param_data = parameter.get_data(); + + if weight_keys.contains(&k) { + grad.to_ref_mut() + .add_assign(&(param_data.to_ref() * self.weight_decay).to_ref()); + } + + let mut m = m_t.get_data_mut(); + let mut v = v_t.get_data_mut(); + + m.to_ref_mut().mul_scalar_assign(self.beta1); + m.to_ref_mut() + .add_assign(&(grad.to_ref() * (T::one() - self.beta1)).to_ref()); + + v.to_ref_mut().mul_scalar_assign(self.beta2); + v.to_ref_mut() + .add_assign(&(grad.to_ref().sqrt() * (T::one() - self.beta2)).to_ref()); + + let m_hat = m.to_ref() / (T::one() - beta1_t); + let v_hat = v.to_ref() / (T::one() - beta2_t); + + let mut param_data_mut = parameter.get_data_mut(); + param_data_mut + .to_ref_mut() + .sub_assign(&(m_hat / (v_hat.sqrt() + self.epsilon) * self.learning_rate).to_ref()); + } + } +} + +impl AdamW { + pub fn new( + learning_rate: T, + beta1: T, + beta2: T, + epsilon: T, + weight_decay: T, + model: &impl Parameters, + ) -> Self { + let m = model + .parameters() + .iter() + .map(|(key, value)| (key.clone(), zeros_like(value))) + .collect(); + let v = model + .parameters() + .iter() + .map(|(key, value)| (key.clone(), zeros_like(value))) + .collect(); + Self { + learning_rate, + beta1, + beta2, + epsilon, + weight_decay, + step: Rc::new(RefCell::new(T::zero())), + m, + v, + } + } +} diff --git a/zenu-optimizer/src/lib.rs b/zenu-optimizer/src/lib.rs index b309bc7d..cdc6daf4 100644 --- a/zenu-optimizer/src/lib.rs +++ b/zenu-optimizer/src/lib.rs @@ -1,8 +1,7 @@ pub mod adam; -// pub mod adamw; +pub mod adamw; pub mod sgd; -use zenu_autograd::Variable; use zenu_layer::Parameters; use zenu_matrix::{device::Device, num::Num}; diff --git a/zenu-optimizer/src/sgd.rs b/zenu-optimizer/src/sgd.rs index f1890f2a..f9a2950b 100644 --- a/zenu-optimizer/src/sgd.rs +++ b/zenu-optimizer/src/sgd.rs @@ -1,4 +1,3 @@ -use zenu_autograd::Variable; use zenu_layer::Parameters; use zenu_matrix::{device::Device, num::Num}; @@ -20,54 +19,13 @@ impl SGD { impl> Optimizer for SGD { fn update(&self, parameters: &P) { - let weights = parameters.weights(); - let biases = parameters.biases(); - let mut parameters = Vec::new(); - for (_, weight) in weights.iter() { - if let Some(grad) = weight.get_grad() { - parameters.push(grad); + for data in parameters.parameters().values() { + if let Some(grad) = data.get_grad() { + let update_data = grad.get_data().to_ref() * self.learning_rate; + let mut data = data.get_data_mut(); + let mut data = data.to_ref_mut(); + data -= update_data; } } - for (_, bias) in biases.iter() { - if let Some(grad) = bias.get_grad() { - parameters.push(grad); - } - } - for parameter in parameters { - let grad = parameter.clone().get_grad().unwrap(); - let grad = grad.get_data(); - let update_data = grad.to_ref() * self.learning_rate; - - let mut data = parameter.get_data_mut(); - let mut data = data.to_ref_mut(); - data -= update_data; - } - } -} - -#[cfg(test)] -mod sgd { - use zenu_autograd::creator::from_vec::from_vec; - use zenu_matrix::{ - device::Device, - dim::DimDyn, - matrix::{Matrix, Owned}, - }; - use zenu_test::{assert_mat_eq_epsilon, run_test}; - - use crate::Optimizer; - - use super::SGD; - - // #[test] - fn simple_test() { - let variable = from_vec::(vec![1., 2., 3., 4., 5., 6.], [3, 2]); - variable.set_grad(from_vec(vec![1., 2., 3., 4., 5., 6.], [3, 2])); - let sgd = SGD::new(1.); - sgd.update(&[variable.clone()]); - let data = variable.get_data(); - let ans = Matrix::, DimDyn, D>::from_vec(vec![0., 0., 0., 0., 0., 0.], [3, 2]); - assert_mat_eq_epsilon!(data, ans, 1e-6); } - run_test!(simple_test, simple_test_cpu, simple_test_nvidia); } diff --git a/zenu-optimizer/tests/net_test.rs b/zenu-optimizer/tests/net_test.rs new file mode 100644 index 00000000..4905ea59 --- /dev/null +++ b/zenu-optimizer/tests/net_test.rs @@ -0,0 +1,141 @@ +use std::collections::HashMap; + +use zenu::{ + autograd::{loss::mse::mean_squared_error, Variable}, + layer::{layers::linear::Linear, Module}, + macros::Parameters, + matrix::{ + device::cpu::Cpu, + device::Device, + dim::DimDyn, + matrix::{Matrix, Owned}, + num::Num, + }, + optimizer::{sgd::SGD, Optimizer}, +}; +use zenu_autograd::creator::from_vec::from_vec; + +#[derive(Parameters)] +#[parameters(num = T, device = D)] +struct SimpleNet +where + T: Num, + D: Device, +{ + linear1: Linear, + linear2: Linear, +} + +impl SimpleNet { + fn new() -> Self { + use zenu::layer::Parameters; + let (input_weights, input_bias, output_weights, output_bias) = init_parameters(); + let input_weights = + Matrix::, DimDyn, D>::from_vec(input_weights, DimDyn::from([4, 2])); + let input_bias = Matrix::, DimDyn, D>::from_vec(input_bias, DimDyn::from([4])); + let output_weights = + Matrix::, DimDyn, D>::from_vec(output_weights, DimDyn::from([4, 4])); + let output_bias = Matrix::, DimDyn, D>::from_vec(output_bias, DimDyn::from([4])); + + let linear1 = Linear::new(2, 4, true); + let linear2 = Linear::new(4, 4, true); + + let weight = &(linear1.weights().values().collect::>())[0] + .get_data_mut() + .to_ref_mut(); + weight.copy_from(&input_weights); + + let bias = &(linear1.biases().values().collect::>())[0] + .get_data_mut() + .to_ref_mut(); + bias.copy_from(&input_bias); + + let weight = &(linear2.weights().values().collect::>())[0] + .get_data_mut() + .to_ref_mut(); + weight.copy_from(&output_weights); + + let bias = &(linear2.biases().values().collect::>())[0] + .get_data_mut() + .to_ref_mut(); + bias.copy_from(&output_bias); + + Self { linear1, linear2 } + } +} + +impl Module for SimpleNet { + type Input = Variable; + type Output = Variable; + + fn call(&self, input: Self::Input) -> Self::Output { + let x = self.linear1.call(input); + self.linear2.call(x) + } +} + +fn init_parameters() -> (Vec, Vec, Vec, Vec) { + let input_parameters = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.07, 0.08]; + let input_bias = vec![0.1, 0.2, 0.3, 0.4]; + + let output_parameters = vec![ + 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, + ]; + let output_bias = vec![0.1, 0.2, 0.3, 0.4]; + + (input_parameters, input_bias, output_parameters, output_bias) +} + +fn test_funcion_inner>>( + net: &SimpleNet, + optimizer: &O, +) -> HashMap> { + use zenu::layer::Parameters; + + let input = from_vec(vec![0.1, 0.2], DimDyn::from([1, 2])); + let target = from_vec(vec![0.1, 0.2, 0.3, 0.4], DimDyn::from([4])); + let output = net.call(input); + let loss = mean_squared_error(target, output); + loss.backward(); + optimizer.update(net); + + net.parameters() +} + +#[test] +fn sgd_test() { + use zenu_test::assert_val_eq; + + let net = SimpleNet::::new(); + let optimizer = SGD::new(0.9); + let parameters = test_funcion_inner(&net, &optimizer); + + let ans_linear1 = vec![ + 0.0891, 0.1782, 0.2857, 0.3715, 0.4917, 0.5833, 0.0596, 0.0592, + ]; + let ans_linear1 = + Matrix::, DimDyn, Cpu>::from_vec(ans_linear1, DimDyn::from([4, 2])); + let ans_bias1 = vec![-0.0089, 0.0574, 0.2166, 0.2961]; + let ans_bias1 = Matrix::, DimDyn, Cpu>::from_vec(ans_bias1, DimDyn::from([4])); + let ans_linear2 = vec![ + 0.0739, 0.1460, 0.2181, 0.3263, 0.4779, 0.5543, 0.0007, 0.0176, 0.0801, 0.0795, 0.0789, + 0.0920, 0.1164, 0.1119, 0.1075, 0.1217, + ]; + let ans_linear2 = + Matrix::, DimDyn, Cpu>::from_vec(ans_linear2, DimDyn::from([4, 4])); + let ans_bias2 = vec![-0.0742, 0.0525, 0.2339, 0.3095]; + let ans_bias2 = Matrix::, DimDyn, Cpu>::from_vec(ans_bias2, DimDyn::from([4])); + + assert_val_eq!( + parameters["linear1.linear.weight"].clone(), + ans_linear1, + 1e-4 + ); + assert_val_eq!(parameters["linear1.linear.bias"].clone(), ans_bias1, 1e-4); + assert_val_eq!( + parameters["linear2.linear.weight"].clone(), + ans_linear2, + 1e-4 + ); + assert_val_eq!(parameters["linear2.linear.bias"].clone(), ans_bias2, 1e-4); +} diff --git a/zenu/examples/mnist.rs b/zenu/examples/mnist.rs index c36ad490..999fed72 100644 --- a/zenu/examples/mnist.rs +++ b/zenu/examples/mnist.rs @@ -6,17 +6,22 @@ use zenu::{ dataset::{train_val_split, DataLoader, Dataset}, dataset_loader::mnist_dataset, layer::{layers::linear::Linear, Module}, - matrix::device::{cpu::Cpu, Device}, - optimizer::sgd::SGD, - update_parameters, + matrix::{ + device::{cpu::Cpu, Device}, + num::Num, + }, + optimizer::{sgd::SGD, Optimizer}, }; +use zenu_macros::Parameters; -pub struct SimpleModel { - pub linear_1: Linear, - pub linear_2: Linear, +#[derive(Parameters)] +#[parameters(num=T, device=D)] +pub struct SimpleModel { + pub linear_1: Linear, + pub linear_2: Linear, } -impl SimpleModel { +impl SimpleModel { #[must_use] pub fn new() -> Self { Self { @@ -26,21 +31,19 @@ impl SimpleModel { } } -impl Default for SimpleModel { +impl Default for SimpleModel { fn default() -> Self { Self::new() } } -impl Module for SimpleModel { +impl Module for SimpleModel { type Input = Variable; type Output = Variable; fn call(&self, inputs: Variable) -> Variable { let x = self.linear_1.call(inputs); let x = relu(x); - // let x = self.linear_2.call(x); - // x self.linear_2.call(x) } } @@ -76,7 +79,7 @@ impl Dataset for MnistDataset { #[expect(clippy::cast_precision_loss)] fn main() { - let model = SimpleModel::::new(); + let model = SimpleModel::::new(); let (train, test) = mnist_dataset().unwrap(); let (train, val) = train_val_split(&train, 0.8, true); @@ -104,7 +107,10 @@ fn main() { let pred = model.call(input); let loss = cross_entropy(pred, target); let loss_asum = loss.get_data().asum(); - update_parameters(&loss, &optimizer); + // update_parameters(&loss, &optimizer); + loss.backward(); + optimizer.update(&model); + loss.clear_grad(); train_loss += loss_asum; num_iter += 1; }