Skip to content

Commit

Permalink
Merge pull request #43 from delta-rs/improving-sequential-model
Browse files Browse the repository at this point in the history
Improving sequential model output while training
  • Loading branch information
chaseWillden authored Dec 4, 2024
2 parents 8e43dad + ac14918 commit c2cfcad
Showing 1 changed file with 42 additions and 18 deletions.
60 changes: 42 additions & 18 deletions delta/src/neuralnet/models/sequential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,7 @@ impl Sequential {

for epoch in 0..epochs {
println!("\nEpoch {}/{}", epoch + 1, epochs);
let epoch_loss = self.train_one_epoch(train_data, batch_size, &mut optimizer);
println!(
"Epoch {} completed. Average Loss: {:.6}",
epoch + 1,
epoch_loss
);
self.train_one_epoch(train_data, batch_size, &mut optimizer);
}
}

Expand Down Expand Up @@ -178,13 +173,30 @@ impl Sequential {
) -> f32 {
let num_batches = train_data.len() / batch_size;
let mut epoch_loss = 0.0;
let mut correct_predictions = 0;
let mut total_samples = 0;

let start_time = Instant::now(); // Start timer
let start_time = Instant::now();

for batch_idx in 0..num_batches {
let (inputs, targets) = train_data.get_batch(batch_idx, batch_size);
epoch_loss += self.train_one_batch(&inputs, &targets, optimizer);
self.display_progress(batch_idx, num_batches, epoch_loss, start_time);
let batch_loss = self.train_one_batch(&inputs, &targets, optimizer);
epoch_loss += batch_loss;

// Calculate accuracy
let outputs = self.forward(&inputs);
let predictions = outputs.argmax(1);
let actuals = targets.argmax(1);
correct_predictions += predictions
.data
.iter()
.zip(actuals.data.iter())
.filter(|(pred, actual)| pred == actual)
.count();
total_samples += targets.shape()[0];

let accuracy = correct_predictions as f32 / total_samples as f32;
self.display_progress(batch_idx, num_batches, epoch_loss, accuracy, start_time);
}

epoch_loss / num_batches as f32
Expand Down Expand Up @@ -234,31 +246,35 @@ impl Sequential {
/// * `batch_idx` - The index of the current batch.
/// * `num_batches` - The total number of batches.
/// * `epoch_loss` - The current epoch loss.
/// * `accuracy` - The current accuracy.
/// * `start_time` - The start time of the training process.
fn display_progress(
&mut self,
batch_idx: usize,
num_batches: usize,
epoch_loss: f32,
accuracy: f32,
start_time: Instant,
) {
let progress = (batch_idx + 1) as f32 / num_batches as f32;
let current_avg_loss = epoch_loss / (batch_idx + 1) as f32;
let bar_width = 30;
let filled = (progress * bar_width as f32) as usize;
let arrow = if filled < bar_width { ">" } else { "=" };
let bar: String = std::iter::repeat('=')
.take(filled)
.chain(std::iter::repeat(' ').take(bar_width - filled))
.chain(std::iter::once(arrow.chars().next().unwrap()))
.chain(std::iter::repeat(' ').take((bar_width as isize - filled as isize - 1).max(0) as usize))
.collect();

let elapsed = start_time.elapsed();
let elapsed_secs = elapsed.as_secs_f32();
let estimated_total = elapsed_secs / progress;
let remaining_secs = estimated_total - elapsed_secs;
let remaining_secs = (estimated_total - elapsed_secs).max(0.0);

print!(
"\rProgress: [{}] - Current Average Loss: {:.6} | Elapsed: {:.2}s | Remaining: {:.2}s",
bar, current_avg_loss, elapsed_secs, remaining_secs
"\rProgress: [{}] - ETA: {:.2}s - loss: {:.6} - accuracy: {:.4}",
bar, remaining_secs, current_avg_loss, accuracy
);
std::io::stdout().flush().unwrap();
}
Expand Down Expand Up @@ -292,10 +308,9 @@ impl Sequential {
let mut correct_predictions = 0;
let mut total_samples = 0;

let num_batches = (test_data.len() + batch_size - 1) / batch_size; // Calculate number of batches
let num_batches = (test_data.len() + batch_size - 1) / batch_size;

for batch_idx in 0..num_batches {
// Fetch batch
let (inputs, targets) = test_data.get_batch(batch_idx, batch_size);

// Forward pass to get predictions
Expand Down Expand Up @@ -343,9 +358,18 @@ impl Sequential {
Ok(())
}

/*pub fn forward(&self, input: &Tensor) -> Tensor {
self.layers.iter().fold(input.clone(), |acc, layer| layer.forward(&acc))
}*/
/// Performs a forward pass through the model.
///
/// # Arguments
///
/// * `input` - The input tensor.
///
/// # Returns
///
/// The output tensor after passing through all layers.
pub fn forward(&mut self, input: &Tensor) -> Tensor {
self.layers.iter_mut().fold(input.clone(), |acc, layer| layer.forward(&acc))
}

/// Prints a summary of the model.
pub fn summary(&self) {
Expand Down

0 comments on commit c2cfcad

Please sign in to comment.