Skip to content

Commit

Permalink
gradient accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 13, 2024
1 parent 17abed5 commit 884431c
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 61 deletions.
74 changes: 43 additions & 31 deletions examples/mnist/mnist-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,16 +368,17 @@ mnist_model mnist_model_init_random(const std::string & arch, const std::string
return model;
}

void mnist_model_build(mnist_model & model, const int nbatch) {
model.nbatch = nbatch;
void mnist_model_build(mnist_model & model, const int nbatch_logical, const int nbatch_physical) {
model.nbatch_logical = nbatch_logical;
model.nbatch_physical = nbatch_physical;

if (model.arch == "mnist-fc") {
ggml_set_param(model.ctx_compute, model.fc1_weight);
ggml_set_param(model.ctx_compute, model.fc1_bias);
ggml_set_param(model.ctx_compute, model.fc2_weight);
ggml_set_param(model.ctx_compute, model.fc2_bias);

model.images = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NINPUT, model.nbatch);
model.images = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NINPUT, model.nbatch_physical);
ggml_set_name(model.images, "images");
ggml_set_input(model.images);

Expand All @@ -395,7 +396,7 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
ggml_set_param(model.ctx_compute, model.dense_weight);
ggml_set_param(model.ctx_compute, model.dense_bias);

model.images = ggml_new_tensor_4d(model.ctx_compute, GGML_TYPE_F32, 28, 28, 1, model.nbatch);
model.images = ggml_new_tensor_4d(model.ctx_compute, GGML_TYPE_F32, 28, 28, 1, model.nbatch_physical);
ggml_set_name(model.images, "images");
ggml_set_input(model.images);

Expand All @@ -405,33 +406,33 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
GGML_ASSERT(conv1_out->ne[0] == MNIST_HW);
GGML_ASSERT(conv1_out->ne[1] == MNIST_HW);
GGML_ASSERT(conv1_out->ne[2] == MNIST_CNN_NCB);
GGML_ASSERT(conv1_out->ne[3] == model.nbatch);
GGML_ASSERT(conv1_out->ne[3] == model.nbatch_physical);

struct ggml_tensor * conv2_in = ggml_pool_2d(model.ctx_compute, conv1_out, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
GGML_ASSERT(conv2_in->ne[0] == MNIST_HW/2);
GGML_ASSERT(conv2_in->ne[1] == MNIST_HW/2);
GGML_ASSERT(conv2_in->ne[2] == MNIST_CNN_NCB);
GGML_ASSERT(conv2_in->ne[3] == model.nbatch);
GGML_ASSERT(conv2_in->ne[3] == model.nbatch_physical);

struct ggml_tensor * conv2_out = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,
ggml_conv_2d(model.ctx_compute, model.conv2_kernel, conv2_in, 1, 1, 1, 1, 1, 1),
model.conv2_bias));
GGML_ASSERT(conv2_out->ne[0] == MNIST_HW/2);
GGML_ASSERT(conv2_out->ne[1] == MNIST_HW/2);
GGML_ASSERT(conv2_out->ne[2] == MNIST_CNN_NCB*2);
GGML_ASSERT(conv2_out->ne[3] == model.nbatch);
GGML_ASSERT(conv2_out->ne[3] == model.nbatch_physical);

struct ggml_tensor * dense_in = ggml_pool_2d(model.ctx_compute, conv2_out, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
GGML_ASSERT(dense_in->ne[0] == MNIST_HW/4);
GGML_ASSERT(dense_in->ne[1] == MNIST_HW/4);
GGML_ASSERT(dense_in->ne[2] == MNIST_CNN_NCB*2);
GGML_ASSERT(dense_in->ne[3] == model.nbatch);
GGML_ASSERT(dense_in->ne[3] == model.nbatch_physical);

dense_in = ggml_reshape_2d(model.ctx_compute,
ggml_cont(model.ctx_compute, ggml_permute(model.ctx_compute, dense_in, 1, 2, 0, 3)),
(MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), model.nbatch);
(MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), model.nbatch_physical);
GGML_ASSERT(dense_in->ne[0] == (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2));
GGML_ASSERT(dense_in->ne[1] == model.nbatch);
GGML_ASSERT(dense_in->ne[1] == model.nbatch_physical);
GGML_ASSERT(dense_in->ne[2] == 1);
GGML_ASSERT(dense_in->ne[3] == 1);

