Skip to content

Commit

Permalink
extra call to add optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 12, 2024
1 parent 56b9195 commit aebd118
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
1 change: 1 addition & 0 deletions examples/mnist/mnist-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float *

struct ggml_cgraph * gb = ggml_graph_dup(model.ctx_compute, gf);
ggml_build_backward_expand(model.ctx_compute, gf, gb, false);
ggml_build_opt_adam( model.ctx_compute, gf, gb, 1e-3f, 0.9f, 0.999f, 1e-8f, 0.0f);

struct ggml_opt_context opt_ctx;
struct ggml_opt_params opt_pars = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
Expand Down
10 changes: 10 additions & 0 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2055,6 +2055,16 @@ extern "C" {
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);

GGML_API void ggml_build_opt_adam(
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
float alpha,
float beta1,
float beta2,
float eps,
float l1);

// graph allocation in a context
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
Expand Down
17 changes: 14 additions & 3 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -18775,19 +18775,30 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
}
}

ggml_hash_set_free(&zero_table);
}

void ggml_build_opt_adam(
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
float alpha,
float beta1,
float beta2,
float eps,
float l1) {
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];

if (node->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
struct ggml_tensor * opt_step = ggml_opt_step_adam(ctx, node, 1e-3f, 0.9f, 0.999f, 1e-8f, 1e-3f);
struct ggml_tensor * opt_step = ggml_opt_step_adam(ctx, node, alpha, beta1, beta2, eps, l1);
ggml_build_forward_expand(gb, opt_step);
}
}

ggml_hash_set_free(&zero_table);
}


static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
void * ptr = *p;
ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
Expand Down

0 comments on commit aebd118

Please sign in to comment.