diff --git a/examples/mnist/mnist-common.cpp b/examples/mnist/mnist-common.cpp index 0a78f711f..ffc5e2d00 100644 --- a/examples/mnist/mnist-common.cpp +++ b/examples/mnist/mnist-common.cpp @@ -514,7 +514,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float * opt_pars.print_backward_graph = false; opt_pars.n_threads = std::thread::hardware_concurrency(); opt_pars.adam.n_iter = 1; // per call of ggml_opt_resume_g - ggml_opt_init(model.ctx_compute, &opt_ctx, opt_pars, 397510); + ggml_opt_init(model.ctx_compute, &opt_ctx, opt_pars, 0); model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend); @@ -530,8 +530,10 @@ void mnist_model_train(mnist_model & model, const float * images, const float * ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images)); ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels)); - enum ggml_opt_result opt_result = ggml_opt_resume_g(model.ctx_compute, &opt_ctx, model.loss, gf, gb, NULL, NULL); - GGML_ASSERT(opt_result == GGML_OPT_RESULT_OK || opt_result == GGML_OPT_RESULT_DID_NOT_CONVERGE); + const float onef = 1.0f; + ggml_backend_graph_compute(model.backend, gf); + ggml_backend_tensor_set(model.loss->grad, &onef, 0, sizeof(float)); + ggml_backend_graph_compute(model.backend, gb); ggml_backend_tensor_get(model.loss, &loss, 0, ggml_nbytes(model.loss)); ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits)); diff --git a/examples/mnist/mnist-common.h b/examples/mnist/mnist-common.h index bb3c013e5..352921447 100644 --- a/examples/mnist/mnist-common.h +++ b/examples/mnist/mnist-common.h @@ -59,7 +59,7 @@ struct mnist_model { mnist_model() { // backend = ggml_backend_cuda_init(0); backend = ggml_backend_cpu_init(); - ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency()); + ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency()/2); buf_weight = malloc(size_weight); { diff --git a/include/ggml.h b/include/ggml.h index 59fa80edb..2961e7650 100644 --- a/include/ggml.h +++ b/include/ggml.h @@ -528,6 +528,7 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS_BACK, + GGML_OP_OPT_STEP_ADAM, GGML_OP_COUNT, }; @@ -2033,6 +2034,11 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * c); + GGML_API struct ggml_tensor * ggml_opt_step_adam( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha); + // // automatic differentiation // diff --git a/src/ggml.c b/src/ggml.c index 05519f167..d5751037e 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -2850,9 +2850,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", + "OPT_STEP_ADAM", }; -static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); +static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2942,9 +2943,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", + "adam(x)", }; -static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); +static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -8104,6 +8106,26 @@ struct ggml_tensor * ggml_cross_entropy_loss_back( return result; } +// opt_step_adam + +struct ggml_tensor * ggml_opt_step_adam( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha) { + GGML_ASSERT(a->grad); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_OPT_STEP_ADAM; + result->grad = NULL; + result->src[0] = a; + result->src[1] = a->grad; + + ggml_set_op_params(result, &alpha, sizeof(alpha)); + + return result; +} + //////////////////////////////////////////////////////////////////////////////// void ggml_set_param( @@ -17093,6 +17115,62 @@ static void ggml_compute_forward_cross_entropy_loss_back( } } +static void ggml_compute_forward_opt_step_adam_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src0_grad = dst->src[1]; + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float alpha = ggml_get_op_params_f32(dst, 0); + + for (int ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const size_t offset = i03*nb03 + i02*nb02 + i01*nb01; + + float * weight_ptr = (float *) ((char *) src0->data + offset); + const float * grad_ptr = (const float *) ((const char *) src0_grad->data + offset); + + ggml_vec_mad_f32(ne00, weight_ptr, grad_ptr, -alpha); + } +} + +static void ggml_compute_forward_opt_step_adam( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_opt_step_adam_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} ///////////////////////////////// static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { @@ -17434,6 +17512,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm ggml_compute_forward_cross_entropy_loss_back(params, tensor); } break; + case GGML_OP_OPT_STEP_ADAM: + { + ggml_compute_forward_opt_step_adam(params, tensor); + } + break; case GGML_OP_NONE: { // nop @@ -18520,6 +18603,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // not supported } + case GGML_OP_OPT_STEP_ADAM: + { + GGML_ABORT("fatal error"); // not supported + } case GGML_OP_NONE: { // nop @@ -18652,6 +18739,16 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * } } + for (int i = 0; i < gf->n_nodes; i++) { + struct ggml_tensor * node = gf->nodes[i]; + + if (node->flags & GGML_TENSOR_FLAG_PARAM) { + GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); + struct ggml_tensor * opt_step = ggml_opt_step_adam(ctx, node, 0.001f); + ggml_build_forward_expand(gb, opt_step); + } + } + ggml_hash_set_free(&zero_table); } @@ -19107,6 +19204,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case GGML_OP_OPT_STEP_ADAM: { n_tasks = n_threads; } break;