diff --git a/zenu-macros/src/lib.rs b/zenu-macros/src/lib.rs index c375c4b7..1e1317fc 100644 --- a/zenu-macros/src/lib.rs +++ b/zenu-macros/src/lib.rs @@ -120,11 +120,11 @@ impl Parse for ParametersArgs { num = Some(ty); } else if ident == "device" { device = Some(ty); - } else { - return Err(syn::Error::new( - ident.span(), - "Expected 'num' or 'device' in parameters attribute", - )); + // } else { + // return Err(syn::Error::new( + // ident.span(), + // "Expected 'num' or 'device' in parameters attribute", + // )); } if content.peek(Comma) { diff --git a/zenu-optimizer/src/adam.rs b/zenu-optimizer/src/adam.rs index f7ada9f8..be2ccb58 100644 --- a/zenu-optimizer/src/adam.rs +++ b/zenu-optimizer/src/adam.rs @@ -15,6 +15,39 @@ pub struct Adam { v: Vec>, } +// impl Optimizer for Adam { +// fn update(&self, parameters: &[Variable]) { +// let step = *self.step.borrow(); +// let step = step + T::one(); +// *self.step.borrow_mut() = step; +// +// let beta1_t = self.beta1.powf(step); +// let beta2_t = self.beta2.powf(step); +// +// for ((parameter, m), v) in parameters.iter().zip(&self.m).zip(&self.v) { +// let grad = parameter.get_grad().unwrap(); +// let grad = grad.get_data(); +// +// let mut v = v.get_data_mut(); +// let mut v = v.to_ref_mut(); +// let mut m = m.get_data_mut(); +// let mut m = m.to_ref_mut(); +// +// m *= self.beta1; +// m += grad.to_ref() * (T::one() - self.beta1); +// +// v *= self.beta2; +// v += grad.to_ref() * grad.to_ref() * (T::one() - self.beta2); +// +// let m_hat = m / (T::one() - beta1_t); +// let v_hat = v / (T::one() - beta2_t); +// +// let mut parameter_data = parameter.get_data_mut(); +// let mut parameter_data = parameter_data.to_ref_mut(); +// parameter_data -= m_hat / (v_hat.sqrt() + self.epsilon) * self.learning_rate; +// } +// } +// } impl Optimizer for Adam { fn update(&self, parameters: &[Variable]) { let step = *self.step.borrow(); diff --git a/zenu-optimizer/src/lib.rs b/zenu-optimizer/src/lib.rs index 10b8fceb..b309bc7d 100644 --- a/zenu-optimizer/src/lib.rs +++ b/zenu-optimizer/src/lib.rs @@ -3,8 +3,9 @@ pub mod adam; pub mod sgd; use zenu_autograd::Variable; +use zenu_layer::Parameters; use zenu_matrix::{device::Device, num::Num}; -pub trait Optimizer { - fn update(&self, parameters: &[Variable]); +pub trait Optimizer> { + fn update(&self, parameters: &P); } diff --git a/zenu-optimizer/src/sgd.rs b/zenu-optimizer/src/sgd.rs index 36fc6d46..96aa2fff 100644 --- a/zenu-optimizer/src/sgd.rs +++ b/zenu-optimizer/src/sgd.rs @@ -1,4 +1,5 @@ use zenu_autograd::Variable; +use zenu_layer::Parameters; use zenu_matrix::{device::Device, num::Num}; use crate::Optimizer; @@ -17,12 +18,21 @@ impl SGD { } } -impl Optimizer for SGD { - fn update(&self, parameters: &[Variable]) { - let parameters = parameters - .iter() - .filter(|parameter| parameter.get_grad().is_some()) - .collect::>(); +impl Optimizer for SGD { + fn update(&self, parameters: &P) { + let weights = parameters.weights(); + let biases = parameters.biases(); + let mut parameters = Vec::new(); + for (_, weight) in weights.iter() { + if let Some(grad) = weight.get_grad() { + parameters.push(grad); + } + } + for (_, bias) in biases.iter() { + if let Some(grad) = bias.get_grad() { + parameters.push(grad); + } + } for parameter in parameters { let grad = parameter.clone().get_grad().unwrap(); let grad = grad.get_data(); diff --git a/zenu/src/lib.rs b/zenu/src/lib.rs index 42f098ea..aad0a41a 100644 --- a/zenu/src/lib.rs +++ b/zenu/src/lib.rs @@ -13,7 +13,6 @@ use serde::Deserialize; use zenu_autograd::Variable; use zenu_layer::Parameters; use zenu_matrix::{device::Device, num::Num}; -use zenu_optimizer::Optimizer; pub extern crate zenu_macros; @@ -23,16 +22,6 @@ pub use zenu_macros as macros; pub use zenu_matrix as matrix; pub use zenu_optimizer as optimizer; -pub fn update_parameters>( - loss: &Variable, - optimizer: &O, -) { - loss.backward(); - let parameters = loss.get_all_trainable_variables(); - optimizer.update(¶meters); - loss.clear_grad(); -} - #[expect(clippy::missing_errors_doc)] pub fn save_model, P: AsRef>( model: &M,