Skip to content

Commit

Permalink
ggml/ex: ref. CEL, ggml_backend_sched for MNIST
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 30, 2024
1 parent 4de6ee8 commit caf425b
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 55 deletions.
26 changes: 16 additions & 10 deletions examples/mnist/mnist-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ mnist_model mnist_model_init_from_file(const std::string & fname, const std::str
} else {
fprintf(stderr, "%s: unknown model arch: %s\n", __func__, model.arch.c_str());
}
model.buf_weight = ggml_backend_alloc_ctx_tensors(model.ctx_weight, model.backend);
model.buf_weight = ggml_backend_alloc_ctx_tensors(model.ctx_weight, model.backends[0]);

if(!load_from_gguf(fname.c_str(), model.ctx_weight, ctx)) {
fprintf(stderr, "%s: loading weights from %s failed\n", __func__, fname.c_str());
Expand Down Expand Up @@ -361,7 +361,7 @@ mnist_model mnist_model_init_random(const std::string & arch, const std::string
fprintf(stderr, "%s: unknown model arch: %s\n", __func__, model.arch.c_str());
}

model.buf_weight = ggml_backend_alloc_ctx_tensors(model.ctx_weight, model.backend);
model.buf_weight = ggml_backend_alloc_ctx_tensors(model.ctx_weight, model.backends[0]);

for (ggml_tensor * t : init_tensors) {
GGML_ASSERT(t->type == GGML_TYPE_F32);
Expand Down Expand Up @@ -488,7 +488,10 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co
struct ggml_cgraph * gf = ggml_new_graph(model.ctx_compute);
ggml_build_forward_expand(gf, model.loss);

model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
if(!ggml_backend_sched_alloc_graph(model.backend_sched, gf)) {
fprintf(stderr, "%s: failed to allocate compute graph\n", __func__);
exit(1);
}

{
const int64_t t_start_us = ggml_time_us();
Expand All @@ -504,7 +507,7 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co
ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images));
ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));

ggml_backend_graph_compute(model.backend, gf);
ggml_backend_sched_graph_compute(model.backend_sched, gf);

ggml_backend_tensor_get(model.loss, &loss, 0, ggml_nbytes(model.loss));
ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits));
Expand Down Expand Up @@ -536,13 +539,16 @@ void mnist_model_train(mnist_model & model, const float * images, const float *

// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf);
ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ true);
ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ model.nbatch_logical != model.nbatch_physical);

// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gb_grad);
ggml_build_opt_adamw(model.ctx_compute, gf, gb_opt, 1e-3f, 0.9f, 0.999f, 1e-8f, 0.0f);

model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
if(!ggml_backend_sched_alloc_graph(model.backend_sched, gb_opt)) {
fprintf(stderr, "%s: failed to allocate compute graph\n", __func__);
exit(1);
}
ggml_graph_reset(gb_opt); // Set gradients to zero, reset optimizer.

const int iex_split = ((int)((1.0f - val_split)*nex) / model.nbatch_logical) * model.nbatch_logical;
Expand All @@ -563,10 +569,10 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
// With a period of nbatch_logical/nbatch_physical iterations:
if ((iex0 + model.nbatch_physical) % model.nbatch_logical != 0) {
// For the first nbatch_logical/nbatch_physical - 1 iterations, only calculate gradients and accumulate them:
ggml_backend_graph_compute(model.backend, gb_grad);
ggml_backend_sched_graph_compute(model.backend_sched, gb_grad);
} else {
// For the last iteration, calculate gradients and also apply the optimizer:
ggml_backend_graph_compute(model.backend, gb_opt); // gb_opt contains all nodes of gb_grad so no extra call for gb_grad is needed.
ggml_backend_sched_graph_compute(model.backend_sched, gb_opt); // gb_opt contains all nodes of gb_grad so no extra call for gb_grad is needed.
ggml_graph_reset(gb_grad); // Set gradients to zero, do not reset optimizer.
}

Expand All @@ -586,7 +592,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images));
ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));

ggml_backend_graph_compute(model.backend, gf); // For the validation set, only the forward pass is needed.
ggml_backend_sched_graph_compute(model.backend_sched, gf); // For the validation set, only the forward pass is needed.

ggml_backend_tensor_get(model.loss, &loss, 0, ggml_nbytes(model.loss));
ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits));
Expand Down Expand Up @@ -621,7 +627,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
const double t_total_s = 1e-6*t_total_us;
fprintf(stderr, "%s: training took %.2lfs\n", __func__, t_total_s);