Expand All @@ -444,7 +445,7 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
ggml_set_output(model.logits);
GGML_ASSERT(model.logits->type == GGML_TYPE_F32);
GGML_ASSERT(model.logits->ne[0] == MNIST_NCLASSES);
GGML_ASSERT(model.logits->ne[1] == model.nbatch);
GGML_ASSERT(model.logits->ne[1] == model.nbatch_physical);
GGML_ASSERT(model.logits->ne[2] == 1);
GGML_ASSERT(model.logits->ne[3] == 1);

Expand All @@ -453,11 +454,11 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
ggml_set_output(model.probs);
GGML_ASSERT(model.probs->type == GGML_TYPE_F32);
GGML_ASSERT(model.probs->ne[0] == MNIST_NCLASSES);
GGML_ASSERT(model.probs->ne[1] == model.nbatch);
GGML_ASSERT(model.probs->ne[1] == model.nbatch_physical);
GGML_ASSERT(model.probs->ne[2] == 1);
GGML_ASSERT(model.probs->ne[3] == 1);

model.labels = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NCLASSES, model.nbatch);
model.labels = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NCLASSES, model.nbatch_physical);
ggml_set_name(model.labels, "labels");
ggml_set_input(model.labels);

Expand All @@ -484,13 +485,13 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co
const int64_t t_start_us = ggml_time_us();

float loss;
std::vector<float> logits(model.nbatch*MNIST_NCLASSES);
std::vector<float> logits(model.nbatch_physical*MNIST_NCLASSES);

GGML_ASSERT(sizeof(loss) == ggml_nbytes(model.loss));
GGML_ASSERT(logits.size() == ggml_nelements(model.logits));

GGML_ASSERT(nex % model.nbatch == 0);
for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch) {
GGML_ASSERT(nex % model.nbatch_physical == 0);
for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch_physical) {
ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images));
ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));

Expand All @@ -501,7 +502,7 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co

result.loss.push_back(loss);

for (int iexb = 0; iexb < model.nbatch; ++iexb) {
for (int iexb = 0; iexb < model.nbatch_physical; ++iexb) {
const float * logits_iexb = logits.data() + iexb*MNIST_NCLASSES;
result.pred.push_back(std::max_element(logits_iexb, logits_iexb + MNIST_NCLASSES) - logits_iexb);
}
Expand All @@ -520,32 +521,43 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co
void mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nthreads) {
const int64_t t_start_us = ggml_time_us();

struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, 16384, true);
struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, 16384, true); // Forward pass.
ggml_build_forward_expand(gf, model.loss);

struct ggml_cgraph * gb = ggml_graph_dup(model.ctx_compute, gf);
ggml_build_backward_expand(model.ctx_compute, gf, gb, false);
ggml_build_opt_adam( model.ctx_compute, gf, gb, 1e-3f, 0.9f, 0.999f, 1e-8f, 0.0f);
struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf); // Backward pass, gradients.
ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ true, false);

struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gf); // Backward pass, gradients + optimizer.
ggml_build_opt_adam(model.ctx_compute, gf, gb_opt, 1e-3f, 0.9f, 0.999f, 1e-8f, 0.0f);

model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
ggml_graph_reset(gb);
ggml_graph_reset(gb_opt); // Set gradients to zero, reset optimizer.

for (int epoch = 0; epoch < 20; ++epoch) {
for (int epoch = 0; epoch < 30; ++epoch) {
fprintf(stderr, "%s: epoch %d start...", __func__, epoch);
const int64_t t_start_us = ggml_time_us();

float loss;
std::vector<float> logits(model.nbatch*MNIST_NCLASSES);
std::vector<float> logits(model.nbatch_physical*MNIST_NCLASSES);

mnist_eval_result result;
for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch) {
for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch_physical) {
ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images));
ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));

