Skip to content

Commit

Permalink
multiheadattention int8 quantization (#5733)
Browse files Browse the repository at this point in the history
* x86 vulkan fallback
* comment about bf16s
  • Loading branch information
nihui authored Oct 15, 2024
1 parent 1c7af00 commit 66b54cb
Show file tree
Hide file tree
Showing 10 changed files with 953 additions and 60 deletions.
5 changes: 5 additions & 0 deletions docs/developer-guide/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,7 @@ y = affine(out)
| 4 | vdim | int | embed_dim | |
| 5 | attn_mask | int | 0 | |
| 6 | scale | float | 1.f / sqrt(embed_dim / num_heads) | |
| 18 | int8_scale_term | int | 0 | |

| weight | type | shape |
| ------------- | ----- | --------------------- |
Expand All @@ -1288,6 +1289,10 @@ y = affine(out)
| v_bias_data | float | [embed_dim] |
| out_weight_data| float/fp16/int8 | [qdim * embed_dim] |
| out_bias_data | float | [qdim] |
| q_weight_data_int8_scales| float | [embed_dim] |
| k_weight_data_int8_scales| float | [embed_dim] |
| v_weight_data_int8_scales| float | [embed_dim] |
| out_weight_data_int8_scales| float | [1] |

# MVN
```
Expand Down
42 changes: 37 additions & 5 deletions src/layer/arm/multiheadattention_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ MultiHeadAttention_arm::MultiHeadAttention_arm()
#endif
#endif // __ARM_NEON

support_bf16_storage = false;
support_bf16_storage = false;// TODO enable bf16 when gemm has proper out_elemtype support

q_gemm = 0;
k_gemm = 0;
Expand Down Expand Up @@ -76,10 +76,16 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(14, 0); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
q_gemm->load_param(pd);
Mat weights[2];
Mat weights[3];
weights[0] = q_weight_data;
weights[1] = q_bias_data;
#if NCNN_INT8
weights[2] = q_weight_data_int8_scales;
#endif
q_gemm->load_model(ModelBinFromMatArray(weights));
q_gemm->create_pipeline(opt);

Expand All @@ -105,10 +111,16 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(14, 0); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
k_gemm->load_param(pd);
Mat weights[2];
Mat weights[3];
weights[0] = k_weight_data;
weights[1] = k_bias_data;
#if NCNN_INT8
weights[2] = k_weight_data_int8_scales;
#endif
k_gemm->load_model(ModelBinFromMatArray(weights));
k_gemm->create_pipeline(opt);

Expand All @@ -134,10 +146,16 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(14, 0); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
v_gemm->load_param(pd);
Mat weights[2];
Mat weights[3];
weights[0] = v_weight_data;
weights[1] = v_bias_data;
#if NCNN_INT8
weights[2] = v_weight_data_int8_scales;
#endif
v_gemm->load_model(ModelBinFromMatArray(weights));
v_gemm->create_pipeline(opt);

Expand All @@ -161,10 +179,18 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(9, embed_dim); // K = maxk*inch
pd.set(10, 4); // constant_broadcast_type_C = null
pd.set(11, 0); // output_N1M
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
o_gemm->load_param(pd);
Mat weights[2];
Mat weights[3];
weights[0] = out_weight_data;
weights[1] = out_bias_data;
#if NCNN_INT8
Mat out_weight_data_int8_scales(1);
out_weight_data_int8_scales[0] = out_weight_data_int8_scale;
weights[2] = out_weight_data_int8_scales;
#endif
o_gemm->load_model(ModelBinFromMatArray(weights));
o_gemm->create_pipeline(opt);

Expand All @@ -189,6 +215,9 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
qk_gemm->load_param(pd);
qk_gemm->load_model(ModelBinFromMatArray(0));
Option opt1 = opt;
Expand All @@ -211,6 +240,9 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(14, 1); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
qkv_gemm->load_param(pd);
qkv_gemm->load_model(ModelBinFromMatArray(0));
Option opt1 = opt;
Expand Down
Loading

0 comments on commit 66b54cb

Please sign in to comment.