Skip to content

Commit

Permalink
ggml: fix gradient allocation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 23, 2024
1 parent 336c10a commit a19c280
Show file tree
Hide file tree
Showing 2 changed files with 568 additions and 1,014 deletions.
39 changes: 20 additions & 19 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,11 @@ extern "C" {

// this tensor...
enum ggml_tensor_flag {
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
GGML_TENSOR_FLAG_GRAD = 16, // ...is a gradient for another tensor
};

// n-dimensional tensor
Expand Down Expand Up @@ -1407,14 +1408,14 @@ extern "C" {
// supports 3D: a->ne[2] == b->ne[1]
GGML_API struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * a, // data
struct ggml_tensor * b); // row indices

GGML_API struct ggml_tensor * ggml_get_rows_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c);
struct ggml_tensor * a, // gradients of ggml_get_rows result
struct ggml_tensor * b, // row indices
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape

GGML_API struct ggml_tensor * ggml_diag(
struct ggml_context * ctx,
Expand Down Expand Up @@ -1565,9 +1566,9 @@ extern "C" {
// a - dy
GGML_API struct ggml_tensor * ggml_rope_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
struct ggml_tensor * a, // gradients of ggml_rope result
struct ggml_tensor * b, // positions
struct ggml_tensor * c, // freq factors
int n_dims,
int mode,
int n_ctx_orig,
Expand Down Expand Up @@ -2033,15 +2034,15 @@ extern "C" {
// loss function

GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_context * ctx,
struct ggml_tensor * a, // logits
struct ggml_tensor * b); // labels

GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c);
struct ggml_context * ctx,
struct ggml_tensor * a, // logits
struct ggml_tensor * b, // labels
struct ggml_tensor * c); // gradients of cross_entropy_loss result

// AdamW optimizer step
// Paper: https://arxiv.org/pdf/1711.05101v3.pdf
Expand Down
Loading

0 comments on commit a19c280

Please sign in to comment.