diff --git a/include/ggml-backend.h b/include/ggml-backend.h index 5f3f1e286..e73b9a745 100644 --- a/include/ggml-backend.h +++ b/include/ggml-backend.h @@ -63,6 +63,7 @@ extern "C" { GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + // "offset" refers to the offset of the tensor data for setting/getting data GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index 56c16a3c4..8a844b02a 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -9,8 +9,10 @@ #include "ggml-cuda/binbcast.cuh" #include "ggml-cuda/clamp.cuh" #include "ggml-cuda/concat.cuh" +#include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/convert.cuh" #include "ggml-cuda/cpy.cuh" +#include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/dmmv.cuh" #include "ggml-cuda/fattn.cuh" @@ -29,7 +31,6 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" -#include "ggml-cuda/conv-transpose-1d.cuh" #include #include @@ -2312,6 +2313,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_FLASH_ATTN_EXT: ggml_cuda_flash_attn_ext(ctx, dst); break; + case GGML_OP_CROSS_ENTROPY_LOSS: + ggml_cuda_cross_entropy_loss(ctx, dst); + break; default: return false; } @@ -2619,6 +2623,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); for (int j = 0; j < GGML_MAX_SRC; j++) { if (node->src[j] != nullptr) { + assert(node->src[j]->buffer); assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); } } @@ -2902,6 +2907,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons } return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; + case GGML_OP_CROSS_ENTROPY_LOSS: + return true; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) default: return false; diff --git a/src/ggml-cuda/cross-entropy-loss.cu b/src/ggml-cuda/cross-entropy-loss.cu new file mode 100644 index 000000000..a14043e70 --- /dev/null +++ b/src/ggml-cuda/cross-entropy-loss.cu @@ -0,0 +1,106 @@ +#include "common.cuh" +#include "cross-entropy-loss.cuh" +#include "sumrows.cuh" + +#include +#include + +static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) { + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE; + + const int ne_tmp = WARP_SIZE*nclasses; + + extern __shared__ float tmp_all[]; + float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp; + float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp; + + // Each warp first loads ne_tmp logits/labels into shared memory: + for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) { + const int ig = i0*nclasses + i; // ig == i global + + tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f; + tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f; + } + + // Each thread in the warp then calculates the cross entropy loss for a single row. + // TODO: pad in order to avoid shared memory bank conflicts. + + // Find maximum for softmax: + float max = -INFINITY; + for (int i = 0; i < nclasses; ++i) { + max = fmaxf(max, tmp_logits[lane_id*nclasses + i]); + } + + // Calculate log(softmax(logits)) which is just logits - max: + float sum = 0.0f; + for (int i = 0; i < nclasses; ++i) { + float val = tmp_logits[lane_id*nclasses + i] - max; + sum += expf(val); + tmp_logits[lane_id*nclasses + i] = val; + } + sum = logf(sum); + + // log(exp(logits - max) / sum) = (logits - max) - log(sum) + float loss = 0.0f; + for (int i = 0; i < nclasses; ++i) { + loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i]; + } + loss = -warp_reduce_sum(loss) / (float)k; + + __syncthreads(); + + if (lane_id == 0) { + tmp_all[warp_id] = loss; + } + + __syncthreads(); + + if (warp_id != 0) { + return; + } + + loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f; + loss = warp_reduce_sum(loss); + + if (lane_id != 0) { + return; + } + + dst[blockIdx.x] = loss; +} + +void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t stream = ctx.stream(); + + const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); + const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); + const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float); + + ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); + + cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + + // Combine results from individual blocks: + sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream); +} diff --git a/src/ggml-cuda/cross-entropy-loss.cuh b/src/ggml-cuda/cross-entropy-loss.cuh new file mode 100644 index 000000000..9d7b8b0f0 --- /dev/null +++ b/src/ggml-cuda/cross-entropy-loss.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256 + +void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml-cuda/sumrows.cu b/src/ggml-cuda/sumrows.cu index 82e8e875f..38dbf1b5e 100644 --- a/src/ggml-cuda/sumrows.cu +++ b/src/ggml-cuda/sumrows.cu @@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc } } -static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_nums(nrows, 1, 1); k_sum_rows_f32<<>>(x, dst, ncols); @@ -32,7 +32,6 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(src0)); - const int64_t ncols = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); diff --git a/src/ggml-cuda/sumrows.cuh b/src/ggml-cuda/sumrows.cuh index e7545f83c..191db1c13 100644 --- a/src/ggml-cuda/sumrows.cuh +++ b/src/ggml-cuda/sumrows.cuh @@ -1,3 +1,5 @@ #include "common.cuh" +void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream); + void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml.c b/src/ggml.c index 07d9d5081..811aa0bfb 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -2671,6 +2671,19 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, return sum; } +static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) { + // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i) + + int i = 0; + ggml_float sum = 0; + for (; i < n; ++i) { + float val = x[i] - max; + y[i] = val; + sum += (ggml_float)expf(val); + } + return sum = (ggml_float)logf(sum); +} + inline static float ggml_silu_backward_f32(float x, float dy) { const float s = 1.0f/(1.0f + expf(-x)); return dy*s*(1.0f + x*(1.0f - s)); @@ -17023,8 +17036,6 @@ static void ggml_compute_forward_cross_entropy_loss_f32( } ggml_barrier(params->shared); - const double eps = 1e-9; - // rows per thread const int dr = (nr + nth - 1)/nth; @@ -17045,20 +17056,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32( } #endif - // soft_max float max = -INFINITY; ggml_vec_max_f32(nc, &max, s0); - ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max); - assert(sum > 0.0); - sum = (1.0 - eps) / sum; + ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max); + assert(sum >= 0.0); - // avoid log(0) by rescaling from [0..1] to [eps..1] - ggml_vec_scale_f32(nc, st, sum); - ggml_vec_add1_f32(nc, st, st, eps); - ggml_vec_log_f32(nc, st, st); + ggml_vec_add1_f32(nc, st, st, -sum); ggml_vec_mul_f32(nc, st, st, s1); - float st_sum = 0; + float st_sum = 0.0f; ggml_vec_sum_f32(nc, &st_sum, st); sums[ith] += st_sum; @@ -17115,8 +17121,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( const int64_t ith = params->ith; const int64_t nth = params->nth; - const double eps = 1e-9; - // TODO: handle transposed/permuted matrices const int64_t nc = src0->ne[0]; const int64_t nr = ggml_nrows(src0); @@ -17148,11 +17152,9 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( ggml_vec_max_f32(nc, &max, s0); ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); assert(sum > 0.0); - sum = (1.0 - eps) / sum; + ggml_vec_scale_f32(nc, ds0, 1.0/sum); // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr - ggml_vec_scale_f32(nc, ds0, sum); - ggml_vec_add1_f32(nc, ds0, ds0, eps); ggml_vec_sub_f32(nc, ds0, ds0, s1); ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); @@ -20288,6 +20290,7 @@ static enum ggml_opt_result ggml_opt_adam( ggml_opt_callback callback, void * callback_data) { GGML_ASSERT(ggml_is_scalar(f)); + GGML_ASSERT(f->type == GGML_TYPE_F32); // these will store the parameters we want to optimize struct ggml_tensor * ps[GGML_MAX_PARAMS]; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 01702b109..8f60854b2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1730,6 +1730,27 @@ struct test_flash_attn_ext : public test_case { } }; +// GGML_OP_CROSS_ENTROPY_LOSS +struct test_cross_entropy_loss : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_cross_entropy_loss(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 10, 10}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_cross_entropy_loss(ctx, logits, labels); + return out; + } +}; + enum llm_norm_type { LLM_NORM, LLM_NORM_RMS, @@ -2491,6 +2512,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + test_cases.emplace_back(new test_cross_entropy_loss()); + // these tests are disabled to save execution time, but they can be handy for debugging #if 0 test_cases.emplace_back(new test_llama(1)); diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index 2221fa2d5..1834c11d8 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -1444,7 +1444,7 @@ int main(int argc, const char ** argv) { get_random_dims(ne2, 4); for (int ndims = 1; ndims <= 4; ++ndims) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -0.1f, 0.1f); + x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f); // the second argument to cross_entropy_loss must sum up to 1 for each row int nr = ggml_nrows(x[1]); @@ -1462,7 +1462,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]); - check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-4f, 1e-3f, INFINITY, {}); + check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } }