Skip to content

Commit

Permalink
Temporarily use custom operator for GroupNorm
Browse files Browse the repository at this point in the history
...until ggml GroupNorm has the eps parameter

Signed-off-by: Molly Sophia <[email protected]>
  • Loading branch information
MollySophia committed Jul 16, 2024
1 parent 271a90d commit d26d6f3
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 5 deletions.
7 changes: 2 additions & 5 deletions rwkv_graph.inc
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ static struct ggml_tensor * rwkv_att_v5(

// ggml_group_norm considers groups in the third dimension.
x = ggml_reshape_4d(ctx, x, 1, 1, n_embed, sequence_length);
x = ggml_group_norm_inplace(ctx, x, head_count);
x = rwkv_group_norm_eps_1e_minus5_inplace(ctx, x, head_count);
// Convert back to a regular vector.
x = ggml_reshape_2d(ctx, x, n_embed, sequence_length);
x = ggml_add_inplace(
Expand Down Expand Up @@ -563,12 +563,9 @@ static struct ggml_tensor * rwkv_att_v6(

state.att_heads = state_out;

// rwkv/ggml ggml_group_norm uses eps=1e-5, while rwkv v6 uses eps=64e-5
// Do 1/8 scale to x before group_norm for now.
x = ggml_scale_inplace(ctx, x, 0.125);
// ggml_group_norm considers groups in the third dimension.
x = ggml_reshape_4d(ctx, x, 1, 1, n_embed, sequence_length);
x = ggml_group_norm_inplace(ctx, x, head_count);
x = rwkv_group_norm_eps_64e_minus5_inplace(ctx, x, head_count);
// Convert back to a regular vector.
x = ggml_reshape_2d(ctx, x, n_embed, sequence_length);
x = ggml_add_inplace(
Expand Down
94 changes: 94 additions & 0 deletions rwkv_operators.inc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,85 @@ static void rwkv_max_impl(
SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP();
}

// From ggml.c
static void rwkv_groupnorm_impl(
struct ggml_tensor * dst,
const struct ggml_tensor * src0,
int ith,
int nth,
void * userdata
) {
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_are_same_shape(src0, dst));

GGML_ASSERT(src0->nb[0] == sizeof(float));

GGML_TENSOR_UNARY_OP_LOCALS

const float eps = ((float*)userdata)[0];
const int n_groups = ((int32_t*)userdata)[1];

int n_channels = src0->ne[2];
int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
for (int i = ith; i < n_groups; i += nth) {
int start = i * n_channels_per_group;
int end = start + n_channels_per_group;
if (end > n_channels) {
end = n_channels;
}
int step = end - start;

for (int64_t i03 = 0; i03 < ne03; i03++) {
float sum = 0.0;
for (int64_t i02 = start; i02 < end; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);

float sumr = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sumr += (float)x[i00];
}
sum += sumr;
}
}
const float mean = sum / (ne00 * ne01 * step);

float sum2 = 0.0;
for (int64_t i02 = start; i02 < end; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);

float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);

float sumr = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
float v = x[i00] - mean;
y[i00] = v;
sumr += (float)(v * v);
}
sum2 += sumr;
}
}
const float variance = sum2 / (ne00 * ne01 * step);
const float scale = 1.0f / sqrtf(variance + eps);

for (int64_t i02 = start; i02 < end; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
for (int i00 = 0; i00 < ne00; i00++) {
y[i00] *= scale;
}
}
}
}
}

SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP();
}

// Element-wise exp(x)
struct ggml_tensor * rwkv_exp(struct ggml_context * ctx, struct ggml_tensor * x) {
return ggml_map_custom1(ctx, x, rwkv_exp_impl, 1, NULL);
Expand All @@ -113,6 +192,21 @@ struct ggml_tensor * rwkv_max(struct ggml_context * ctx, struct ggml_tensor * x,
return ggml_map_custom2(ctx, x, y, rwkv_max_impl, 1, NULL);
}

// GroupNorm with custom eps value; Remove when ggml_norm supports eps as an argument.
struct ggml_tensor * rwkv_group_norm_eps_1e_minus5_inplace(struct ggml_context * ctx, struct ggml_tensor * x, int n_groups) {
static float params[2];
params[0] = 1e-5F;
((int*)params)[1] = n_groups;
return ggml_map_custom1_inplace(ctx, x, rwkv_groupnorm_impl, 1, params);
}

struct ggml_tensor * rwkv_group_norm_eps_64e_minus5_inplace(struct ggml_context * ctx, struct ggml_tensor * x, int n_groups) {
static float params[2];
params[0] = 64e-5F;
((int*)params)[1] = n_groups;
return ggml_map_custom1_inplace(ctx, x, rwkv_groupnorm_impl, 1, params);
}

struct ggml_tensor * rwkv_layer_norm(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) {
// LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias`
// Looks like ggml_norm does the first part, we only need to apply weight & bias.
Expand Down

0 comments on commit d26d6f3

Please sign in to comment.