if (ggml_backend_is_cpu(model.backend)) {
if (model.backends.size() == 1 && ggml_backend_is_cpu(model.backends[0])) {
std::string fname = model.arch + "-f32.ggml";
fprintf(stderr, "%s: saving the GGML graph for the forward pass to %s\n", __func__, fname.c_str());
ggml_graph_export(gf, fname.c_str());
Expand Down
48 changes: 35 additions & 13 deletions examples/mnist/mnist-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ static_assert(MNIST_NTEST % MNIST_NBATCH_LOGICAL == 0, "MNIST_NTRAIN % MNIST_NB

struct mnist_model {
std::string arch;
ggml_backend_t backend;
std::vector<ggml_backend_t> backends;
ggml_backend_sched_t backend_sched;
int nbatch_logical;
int nbatch_physical;

Expand Down Expand Up @@ -55,22 +56,39 @@ struct mnist_model {
ggml_backend_buffer_t buf_compute = nullptr;

mnist_model(const std::string & backend_name) {
const size_t backend_index = ggml_backend_reg_find_by_name(backend_name.c_str());
if (backend_index == SIZE_MAX) {
fprintf(stderr, "%s: ERROR: backend %s not found, available:\n", __func__, backend_name.c_str());
for (size_t i = 0; i < ggml_backend_reg_get_count(); ++i) {
fprintf(stderr, " - %s\n", ggml_backend_reg_get_name(i));
std::vector<std::string> backend_names = {backend_name};
if (backend_name != "CPU") {
backend_names.push_back("CPU");
}
for (const std::string & bn : backend_names) {
const size_t backend_index = ggml_backend_reg_find_by_name(bn.c_str());
if (backend_index == SIZE_MAX) {
fprintf(stderr, "%s: ERROR: backend %s not found, available:\n", __func__, bn.c_str());
for (size_t i = 0; i < ggml_backend_reg_get_count(); ++i) {
fprintf(stderr, " - %s\n", ggml_backend_reg_get_name(i));
}
exit(1);
}

ggml_backend_t be = ggml_backend_reg_init_backend(backend_index, nullptr);
if (ggml_backend_is_cpu(be)) {
const int ncores_logical = std::thread::hardware_concurrency();
ggml_backend_cpu_set_n_threads(be, std::min(ncores_logical, (ncores_logical + 4)/2));
}
exit(1);
backends.push_back(be);
}

fprintf(stderr, "%s: using %s backend\n", __func__, backend_name.c_str());
backend = ggml_backend_reg_init_backend(backend_index, nullptr);
if (ggml_backend_is_cpu(backend)) {
const int ncores_logical = std::thread::hardware_concurrency();
ggml_backend_cpu_set_n_threads(backend, std::min(ncores_logical, (ncores_logical + 4)/2));
if (backends.size() == 1) {
fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backends[0]));
} else if (backends.size() == 2) {
fprintf(stderr, "%s: using %s as primary backend with %s as fallback\n",
__func__, ggml_backend_name(backends[0]), ggml_backend_name(backends[1]));
} else {
GGML_ASSERT(false);
}

backend_sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), GGML_DEFAULT_GRAPH_SIZE, /*parallel =*/ false);

{
const size_t size_meta = 1024*ggml_tensor_overhead();
struct ggml_init_params params = {
Expand Down Expand Up @@ -99,7 +117,11 @@ struct mnist_model {

ggml_backend_buffer_free(buf_weight);
ggml_backend_buffer_free(buf_compute);
ggml_backend_free(backend);

ggml_backend_sched_free(backend_sched);
for (ggml_backend_t be : backends) {
ggml_backend_free(be);
}
}
};

Expand Down
4 changes: 2 additions & 2 deletions include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ extern "C" {
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);

// Initialize backend buffers from a measure graph
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success

GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
GGML_API ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
Expand All @@ -200,7 +200,7 @@ extern "C" {
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);

// Allocate and compute graph on the backend scheduler
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
Expand Down
64 changes: 34 additions & 30 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4195,9 +4195,13 @@ static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, floa
}

struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
if (ggml_is_empty(tensor)) {
return tensor;
}
if (tensor->buffer) {
ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
} else {
GGML_ASSERT(tensor->data);
memset(tensor->data, 0, ggml_nbytes(tensor));
}
return tensor;
Expand Down Expand Up @@ -16810,41 +16814,40 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];

GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_scalar(dst));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ggml_is_scalar(dst));
GGML_ASSERT(dst->type == GGML_TYPE_F32);

// TODO: handle transposed/permuted matrices
const int64_t nc = src0->ne[0];
const int64_t nr = ggml_nrows(src0);

const int ith = params->ith;
const int nth = params->nth;

float * sums = (float *) params->wdata;

// TODO: handle transposed/permuted matrices
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
float * sums = (float *) params->wdata;
float * st = ((float *) params->wdata) + nth + ith*nc;
float sum_thread = 0.0f;

GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));

if (ith == 0) {
memset(sums, 0, sizeof(float) * (nth + nth * nc));
}
ggml_barrier(params->threadpool);

// rows per thread
const int dr = (nr + nth - 1)/nth;
const int64_t dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);

for (int i1 = ir0; i1 < ir1; i1++) {
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
float * st = ((float *) params->wdata) + nth + ith*nc;
for (int64_t i1 = ir0; i1 < ir1; ++i1) {
const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);

#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
for (int64_t i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(s0[i]));
assert(!isnan(s1[i]));
Expand All @@ -16853,23 +16856,24 @@ static void ggml_compute_forward_cross_entropy_loss_f32(

float max = -INFINITY;
ggml_vec_max_f32(nc, &max, s0);
ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max);
assert(sum >= 0.0);
const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
assert(sum_softmax >= 0.0);

ggml_vec_add1_f32(nc, st, st, -sum);
ggml_vec_add1_f32(nc, st, st, -sum_softmax);
ggml_vec_mul_f32(nc, st, st, s1);

float st_sum = 0.0f;
ggml_vec_sum_f32(nc, &st_sum, st);
sums[ith] += st_sum;
float sum_st = 0.0f;
ggml_vec_sum_f32(nc, &sum_st, st);
sum_thread += sum_st;

#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
for (int64_t i = 0; i < nc; ++i) {
assert(!isnan(st[i]));
assert(!isinf(st[i]));
}
#endif
}
sums[ith] = sum_thread;
ggml_barrier(params->threadpool);

if (ith == 0) {
Expand Down Expand Up @@ -16935,7 +16939,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);

#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
for (int64_t i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(s0[i]));
assert(!isnan(s1[i]));
Expand All @@ -16954,7 +16958,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
ggml_vec_scale_f32(nc, ds0, d_by_nr);

#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
for (int64_t i = 0; i < nc; ++i) {
assert(!isnan(ds0[i]));
assert(!isinf(ds0[i]));
}
Expand Down

0 comments on commit caf425b

Please sign in to comment.