Skip to content

Commit

Permalink
change macros
Browse files Browse the repository at this point in the history
  • Loading branch information
bokutotu committed Nov 8, 2024
1 parent 6ccf9a0 commit c25714a
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 24 deletions.
10 changes: 5 additions & 5 deletions zenu-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ impl Parse for ParametersArgs {
num = Some(ty);
} else if ident == "device" {
device = Some(ty);
} else {
return Err(syn::Error::new(
ident.span(),
"Expected 'num' or 'device' in parameters attribute",
));
// } else {
// return Err(syn::Error::new(
// ident.span(),
// "Expected 'num' or 'device' in parameters attribute",
// ));
}

if content.peek(Comma) {
Expand Down
33 changes: 33 additions & 0 deletions zenu-optimizer/src/adam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,39 @@ pub struct Adam<T: Num, D: Device> {
v: Vec<Variable<T, D>>,
}

// impl<T: Num, D: Device> Optimizer<T, D> for Adam<T, D> {
// fn update(&self, parameters: &[Variable<T, D>]) {
// let step = *self.step.borrow();
// let step = step + T::one();
// *self.step.borrow_mut() = step;
//
// let beta1_t = self.beta1.powf(step);
// let beta2_t = self.beta2.powf(step);
//
// for ((parameter, m), v) in parameters.iter().zip(&self.m).zip(&self.v) {
// let grad = parameter.get_grad().unwrap();
// let grad = grad.get_data();
//
// let mut v = v.get_data_mut();
// let mut v = v.to_ref_mut();
// let mut m = m.get_data_mut();
// let mut m = m.to_ref_mut();
//
// m *= self.beta1;
// m += grad.to_ref() * (T::one() - self.beta1);
//
// v *= self.beta2;
// v += grad.to_ref() * grad.to_ref() * (T::one() - self.beta2);
//
// let m_hat = m / (T::one() - beta1_t);
// let v_hat = v / (T::one() - beta2_t);
//
// let mut parameter_data = parameter.get_data_mut();
// let mut parameter_data = parameter_data.to_ref_mut();
// parameter_data -= m_hat / (v_hat.sqrt() + self.epsilon) * self.learning_rate;
// }
// }
// }
impl<T: Num, D: Device> Optimizer<T, D> for Adam<T, D> {
fn update(&self, parameters: &[Variable<T, D>]) {
let step = *self.step.borrow();
Expand Down
5 changes: 3 additions & 2 deletions zenu-optimizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ pub mod adam;
pub mod sgd;

use zenu_autograd::Variable;
use zenu_layer::Parameters;
use zenu_matrix::{device::Device, num::Num};

pub trait Optimizer<T: Num, D: Device> {
fn update(&self, parameters: &[Variable<T, D>]);
pub trait Optimizer<T: Num, D: Device, P: Parameters<T, D>> {
fn update(&self, parameters: &P);
}
22 changes: 16 additions & 6 deletions zenu-optimizer/src/sgd.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use zenu_autograd::Variable;
use zenu_layer::Parameters;
use zenu_matrix::{device::Device, num::Num};

use crate::Optimizer;
Expand All @@ -17,12 +18,21 @@ impl<T: Num, D: Device> SGD<T, D> {
}
}

impl<T: Num, D: Device> Optimizer<T, D> for SGD<T, D> {
fn update(&self, parameters: &[Variable<T, D>]) {
let parameters = parameters
.iter()
.filter(|parameter| parameter.get_grad().is_some())
.collect::<Vec<_>>();
impl<T: Num, D: Device, P: Parameters> Optimizer<T, D, P> for SGD<T, D> {
fn update(&self, parameters: &P) {
let weights = parameters.weights();
let biases = parameters.biases();
let mut parameters = Vec::new();
for (_, weight) in weights.iter() {
if let Some(grad) = weight.get_grad() {
parameters.push(grad);
}
}
for (_, bias) in biases.iter() {
if let Some(grad) = bias.get_grad() {
parameters.push(grad);
}
}
for parameter in parameters {
let grad = parameter.clone().get_grad().unwrap();
let grad = grad.get_data();
Expand Down
11 changes: 0 additions & 11 deletions zenu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use serde::Deserialize;
use zenu_autograd::Variable;
use zenu_layer::Parameters;
use zenu_matrix::{device::Device, num::Num};
use zenu_optimizer::Optimizer;

pub extern crate zenu_macros;

Expand All @@ -23,16 +22,6 @@ pub use zenu_macros as macros;
pub use zenu_matrix as matrix;
pub use zenu_optimizer as optimizer;

pub fn update_parameters<T: Num, D: Device, O: Optimizer<T, D>>(
loss: &Variable<T, D>,
optimizer: &O,
) {
loss.backward();
let parameters = loss.get_all_trainable_variables();
optimizer.update(&parameters);
loss.clear_grad();
}

#[expect(clippy::missing_errors_doc)]
pub fn save_model<T: Num, D: Device, M: Parameters<T, D>, P: AsRef<Path>>(
model: &M,
Expand Down

0 comments on commit c25714a

Please sign in to comment.