From b654173e3125b8523d40dac4593cdbb48b46f611 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Jun 2023 16:03:33 +0300 Subject: [PATCH] ggml : make src0 broadcast-able into src1 for ggml_mul_mat --- src/ggml.c | 272 +++++++++++++++++++++++++++-------------------------- 1 file changed, 141 insertions(+), 131 deletions(-) diff --git a/src/ggml.c b/src/ggml.c index 18a2007b7..349a18584 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -3785,10 +3785,9 @@ static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) { static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - return - (t0->ne[0] == t1->ne[0]) && - (t0->ne[2] == t1->ne[2]) && - (t0->ne[3] == t1->ne[3]); + return (t0->ne[0] == t1->ne[0]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); } bool ggml_is_quantized(enum ggml_type type) { @@ -5432,8 +5431,8 @@ struct ggml_tensor * ggml_mul_mat( is_node = true; } - const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); result->op = GGML_OP_MUL_MAT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -9628,11 +9627,8 @@ static void ggml_compute_forward_mul_mat_f32( const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) const int64_t ne10 = src1->ne[0]; -#endif const int64_t ne11 = src1->ne[1]; -#ifndef NDEBUG const int64_t ne12 = src1->ne[2]; const int64_t ne13 = src1->ne[3]; @@ -9642,17 +9638,14 @@ static void ggml_compute_forward_mul_mat_f32( const int64_t ne3 = dst->ne[3]; const int nb00 = src0->nb[0]; -#endif const int nb01 = src0->nb[1]; const int nb02 = src0->nb[2]; const int nb03 = src0->nb[3]; -#ifndef NDEBUG const int nb10 = src1->nb[0]; -#endif - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; + const int nb11 = src1->nb[1]; UNUSED(nb11); + const int nb12 = src1->nb[2]; UNUSED(nb12); + const int nb13 = src1->nb[3]; UNUSED(nb13); const int nb0 = dst->nb[0]; const int nb1 = dst->nb[1]; @@ -9662,31 +9655,31 @@ static void ggml_compute_forward_mul_mat_f32( const int ith = params->ith; const int nth = params->nth; - assert(ne02 == ne12); - assert(ne03 == ne13); - assert(ne2 == ne12); - assert(ne3 == ne13); + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); // we don't support permuted src0 or src1 - assert(nb00 == sizeof(float)); - assert(nb10 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); // dst cannot be transposed or permuted - assert(nb0 == sizeof(float)); - assert(nb0 <= nb1); - assert(nb1 <= nb2); - assert(nb2 <= nb3); - - assert(ne0 == ne01); - assert(ne1 == ne11); - assert(ne2 == ne02); - assert(ne3 == ne03); + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); // nb01 >= nb00 - src0 is not transposed // compute by src0 rows #if defined(GGML_USE_CUBLAS) if (ggml_cuda_can_mul_mat(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); } @@ -9694,6 +9687,11 @@ static void ggml_compute_forward_mul_mat_f32( } #elif defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); } @@ -9703,6 +9701,11 @@ static void ggml_compute_forward_mul_mat_f32( #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith != 0) { return; } @@ -9742,40 +9745,35 @@ static void ggml_compute_forward_mul_mat_f32( return; } - // parallelize by src0 rows using ggml_vec_dot_f32 + // parallelize by src0 rows + const int64_t dr = (ne01 + nth - 1)/nth; - // total rows in src0 - const int nr = ne01*ne02*ne03; + const int64_t ir10 = dr*ith; + const int64_t ir11 = MIN(ir10 + dr, ne01); - // rows per thread - const int dr = (nr + nth - 1)/nth; + // src1 rows + const int64_t nr1 = ne11*ne12*ne13; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1/(ne12*ne11)); + const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11; + const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11); - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + const int64_t ir0 = (ir1/ne11)%(ne02*ne03); + const int64_t i03 = (ir0/(ne02)); + const int64_t i02 = (ir0 - i03*ne02); - for (int64_t ic = 0; ic < ne11; ++ic) { - // src1 indices - const int i13 = i03; - const int i12 = i02; - const int i11 = ic; + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; - // dst indices - const int i0 = i01; - const int i1 = i11; - const int i2 = i02; - const int i3 = i03; - - ggml_vec_dot_f32(ne00, - (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), - (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); + char * src0_row = (char *) src0->data + ( 0 + i02*nb02 + i03*nb03); + char * src1_col = (char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13); + + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); + + for (int64_t ir = ir10; ir < ir11; ++ir) { + ggml_vec_dot_f32(ne00, &dst_col[ir], (float *) (src0_row + ir*nb01), (float *) src1_col); } } @@ -9815,7 +9813,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int64_t ne1 = dst->ne[1]; const int64_t ne2 = dst->ne[2]; const int64_t ne3 = dst->ne[3]; - //const int64_t ne = ne0*ne1*ne2*ne3; const int nb00 = src0->nb[0]; const int nb01 = src0->nb[1]; @@ -9835,10 +9832,10 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int ith = params->ith; const int nth = params->nth; - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); // TODO: we don't support permuted src0 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); @@ -9849,16 +9846,16 @@ static void ggml_compute_forward_mul_mat_f16_f32( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - // nb01 >= nb00 - src0 is not transposed // compute by src0 rows #if defined(GGML_USE_CUBLAS) if (ggml_cuda_can_mul_mat(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); } @@ -9866,6 +9863,11 @@ static void ggml_compute_forward_mul_mat_f16_f32( } #elif defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); } @@ -9877,6 +9879,11 @@ static void ggml_compute_forward_mul_mat_f16_f32( if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { GGML_ASSERT(nb10 == sizeof(float)); + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith != 0) { return; } @@ -9950,40 +9957,38 @@ static void ggml_compute_forward_mul_mat_f16_f32( // TODO: do not support transposed src1 assert(nb10/2 == sizeof(ggml_fp16_t)); - // parallelize by src0 rows using ggml_vec_dot_f16 - - // total rows in src0 - const int nr = ne01*ne02*ne03; + // parallelize by src0 rows + const int64_t dr = (ne01 + nth - 1)/nth; - // rows per thread - const int dr = (nr + nth - 1)/nth; + const int64_t ir10 = dr*ith; + const int64_t ir11 = MIN(ir10 + dr, ne01); - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + // src1 rows + const int64_t nr1 = ne11*ne12*ne13; - ggml_fp16_t * wdata = params->wdata; + void * wdata = params->wdata; + const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_F16]; - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1/(ne12*ne11)); + const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11; + const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11); - const int i13 = i03; - const int i12 = i02; + const int64_t ir0 = (ir1/ne11)%(ne02*ne03); + const int64_t i03 = (ir0/(ne02)); + const int64_t i02 = (ir0 - i03*ne02); - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; - ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; + char * src0_row = (char *) src0->data + ( 0 + i02*nb02 + i03*nb03 ); + char * src1_col = (char *) wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size; - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); - for (int64_t ic = 0; ic < ne11; ++ic) { - ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); + for (int64_t ir = ir10; ir < ir11; ++ir) { + ggml_vec_dot_f16(ne00, &dst_col[ir], (ggml_fp16_t *) (src0_row + ir*nb01), (ggml_fp16_t *) src1_col); } } @@ -10041,16 +10046,16 @@ static void ggml_compute_forward_mul_mat_q_f32( const int ith = params->ith; const int nth = params->nth; - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - const enum ggml_type type = src0->type; quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot; vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type; + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); GGML_ASSERT(nb10 == sizeof(float)); @@ -10061,16 +10066,16 @@ static void ggml_compute_forward_mul_mat_q_f32( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - // nb01 >= nb00 - src0 is not transposed // compute by src0 rows #if defined(GGML_USE_CUBLAS) if (ggml_cuda_can_mul_mat(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); } @@ -10078,6 +10083,11 @@ static void ggml_compute_forward_mul_mat_q_f32( } #elif defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); } @@ -10087,6 +10097,11 @@ static void ggml_compute_forward_mul_mat_q_f32( #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + if (params->ith != 0) { return; } @@ -10154,43 +10169,38 @@ static void ggml_compute_forward_mul_mat_q_f32( return; } - // parallelize by src0 rows using ggml_vec_dot_q + // parallelize by src0 rows + const int64_t dr = (ne01 + nth - 1)/nth; - // total rows in src0 - const int nr = ne01*ne02*ne03; + const int64_t ir10 = dr*ith; + const int64_t ir11 = MIN(ir10 + dr, ne01); - // rows per thread - const int dr = (nr + nth - 1)/nth; + // src1 rows + const int64_t nr1 = ne11*ne12*ne13; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + const void * wdata = params->wdata; + const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; - void * wdata = params->wdata; - const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1/(ne12*ne11)); + const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11; + const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11); - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + const int64_t ir0 = (ir1/ne11)%(ne02*ne03); + const int64_t i03 = (ir0/(ne02)); + const int64_t i02 = (ir0 - i03*ne02); - const int i13 = i03; - const int i12 = i02; + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; - - void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); + const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 ); + const char * src1_col = (const char *) wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size; - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - - assert(ne00 % 32 == 0); + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); - for (int64_t ic = 0; ic < ne11; ++ic) { - vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); + for (int64_t ir = ir10; ir < ir11; ++ir) { + vec_dot_q(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col); } }