Skip to content

Commit

Permalink
stochastic gradient descent op
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 5, 2024
1 parent d9316cc commit 0fc3efe
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 6 deletions.
8 changes: 5 additions & 3 deletions examples/mnist/mnist-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
opt_pars.print_backward_graph = false;
opt_pars.n_threads = std::thread::hardware_concurrency();
opt_pars.adam.n_iter = 1; // per call of ggml_opt_resume_g
ggml_opt_init(model.ctx_compute, &opt_ctx, opt_pars, 397510);
ggml_opt_init(model.ctx_compute, &opt_ctx, opt_pars, 0);

model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);

Expand All @@ -530,8 +530,10 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images));
ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));

enum ggml_opt_result opt_result = ggml_opt_resume_g(model.ctx_compute, &opt_ctx, model.loss, gf, gb, NULL, NULL);
GGML_ASSERT(opt_result == GGML_OPT_RESULT_OK || opt_result == GGML_OPT_RESULT_DID_NOT_CONVERGE);
const float onef = 1.0f;
ggml_backend_graph_compute(model.backend, gf);
ggml_backend_tensor_set(model.loss->grad, &onef, 0, sizeof(float));
ggml_backend_graph_compute(model.backend, gb);

ggml_backend_tensor_get(model.loss, &loss, 0, ggml_nbytes(model.loss));
ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits));
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/mnist-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct mnist_model {
mnist_model() {
// backend = ggml_backend_cuda_init(0);
backend = ggml_backend_cpu_init();
ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency());
ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency()/2);

buf_weight = malloc(size_weight);
{
Expand Down
6 changes: 6 additions & 0 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ extern "C" {

GGML_OP_CROSS_ENTROPY_LOSS,
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
GGML_OP_OPT_STEP_ADAM,

GGML_OP_COUNT,
};
Expand Down Expand Up @@ -2033,6 +2034,11 @@ extern "C" {
struct ggml_tensor * b,
struct ggml_tensor * c);

GGML_API struct ggml_tensor * ggml_opt_step_adam(
struct ggml_context * ctx,
struct ggml_tensor * a,
float alpha);

//
// automatic differentiation
//
Expand Down
102 changes: 100 additions & 2 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2850,9 +2850,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {

"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
"OPT_STEP_ADAM",
};

static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -2942,9 +2943,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {

"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
"adam(x)",
};

static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -8104,6 +8106,26 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
return result;
}

// opt_step_adam

struct ggml_tensor * ggml_opt_step_adam(
struct ggml_context * ctx,
struct ggml_tensor * a,
float alpha) {
GGML_ASSERT(a->grad);

struct ggml_tensor * result = ggml_view_tensor(ctx, a);

result->op = GGML_OP_OPT_STEP_ADAM;
result->grad = NULL;
result->src[0] = a;
result->src[1] = a->grad;

ggml_set_op_params(result, &alpha, sizeof(alpha));

return result;
}

////////////////////////////////////////////////////////////////////////////////

void ggml_set_param(
Expand Down Expand Up @@ -17093,6 +17115,62 @@ static void ggml_compute_forward_cross_entropy_loss_back(
}
}

static void ggml_compute_forward_opt_step_adam_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src0_grad = dst->src[1];
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));

const int ith = params->ith;
const int nth = params->nth;

const int nr = ggml_nrows(src0);

GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(nb00 == sizeof(float));

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

const float alpha = ggml_get_op_params_f32(dst, 0);

for (int ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);

const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;

float * weight_ptr = (float *) ((char *) src0->data + offset);
const float * grad_ptr = (const float *) ((const char *) src0_grad->data + offset);

ggml_vec_mad_f32(ne00, weight_ptr, grad_ptr, -alpha);
}
}

static void ggml_compute_forward_opt_step_adam(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

const struct ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_opt_step_adam_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
/////////////////////////////////

static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
Expand Down Expand Up @@ -17434,6 +17512,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
ggml_compute_forward_cross_entropy_loss_back(params, tensor);
}
break;
case GGML_OP_OPT_STEP_ADAM:
{
ggml_compute_forward_opt_step_adam(params, tensor);
}
break;
case GGML_OP_NONE:
{
// nop
Expand Down Expand Up @@ -18520,6 +18603,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ABORT("fatal error"); // not supported
}
case GGML_OP_OPT_STEP_ADAM:
{
GGML_ABORT("fatal error"); // not supported
}
case GGML_OP_NONE:
{
// nop
Expand Down Expand Up @@ -18652,6 +18739,16 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
}
}

for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];

if (node->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
struct ggml_tensor * opt_step = ggml_opt_step_adam(ctx, node, 0.001f);
ggml_build_forward_expand(gb, opt_step);
}
}

ggml_hash_set_free(&zero_table);
}

Expand Down Expand Up @@ -19107,6 +19204,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAM:
{
n_tasks = n_threads;
} break;
Expand Down

0 comments on commit 0fc3efe

Please sign in to comment.