Skip to content

Commit

Permalink
initialize gradients with ggml_graph_reset
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 12, 2024
1 parent aebd118 commit 17abed5
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 47 deletions.
35 changes: 8 additions & 27 deletions examples/mnist/mnist-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
ggml_set_param(model.ctx_compute, model.fc2_bias);

model.images = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NINPUT, model.nbatch);
ggml_set_input(model.images);
ggml_set_name(model.images, "images");
ggml_set_input(model.images);

ggml_tensor * fc1 = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,
ggml_mul_mat(model.ctx_compute, model.fc1_weight, model.images),
Expand All @@ -396,8 +396,8 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
ggml_set_param(model.ctx_compute, model.dense_bias);

model.images = ggml_new_tensor_4d(model.ctx_compute, GGML_TYPE_F32, 28, 28, 1, model.nbatch);
ggml_set_input(model.images);
ggml_set_name(model.images, "images");
ggml_set_input(model.images);

struct ggml_tensor * conv1_out = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,
ggml_conv_2d(model.ctx_compute, model.conv1_kernel, model.images, 1, 1, 1, 1, 1, 1),
Expand Down Expand Up @@ -440,30 +440,31 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
GGML_ASSERT(false);
}

ggml_set_output(model.logits);
ggml_set_name(model.logits, "logits");
ggml_set_output(model.logits);
GGML_ASSERT(model.logits->type == GGML_TYPE_F32);
GGML_ASSERT(model.logits->ne[0] == MNIST_NCLASSES);
GGML_ASSERT(model.logits->ne[1] == model.nbatch);
GGML_ASSERT(model.logits->ne[2] == 1);
GGML_ASSERT(model.logits->ne[3] == 1);

model.probs = ggml_soft_max(model.ctx_compute, model.logits);
ggml_set_output(model.probs);
ggml_set_name(model.probs, "probs");
ggml_set_output(model.probs);
GGML_ASSERT(model.probs->type == GGML_TYPE_F32);
GGML_ASSERT(model.probs->ne[0] == MNIST_NCLASSES);
GGML_ASSERT(model.probs->ne[1] == model.nbatch);
GGML_ASSERT(model.probs->ne[2] == 1);
GGML_ASSERT(model.probs->ne[3] == 1);

model.labels = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NCLASSES, model.nbatch);
ggml_set_input(model.labels);
ggml_set_name(model.labels, "labels");
ggml_set_input(model.labels);

model.loss = ggml_cross_entropy_loss(model.ctx_compute, model.logits, model.labels);
ggml_set_output(model.loss);
ggml_set_name(model.loss, "loss");
ggml_set_output(model.loss);
ggml_set_loss(model.loss);
GGML_ASSERT(model.loss->type == GGML_TYPE_F32);
GGML_ASSERT(model.loss->ne[0] == 1);
GGML_ASSERT(model.loss->ne[1] == 1);
Expand Down Expand Up @@ -526,26 +527,8 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
ggml_build_backward_expand(model.ctx_compute, gf, gb, false);
ggml_build_opt_adam( model.ctx_compute, gf, gb, 1e-3f, 0.9f, 0.999f, 1e-8f, 0.0f);

struct ggml_opt_context opt_ctx;
struct ggml_opt_params opt_pars = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
opt_pars.print_forward_graph = false;
opt_pars.print_backward_graph = false;
opt_pars.n_threads = std::thread::hardware_concurrency();
opt_pars.adam.n_iter = 1; // per call of ggml_opt_resume_g
ggml_opt_init(model.ctx_compute, &opt_ctx, opt_pars, 0);

model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);

for (int j = 0; j < gb->n_nodes; ++j) {
struct ggml_tensor * node = gb->nodes[j];

if (node->op != GGML_OP_OPT_STEP_ADAM) {
continue;
}

ggml_backend_tensor_memset(node->src[2], 0, 0, ggml_nbytes(node->src[2]));
ggml_backend_tensor_memset(node->src[3], 0, 0, ggml_nbytes(node->src[3]));
}
ggml_graph_reset(gb);

for (int epoch = 0; epoch < 20; ++epoch) {
fprintf(stderr, "%s: epoch %d start...", __func__, epoch);
Expand All @@ -559,9 +542,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images));
ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));

const float onef = 1.0f;
ggml_backend_graph_compute(model.backend, gf);
ggml_backend_tensor_set(model.loss->grad, &onef, 0, sizeof(float));
ggml_backend_graph_compute(model.backend, gb);
for (int j = 0; j < gb->n_nodes; ++j) {
struct ggml_tensor * node = gb->nodes[j];
Expand Down
6 changes: 3 additions & 3 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ extern "C" {
GGML_TENSOR_FLAG_INPUT = 1,
GGML_TENSOR_FLAG_OUTPUT = 2,
GGML_TENSOR_FLAG_PARAM = 4,
GGML_TENSOR_FLAG_LOSS = 8,
};

// ggml object
Expand Down Expand Up @@ -2047,9 +2048,8 @@ extern "C" {
// automatic differentiation
//

GGML_API void ggml_set_param(
struct ggml_context * ctx,
struct ggml_tensor * tensor);
GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);


GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
Expand Down
32 changes: 28 additions & 4 deletions src/ggml.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
#define _USE_MATH_DEFINES // For M_PI on MSVC

#include "ggml-backend.h"
#include "ggml-impl.h"
#include "ggml-quants.h"
#include "ggml.h"
Expand Down Expand Up @@ -8143,16 +8144,21 @@ struct ggml_tensor * ggml_opt_step_adam(

////////////////////////////////////////////////////////////////////////////////

void ggml_set_param(
struct ggml_context * ctx,
struct ggml_tensor * tensor) {
void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
tensor->flags |= GGML_TENSOR_FLAG_PARAM;

GGML_ASSERT(tensor->grad == NULL);
tensor->grad = ggml_dup_tensor(ctx, tensor);
ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
}

void ggml_set_loss(struct ggml_tensor * tensor) {
GGML_ASSERT(ggml_is_scalar(tensor));
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
GGML_ASSERT(tensor->grad);
tensor->flags |= GGML_TENSOR_FLAG_LOSS;
}

// ggml_compute_forward_dup

static void ggml_compute_forward_dup_same_cont(
Expand Down Expand Up @@ -18926,10 +18932,28 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
GGML_ASSERT(cgraph->grads != NULL);

for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
struct ggml_tensor * grad = cgraph->grads[i];

// initial gradients of loss should be 1, 0 otherwise
if (grad) {
ggml_set_zero(grad);
if (node->flags & GGML_TENSOR_FLAG_LOSS) {
GGML_ASSERT(node->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_scalar(node));

const float onef = 1.0f;
ggml_backend_tensor_set(grad, &onef, 0, ggml_nbytes(grad));
} else {
ggml_backend_tensor_memset(grad, 0, 0, ggml_nbytes(grad));
}
}

GGML_ASSERT(node);
if (node->op == GGML_OP_OPT_STEP_ADAM) {
// set iteration to 1 and clear momenta
ggml_set_op_params_i32(node, 0, 1);
ggml_backend_tensor_memset(node->src[2], 0, 0, ggml_nbytes(node->src[2]));
ggml_backend_tensor_memset(node->src[3], 0, 0, ggml_nbytes(node->src[3]));
}
}
}
Expand Down
16 changes: 3 additions & 13 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ struct test_case {
out = ggml_sum(ctx, out);
ggml_set_name(out, "sum_of_out");
}
ggml_set_loss(out);

ggml_build_forward_expand(gf, out);
ggml_graph_cpy(gf, gb);
Expand Down Expand Up @@ -837,22 +838,11 @@ struct test_case {
return false;
}

// randomize tensors
initialize_tensors(ctx);

for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
if (!t->grad) {
continue;
}

std::vector<float> tmp(ggml_nelements(t->grad));
ggml_backend_tensor_set(t->grad, tmp.data(), 0, ggml_nbytes(t->grad));
}
initialize_tensors(ctx); // Randomizes all tensors (including gradients).
ggml_graph_reset(gb); // Sets gradients to 1 if loss, 0 otherwise.

// build graphs
const float onef = 1.0f;
ggml_backend_graph_compute(backend, gf);
ggml_backend_tensor_set(out->grad, &onef, 0, ggml_nbytes(out->grad));
ggml_backend_graph_compute(backend, gb);

bool ok = true;
Expand Down

0 comments on commit 17abed5

Please sign in to comment.