diff --git a/examples/mnist/mnist-common.cpp b/examples/mnist/mnist-common.cpp index 1bc5b4c81..ea90dd97c 100644 --- a/examples/mnist/mnist-common.cpp +++ b/examples/mnist/mnist-common.cpp @@ -556,15 +556,6 @@ void mnist_model_train(mnist_model & model, const float * images, const float * ggml_backend_graph_compute(model.backend, gb_opt); ggml_graph_reset(gb_grad); // Set gradients to zero, do not reset optimizer. } - for (int j = 0; j < gb_grad->n_nodes; ++j) { - struct ggml_tensor * node = gb_grad->nodes[j]; - - if (node->op != GGML_OP_OPT_STEP_ADAM) { - continue; - } - - node->op_params[0]++; - } ggml_backend_tensor_get(model.loss, &loss, 0, ggml_nbytes(model.loss)); ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits)); diff --git a/src/ggml-cuda/opt-step-adam.cu b/src/ggml-cuda/opt-step-adam.cu index 580f08167..4f2689bf6 100644 --- a/src/ggml-cuda/opt-step-adam.cu +++ b/src/ggml-cuda/opt-step-adam.cu @@ -63,15 +63,18 @@ void ggml_cuda_opt_step_adam(ggml_backend_cuda_context & ctx, ggml_tensor * dst) const int64_t ne = ggml_nelements(src0); - int32_t iter; memcpy(&iter, &dst->op_params[0], sizeof(float)); - float alpha; memcpy(&alpha, &dst->op_params[1], sizeof(float)); - float beta1; memcpy(&beta1, &dst->op_params[2], sizeof(float)); - float beta2; memcpy(&beta2, &dst->op_params[3], sizeof(float)); - float eps; memcpy(&eps, &dst->op_params[4], sizeof(float)); - float l1; memcpy(&l1, &dst->op_params[5], sizeof(float)); + int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t)); + float alpha; memcpy(&alpha, &dst->op_params[2], sizeof(float)); + float beta1; memcpy(&beta1, &dst->op_params[3], sizeof(float)); + float beta2; memcpy(&beta2, &dst->op_params[4], sizeof(float)); + float eps; memcpy(&eps, &dst->op_params[5], sizeof(float)); + float l1; memcpy(&l1, &dst->op_params[6], sizeof(float)); const float beta1h = alpha/(1.0f - powf(beta1, iter)); const float beta2h = 1.0f/(1.0f - powf(beta2, iter)); opt_step_adam_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, ne, alpha, beta1, beta2, eps, l1, beta1h, beta2h, stream); + + iter++; + memcpy(&dst->op_params[0], &iter, sizeof(int64_t)); } diff --git a/src/ggml.c b/src/ggml.c index 0994309af..ffcc9defe 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -8132,12 +8132,13 @@ struct ggml_tensor * ggml_opt_step_adam( result->src[2] = ggml_dup_tensor(ctx, a->grad); result->src[3] = ggml_dup_tensor(ctx, a->grad); - ggml_set_op_params_i32(result, 0, 1); // iteration - ggml_set_op_params_f32(result, 1, alpha); - ggml_set_op_params_f32(result, 2, beta1); - ggml_set_op_params_f32(result, 3, beta2); - ggml_set_op_params_f32(result, 4, eps); - ggml_set_op_params_f32(result, 5, l1); + const int64_t iter = 1; + memcpy(&result->op_params[0], &iter, sizeof(int64_t)); + ggml_set_op_params_f32(result, 2, alpha); + ggml_set_op_params_f32(result, 3, beta1); + ggml_set_op_params_f32(result, 4, beta2); + ggml_set_op_params_f32(result, 5, eps); + ggml_set_op_params_f32(result, 6, l1); return result; } @@ -17162,12 +17163,12 @@ static void ggml_compute_forward_opt_step_adam_f32( const int ir1 = MIN(ir0 + dr, nr); /* const float gnorm = 1.0f; */ - const int32_t iter = ggml_get_op_params_i32(dst, 0); - const float alpha = ggml_get_op_params_f32(dst, 1); - const float beta1 = ggml_get_op_params_f32(dst, 2); - const float beta2 = ggml_get_op_params_f32(dst, 3); - const float eps = ggml_get_op_params_f32(dst, 4); - const float l1 = ggml_get_op_params_f32(dst, 5); + int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t)); + const float alpha = ggml_get_op_params_f32(dst, 2); + const float beta1 = ggml_get_op_params_f32(dst, 3); + const float beta2 = ggml_get_op_params_f32(dst, 4); + const float eps = ggml_get_op_params_f32(dst, 5); + const float l1 = ggml_get_op_params_f32(dst, 6); const float beta1h = alpha/(1.0f - powf(beta1, iter)); const float beta2h = 1.0f/(1.0f - powf(beta2, iter)); @@ -17194,6 +17195,14 @@ static void ggml_compute_forward_opt_step_adam_f32( w[i00] = w[i00]*(1.0f - alpha*l1) - mh/vh; } } + + ggml_barrier(params->shared); + if (ith != 0) { + return; + } + + iter++; + memcpy(&dst->op_params[0], &iter, sizeof(int64_t)); } static void ggml_compute_forward_opt_step_adam(