Skip to content

Commit

Permalink
fix gradient accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 17, 2024
1 parent c1d13df commit 478472b
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 26 deletions.
9 changes: 5 additions & 4 deletions examples/mnist/mnist-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,13 +530,16 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co
void mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nepoch, const float val_split) {
const int64_t t_start_us = ggml_time_us();

// gf == graph forward, forward pass only.
struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
ggml_build_forward_expand(gf, model.loss);

struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf); // Backward pass, gradients.
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf);
ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ true, false);

struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gf); // Backward pass, gradients + optimizer.
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gb_grad);
ggml_build_opt_adamw(model.ctx_compute, gf, gb_opt, 1e-3f, 0.9f, 0.999f, 1e-8f, 0.0f);

model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
Expand All @@ -557,8 +560,6 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images));
ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));

ggml_backend_graph_compute(model.backend, gf); // Always compute forward pass.

// With a period of nbatch_logical/nbatch_physical iterations:
if ((iex0 + model.nbatch_physical) % model.nbatch_logical != 0) {
// For the first nbatch_logical/nbatch_physical - 1 iterations, only calculate gradients and accumulate them:
Expand Down
10 changes: 6 additions & 4 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,11 +570,13 @@ extern "C" {
GGML_LOG_LEVEL_DEBUG = 5
};

// this tensor...
enum ggml_tensor_flag {
GGML_TENSOR_FLAG_INPUT = 1,
GGML_TENSOR_FLAG_OUTPUT = 2,
GGML_TENSOR_FLAG_PARAM = 4,
GGML_TENSOR_FLAG_LOSS = 8,
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML comptue graph
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML comptue graph
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
GGML_TENSOR_FLAG_GRAD_ACC = 8, // ...is an accumulator for gradients
GGML_TENSOR_FLAG_LOSS = 16, // ...defines loss for numerical optimization (multiple loss tensors add up)
};

// ggml object
Expand Down
77 changes: 59 additions & 18 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -18123,36 +18123,75 @@ void ggml_build_backward_gradient_checkpointing(
ggml_hash_map_free(replacements);
}

// functions to change gradients considering the case that input a might be initial gradient with zero value
// utility functions to change gradients
// by default, just add/subtract/etc. the gradients
// if a is in zero_table and not a gradient accumulator, replace a
// if a is in zero_table and a gradient accumulator, modify gradients in-place and mark result as gradient accumulator

static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
if (ggml_hash_contains(zero_table, a)) {
return b;
if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
const size_t insert_result = ggml_hash_insert(zero_table, ret);
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
return ret;
} else {
return b;
}
} else {
return ggml_add_impl(ctx, a, b, false);
}
}

static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set * zero_table) {
if (ggml_hash_contains(zero_table, a)) {
struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
const size_t insert_result = ggml_hash_insert(zero_table, ret);
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
return ret;
} else {
struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
}
} else {
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
}
}

static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
if (ggml_hash_contains(zero_table, a)) {
return ggml_repeat(ctx, b, a);
if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
const size_t insert_result = ggml_hash_insert(zero_table, ret);
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
return ret;
} else {
return ggml_repeat(ctx, b, a);
}
} else {
return ggml_add1_impl(ctx, a, b, false);
}
}

static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
if (ggml_hash_contains(zero_table, a)) {
return ggml_neg(ctx, b);
if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
const size_t insert_result = ggml_hash_insert(zero_table, ret);
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
return ret;
} else {
return ggml_neg(ctx, b);
}
} else {
return ggml_sub_impl(ctx, a, b, false);
}
Expand Down Expand Up @@ -19136,22 +19175,25 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
}
}

// hash table of original gradients that should be overwritten instead of incremented
// keep table of original gradients for replacement/accumulation logic
struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];

// when accumulating gradients the table is empty -> gradients always incremented
if (!accumulate) {
for (int i = 0; i < gf->n_nodes; i++) {
if (gf->grads[i]) {
ggml_hash_insert(&zero_table, gf->grads[i]);
if (node->grad) {
// only gradients of trainable parameters should be accumulated
if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
node->grad->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
}

ggml_hash_insert(&zero_table, node->grad);
}
}

for (int i = gf->n_nodes - 1; i >= 0; i--) {
struct ggml_tensor * node = gf->nodes[i];

// inplace operations to add gradients are not created by ggml_compute_backward
// inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
// use allocator to automatically make inplace operations
if (node->grad) {
ggml_compute_backward(ctx, node, &zero_table);
Expand Down Expand Up @@ -19319,19 +19361,18 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {

for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
struct ggml_tensor * grad = cgraph->grads[i];

// initial gradients of loss should be 1, 0 otherwise
if (grad) {
if (node->grad) {
if (node->flags & GGML_TENSOR_FLAG_LOSS) {
GGML_ASSERT(grad->buffer);
GGML_ASSERT(node->grad->buffer);
GGML_ASSERT(node->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_scalar(node));

const float onef = 1.0f;
ggml_backend_tensor_set(grad, &onef, 0, ggml_nbytes(grad));
ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad));
} else {
ggml_set_zero(grad);
ggml_set_zero(node->grad);
}
}

Expand Down

0 comments on commit 478472b

Please sign in to comment.