ggml_backend_graph_compute(model.backend, gf);
ggml_backend_graph_compute(model.backend, gb);
for (int j = 0; j < gb->n_nodes; ++j) {
struct ggml_tensor * node = gb->nodes[j];
ggml_backend_graph_compute(model.backend, gf); // Always compute forward pass.

// With a period of nbatch_logical/nbatch_physical iterations:
if ((iex0 + model.nbatch_physical) % model.nbatch_logical != 0) {
// For the first nbatch_logical/nbatch_physical - 1 iterations, only calculate gradients and accumulate them:
ggml_backend_graph_compute(model.backend, gb_grad);
} else {
// For the last iteration, calculate gradients and also apply the optimizer:
ggml_backend_graph_compute(model.backend, gb_opt);
ggml_graph_reset(gb_grad); // Set gradients to zero, do not reset optimizer.
}
for (int j = 0; j < gb_grad->n_nodes; ++j) {
struct ggml_tensor * node = gb_grad->nodes[j];

if (node->op != GGML_OP_OPT_STEP_ADAM) {
continue;
Expand All @@ -559,7 +571,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float *

result.loss.push_back(loss);

for (int iexb = 0; iexb < model.nbatch; ++iexb) {
for (int iexb = 0; iexb < model.nbatch_physical; ++iexb) {
const float * logits_iexb = logits.data() + iexb*MNIST_NCLASSES;
result.pred.push_back(std::max_element(logits_iexb, logits_iexb + MNIST_NCLASSES) - logits_iexb);
}
Expand Down Expand Up @@ -663,7 +675,7 @@ int wasm_eval(uint8_t * digitPtr) {
std::vector<float> labels(MNIST_NCLASSES);

mnist_model model = mnist_model_init_from_file("mnist-f32.gguf", "CPU");
mnist_model_build(model, 1);
mnist_model_build(model, 1, 1);
mnist_eval_result result = mnist_model_eval(model, digit.data(), labels.data(), 1, 1);

return result.pred[0];
Expand Down
17 changes: 10 additions & 7 deletions examples/mnist/mnist-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
#include "ggml-backend.h"
#include "ggml.h"

#define MNIST_NTRAIN 60000
#define MNIST_NTEST 10000
#define MNIST_NBATCH 500
#define MNIST_NTRAIN 60000
#define MNIST_NTEST 10000
#define MNIST_NBATCH_LOGICAL 1000
#define MNIST_NBATCH_PHYSICAL 500

static_assert(MNIST_NTRAIN % MNIST_NBATCH == 0, "MNIST_NTRAIN % MNIST_BATCH != 0");
static_assert(MNIST_NTEST % MNIST_NBATCH == 0, "MNIST_NTRAIN % MNIST_BATCH != 0");
static_assert(MNIST_NBATCH_LOGICAL % MNIST_NBATCH_PHYSICAL == 0, "MNIST_NBATCH_LOGICAL % MNIST_NBATCH_PHYSICAL != 0");
static_assert(MNIST_NTRAIN % MNIST_NBATCH_LOGICAL == 0, "MNIST_NTRAIN % MNIST_NBATCH_LOGICAL != 0");
static_assert(MNIST_NTEST % MNIST_NBATCH_LOGICAL == 0, "MNIST_NTRAIN % MNIST_NBATCH_LOGICAL != 0");

#define MNIST_HW 28
#define MNIST_NINPUT (MNIST_HW*MNIST_HW)
Expand All @@ -26,7 +28,8 @@ static_assert(MNIST_NTEST % MNIST_NBATCH == 0, "MNIST_NTRAIN % MNIST_BATCH != 0
struct mnist_model {
std::string arch;
ggml_backend_t backend;
int nbatch;
int nbatch_logical;
int nbatch_physical;

struct ggml_tensor * images = nullptr;
struct ggml_tensor * labels = nullptr;
Expand Down Expand Up @@ -118,7 +121,7 @@ mnist_eval_result mnist_graph_eval(const std::string & fname, const float * imag

mnist_model mnist_model_init_from_file(const std::string & fname, const std::string & backend);
mnist_model mnist_model_init_random(const std::string & arch, const std::string & backend);
void mnist_model_build(mnist_model & model, const int nbatch);
void mnist_model_build(mnist_model & model, const int nbatch_logical, const int nbatch_physical);
mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, const float * labels, const int nex, const int nthreads);
void mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nthreads);
void mnist_model_save(mnist_model & model, const std::string & fname);
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/mnist-eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ int main(int argc, char ** argv) {

mnist_model model = mnist_model_init_from_file(argv[1], argc >= 5 ? argv[4] : "CPU");

mnist_model_build(model, MNIST_NBATCH);
mnist_model_build(model, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL);

const int64_t t_load_us = ggml_time_us() - t_start_us;

Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/mnist-train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ int main(int argc, char ** argv) {

mnist_model model = mnist_model_init_random(argv[1], argc >= 6 ? argv[5] : "CPU");

mnist_model_build(model, MNIST_NBATCH);
mnist_model_build(model, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL);

mnist_model_train(model, images.data(), labels.data(), MNIST_NTRAIN, std::thread::hardware_concurrency());

Expand Down
2 changes: 1 addition & 1 deletion include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2053,7 +2053,7 @@ extern "C" {


GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep);

GGML_API void ggml_build_opt_adam(
struct ggml_context * ctx,
Expand Down
18 changes: 11 additions & 7 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -17713,7 +17713,7 @@ void ggml_build_backward_gradient_checkpointing(
struct ggml_tensor * * checkpoints,
int n_checkpoints) {
ggml_graph_cpy(gf, gb_tmp);
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
ggml_build_backward_expand(ctx, gf, gb_tmp, false, true);

if (n_checkpoints <= 0) {
ggml_graph_cpy(gb_tmp, gb);
Expand Down Expand Up @@ -18738,7 +18738,7 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
ggml_build_forward_impl(cgraph, tensor, true);
}

void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep) {
GGML_ASSERT(gf->n_nodes > 0);
GGML_ASSERT(gf->grads);

Expand All @@ -18754,11 +18754,15 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
}
}

// remember original gradients which start with zero values
// hash table of original gradients that should be overwritten instead of incremented
struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
for (int i = 0; i < gf->n_nodes; i++) {
if (gf->grads[i]) {
ggml_hash_insert(&zero_table, gf->grads[i]);

// when accumulating gradients the table is empty -> gradients always incremented
if (!accumulate) {
for (int i = 0; i < gf->n_nodes; i++) {
if (gf->grads[i]) {
ggml_hash_insert(&zero_table, gf->grads[i]);
}
}
}

Expand Down Expand Up @@ -21175,7 +21179,7 @@ enum ggml_opt_result ggml_opt_resume(
ggml_build_forward_expand(gf, f);

struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
ggml_build_backward_expand(ctx, gf, gb, true);
ggml_build_backward_expand(ctx, gf, gb, false, true);

return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ struct test_case {

ggml_build_forward_expand(gf, out);
ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx, gf, gb, false);
ggml_build_backward_expand(ctx, gf, gb, false, false);
if (expect.size() != 1 || expect[0] != 0.0f) {
GGML_ASSERT(gb->n_nodes > gf->n_nodes);
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
Expand Down
2 changes: 1 addition & 1 deletion tests/test-grad0.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ static bool check_gradient(
struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
ggml_build_forward_expand(gf, f);
ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx0, gf, gb, false);
ggml_build_backward_expand(ctx0, gf, gb, false, false);

ggml_graph_compute_with_ctx(ctx0, gf, n_threads);

Expand Down
2 changes: 1 addition & 1 deletion tests/test-mul-mat0.c
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ bool check_gradient(
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
ggml_build_forward_expand(gf, f);
struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
ggml_build_backward_expand(ctx0, gf, gb, false);
ggml_build_backward_expand(ctx0, gf, gb, false, false);

ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
ggml_graph_reset (gf);
Expand Down
Loading

0 comments on commit 884431c

Please sign in to comment.