From d26d6f3d2c63a43c8e60be4db80f7c467257058f Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Tue, 16 Jul 2024 14:47:53 +0800 Subject: [PATCH] Temporarily use custom operator for GroupNorm ...until ggml GroupNorm has the eps parameter Signed-off-by: Molly Sophia --- rwkv_graph.inc | 7 +--- rwkv_operators.inc | 94 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 5 deletions(-) diff --git a/rwkv_graph.inc b/rwkv_graph.inc index fdf8eec..53a47da 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -356,7 +356,7 @@ static struct ggml_tensor * rwkv_att_v5( // ggml_group_norm considers groups in the third dimension. x = ggml_reshape_4d(ctx, x, 1, 1, n_embed, sequence_length); - x = ggml_group_norm_inplace(ctx, x, head_count); + x = rwkv_group_norm_eps_1e_minus5_inplace(ctx, x, head_count); // Convert back to a regular vector. x = ggml_reshape_2d(ctx, x, n_embed, sequence_length); x = ggml_add_inplace( @@ -563,12 +563,9 @@ static struct ggml_tensor * rwkv_att_v6( state.att_heads = state_out; - // rwkv/ggml ggml_group_norm uses eps=1e-5, while rwkv v6 uses eps=64e-5 - // Do 1/8 scale to x before group_norm for now. - x = ggml_scale_inplace(ctx, x, 0.125); // ggml_group_norm considers groups in the third dimension. x = ggml_reshape_4d(ctx, x, 1, 1, n_embed, sequence_length); - x = ggml_group_norm_inplace(ctx, x, head_count); + x = rwkv_group_norm_eps_64e_minus5_inplace(ctx, x, head_count); // Convert back to a regular vector. x = ggml_reshape_2d(ctx, x, n_embed, sequence_length); x = ggml_add_inplace( diff --git a/rwkv_operators.inc b/rwkv_operators.inc index 9cc9460..8fe1ab0 100644 --- a/rwkv_operators.inc +++ b/rwkv_operators.inc @@ -93,6 +93,85 @@ static void rwkv_max_impl( SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP(); } +// From ggml.c +static void rwkv_groupnorm_impl( + struct ggml_tensor * dst, + const struct ggml_tensor * src0, + int ith, + int nth, + void * userdata +) { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + const float eps = ((float*)userdata)[0]; + const int n_groups = ((int32_t*)userdata)[1]; + + int n_channels = src0->ne[2]; + int n_channels_per_group = (n_channels + n_groups - 1) / n_groups; + for (int i = ith; i < n_groups; i += nth) { + int start = i * n_channels_per_group; + int end = start + n_channels_per_group; + if (end > n_channels) { + end = n_channels; + } + int step = end - start; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + float sum = 0.0; + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + + float sumr = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sumr += (float)x[i00]; + } + sum += sumr; + } + } + const float mean = sum / (ne00 * ne01 * step); + + float sum2 = 0.0; + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + + float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + + float sumr = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00] - mean; + y[i00] = v; + sumr += (float)(v * v); + } + sum2 += sumr; + } + } + const float variance = sum2 / (ne00 * ne01 * step); + const float scale = 1.0f / sqrtf(variance + eps); + + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + for (int i00 = 0; i00 < ne00; i00++) { + y[i00] *= scale; + } + } + } + } + } + + SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP(); +} + // Element-wise exp(x) struct ggml_tensor * rwkv_exp(struct ggml_context * ctx, struct ggml_tensor * x) { return ggml_map_custom1(ctx, x, rwkv_exp_impl, 1, NULL); @@ -113,6 +192,21 @@ struct ggml_tensor * rwkv_max(struct ggml_context * ctx, struct ggml_tensor * x, return ggml_map_custom2(ctx, x, y, rwkv_max_impl, 1, NULL); } +// GroupNorm with custom eps value; Remove when ggml_norm supports eps as an argument. +struct ggml_tensor * rwkv_group_norm_eps_1e_minus5_inplace(struct ggml_context * ctx, struct ggml_tensor * x, int n_groups) { + static float params[2]; + params[0] = 1e-5F; + ((int*)params)[1] = n_groups; + return ggml_map_custom1_inplace(ctx, x, rwkv_groupnorm_impl, 1, params); +} + +struct ggml_tensor * rwkv_group_norm_eps_64e_minus5_inplace(struct ggml_context * ctx, struct ggml_tensor * x, int n_groups) { + static float params[2]; + params[0] = 64e-5F; + ((int*)params)[1] = n_groups; + return ggml_map_custom1_inplace(ctx, x, rwkv_groupnorm_impl, 1, params); +} + struct ggml_tensor * rwkv_layer_norm(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) { // LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias` // Looks like ggml_norm does the first part, we only need to apply weight & bias.