Skip to content

Commit

Permalink
fix 0.0f validation split
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 16, 2024
1 parent cf0f60e commit d07b6e3
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions examples/mnist/mnist-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,15 +604,16 @@ void mnist_model_train(mnist_model & model, const float * images, const float *

const int64_t t_epoch_us = ggml_time_us() - t_start_us;
const double t_epoch_s = 1e-6*t_epoch_us;
fprintf(stderr, "done, took %.2lfs, train_loss=%.6lf, train_acc=%.2f%%, ", t_epoch_s, loss_mean, percent_correct);
fprintf(stderr, "done, took %.2lfs, train_loss=%.6lf, train_acc=%.2f%%", t_epoch_s, loss_mean, percent_correct);
}

{
if (iex_split < nex) {
const std::pair<double, double> loss = mnist_loss(result_val);
const std::pair<double, double> acc = mnist_accuracy(result_val, labels + iex_split*MNIST_NCLASSES);

fprintf(stderr, "val_loss=%.6lf+-%.6lf, train_acc=%.2f+-%.2f%%\n", loss.first, loss.second, 100.0*acc.first, 100.0*acc.second);
fprintf(stderr, ", val_loss=%.6lf+-%.6lf, train_acc=%.2f+-%.2f%%", loss.first, loss.second, 100.0*acc.first, 100.0*acc.second);
}
fprintf(stderr, "\n");
}

const int64_t t_total_us = ggml_time_us() - t_start_us;
Expand Down

0 comments on commit d07b6e3

Please sign in to comment.