From 9c6522630045f2cdbc43f55e4d8fc31199388f08 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Jun 2023 16:03:33 +0300 Subject: [PATCH] ggml : make src0 broadcast-able into src1 for ggml_mul_mat --- src/ggml.c | 259 ++++++++++++++++++++++++++--------------------------- 1 file changed, 127 insertions(+), 132 deletions(-) diff --git a/src/ggml.c b/src/ggml.c index 1a441eb98..42005fb1f 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -3968,10 +3968,9 @@ static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) { static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - return - (t0->ne[0] == t1->ne[0]) && - (t0->ne[2] == t1->ne[2]) && - (t0->ne[3] == t1->ne[3]); + return (t0->ne[0] == t1->ne[0]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); } static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { @@ -5737,8 +5736,8 @@ struct ggml_tensor * ggml_mul_mat( is_node = true; } - const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); result->op = GGML_OP_MUL_MAT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -10466,11 +10465,8 @@ static void ggml_compute_forward_mul_mat_f32( const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - const int64_t ne10 = src1->ne[0]; -#endif + const int64_t ne10 = src1->ne[0]; UNUSED(ne10); const int64_t ne11 = src1->ne[1]; -#ifndef NDEBUG const int64_t ne12 = src1->ne[2]; const int64_t ne13 = src1->ne[3]; @@ -10480,17 +10476,14 @@ static void ggml_compute_forward_mul_mat_f32( const int64_t ne3 = dst->ne[3]; const int nb00 = src0->nb[0]; -#endif const int nb01 = src0->nb[1]; const int nb02 = src0->nb[2]; const int nb03 = src0->nb[3]; -#ifndef NDEBUG const int nb10 = src1->nb[0]; -#endif - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; + const int nb11 = src1->nb[1]; UNUSED(nb11); + const int nb12 = src1->nb[2]; UNUSED(nb12); + const int nb13 = src1->nb[3]; UNUSED(nb13); const int nb0 = dst->nb[0]; const int nb1 = dst->nb[1]; @@ -10500,31 +10493,31 @@ static void ggml_compute_forward_mul_mat_f32( const int ith = params->ith; const int nth = params->nth; - assert(ne02 == ne12); - assert(ne03 == ne13); - assert(ne2 == ne12); - assert(ne3 == ne13); + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); // we don't support permuted src0 or src1 - assert(nb00 == sizeof(float)); - assert(nb10 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); // dst cannot be transposed or permuted - assert(nb0 == sizeof(float)); - assert(nb0 <= nb1); - assert(nb1 <= nb2); - assert(nb2 <= nb3); - - assert(ne0 == ne01); - assert(ne1 == ne11); - assert(ne2 == ne02); - assert(ne3 == ne03); + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); // nb01 >= nb00 - src0 is not transposed // compute by src0 rows #if defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); } @@ -10534,6 +10527,11 @@ static void ggml_compute_forward_mul_mat_f32( #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith != 0) { return; } @@ -10573,40 +10571,35 @@ static void ggml_compute_forward_mul_mat_f32( return; } - // parallelize by src0 rows using ggml_vec_dot_f32 + // parallelize by src0 rows + const int64_t dr = (ne01 + nth - 1)/nth; - // total rows in src0 - const int nr = ne01*ne02*ne03; + const int64_t ir10 = dr*ith; + const int64_t ir11 = MIN(ir10 + dr, ne01); - // rows per thread - const int dr = (nr + nth - 1)/nth; + // src1 rows + const int64_t nr1 = ne11*ne12*ne13; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1/(ne12*ne11)); + const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11; + const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11); - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + const int64_t ir0 = (ir1/ne11)%(ne02*ne03); + const int64_t i03 = (ir0/(ne02)); + const int64_t i02 = (ir0 - i03*ne02); - for (int64_t ic = 0; ic < ne11; ++ic) { - // src1 indices - const int i13 = i03; - const int i12 = i02; - const int i11 = ic; + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; - // dst indices - const int i0 = i01; - const int i1 = i11; - const int i2 = i02; - const int i3 = i03; - - ggml_vec_dot_f32(ne00, - (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), - (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); + char * src0_row = (char *) src0->data + ( 0 + i02*nb02 + i03*nb03); + char * src1_col = (char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13); + + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); + + for (int64_t ir = ir10; ir < ir11; ++ir) { + ggml_vec_dot_f32(ne00, &dst_col[ir], (float *) (src0_row + ir*nb01), (float *) src1_col); } } @@ -10646,7 +10639,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int64_t ne1 = dst->ne[1]; const int64_t ne2 = dst->ne[2]; const int64_t ne3 = dst->ne[3]; - //const int64_t ne = ne0*ne1*ne2*ne3; const int nb00 = src0->nb[0]; const int nb01 = src0->nb[1]; @@ -10666,10 +10658,10 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int ith = params->ith; const int nth = params->nth; - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); // TODO: we don't support permuted src0 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); @@ -10680,16 +10672,16 @@ static void ggml_compute_forward_mul_mat_f16_f32( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - // nb01 >= nb00 - src0 is not transposed // compute by src0 rows #if defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); } @@ -10701,6 +10693,11 @@ static void ggml_compute_forward_mul_mat_f16_f32( if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { GGML_ASSERT(nb10 == sizeof(float)); + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith != 0) { return; } @@ -10774,40 +10771,38 @@ static void ggml_compute_forward_mul_mat_f16_f32( // TODO: do not support transposed src1 assert(nb10/2 == sizeof(ggml_fp16_t)); - // parallelize by src0 rows using ggml_vec_dot_f16 + // parallelize by src0 rows + const int64_t dr = (ne01 + nth - 1)/nth; - // total rows in src0 - const int nr = ne01*ne02*ne03; + const int64_t ir10 = dr*ith; + const int64_t ir11 = MIN(ir10 + dr, ne01); - // rows per thread - const int dr = (nr + nth - 1)/nth; + // src1 rows + const int64_t nr1 = ne11*ne12*ne13; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - ggml_fp16_t * wdata = params->wdata; + void * wdata = params->wdata; + const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_F16]; - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1/(ne12*ne11)); + const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11; + const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11); - const int i13 = i03; - const int i12 = i02; + const int64_t ir0 = (ir1/ne11)%(ne02*ne03); + const int64_t i03 = (ir0/(ne02)); + const int64_t i02 = (ir0 - i03*ne02); - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; - ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; + char * src0_row = (char *) src0->data + ( 0 + i02*nb02 + i03*nb03 ); + char * src1_col = (char *) wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size; - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); - for (int64_t ic = 0; ic < ne11; ++ic) { - ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); + for (int64_t ir = ir10; ir < ir11; ++ir) { + ggml_vec_dot_f16(ne00, &dst_col[ir], (ggml_fp16_t *) (src0_row + ir*nb01), (ggml_fp16_t *) src1_col); } } @@ -10865,16 +10860,16 @@ static void ggml_compute_forward_mul_mat_q_f32( const int ith = params->ith; const int nth = params->nth; - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - const enum ggml_type type = src0->type; quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot; vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type; + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); GGML_ASSERT(nb10 == sizeof(float)); @@ -10885,16 +10880,16 @@ static void ggml_compute_forward_mul_mat_q_f32( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - // nb01 >= nb00 - src0 is not transposed // compute by src0 rows #if defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); } @@ -10904,6 +10899,11 @@ static void ggml_compute_forward_mul_mat_q_f32( #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith != 0) { return; } @@ -10971,43 +10971,38 @@ static void ggml_compute_forward_mul_mat_q_f32( return; } - // parallelize by src0 rows using ggml_vec_dot_q - - // total rows in src0 - const int nr = ne01*ne02*ne03; - - // rows per thread - const int dr = (nr + nth - 1)/nth; + // parallelize by src0 rows + const int64_t dr = (ne01 + nth - 1)/nth; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + const int64_t ir10 = dr*ith; + const int64_t ir11 = MIN(ir10 + dr, ne01); - void * wdata = params->wdata; - const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; + // src1 rows + const int64_t nr1 = ne11*ne12*ne13; - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + const void * wdata = params->wdata; + const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; - const int i13 = i03; - const int i12 = i02; + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1/(ne12*ne11)); + const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11; + const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11); - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; + const int64_t ir0 = (ir1/ne11)%(ne02*ne03); + const int64_t i03 = (ir0/(ne02)); + const int64_t i02 = (ir0 - i03*ne02); - void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 ); + const char * src1_col = (const char *) wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size; - assert(ne00 % 32 == 0); + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); - for (int64_t ic = 0; ic < ne11; ++ic) { - vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); + for (int64_t ir = ir10; ir < ir11; ++ir) { + vec_dot_q(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col); } }