Skip to content

Commit

Permalink
feat: ref. cross entropy, add CUDA, fix grad test
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Aug 27, 2024
1 parent 879dcb8 commit c5fb49b
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 21 deletions.
1 change: 1 addition & 0 deletions include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ extern "C" {
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);

// "offset" refers to the offset of the tensor data for setting/getting data
GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);

Expand Down
9 changes: 8 additions & 1 deletion src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
#include "ggml-cuda/binbcast.cuh"
#include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/convert.cuh"
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/dmmv.cuh"
#include "ggml-cuda/fattn.cuh"
Expand All @@ -29,7 +31,6 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"

#include <algorithm>
#include <array>
Expand Down Expand Up @@ -2312,6 +2313,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_FLASH_ATTN_EXT:
ggml_cuda_flash_attn_ext(ctx, dst);
break;
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
default:
return false;
}
Expand Down Expand Up @@ -2619,6 +2623,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
assert(node->src[j]->buffer);
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
}
}
Expand Down Expand Up @@ -2902,6 +2907,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
}
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
case GGML_OP_CROSS_ENTROPY_LOSS:
return true;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
default:
return false;
Expand Down
106 changes: 106 additions & 0 deletions src/ggml-cuda/cross-entropy-loss.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#include "common.cuh"
#include "cross-entropy-loss.cuh"
#include "sumrows.cuh"

#include <cmath>
#include <cstdint>

static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;

const int ne_tmp = WARP_SIZE*nclasses;

extern __shared__ float tmp_all[];
float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;

// Each warp first loads ne_tmp logits/labels into shared memory:
for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
const int ig = i0*nclasses + i; // ig == i global

tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
}

// Each thread in the warp then calculates the cross entropy loss for a single row.
// TODO: pad in order to avoid shared memory bank conflicts.

// Find maximum for softmax:
float max = -INFINITY;
for (int i = 0; i < nclasses; ++i) {
max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
}

// Calculate log(softmax(logits)) which is just logits - max:
float sum = 0.0f;
for (int i = 0; i < nclasses; ++i) {
float val = tmp_logits[lane_id*nclasses + i] - max;
sum += expf(val);
tmp_logits[lane_id*nclasses + i] = val;
}
sum = logf(sum);

// log(exp(logits - max) / sum) = (logits - max) - log(sum)
float loss = 0.0f;
for (int i = 0; i < nclasses; ++i) {
loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
}
loss = -warp_reduce_sum(loss) / (float)k;

__syncthreads();

if (lane_id == 0) {
tmp_all[warp_id] = loss;
}

__syncthreads();

if (warp_id != 0) {
return;
}

loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
loss = warp_reduce_sum(loss);

if (lane_id != 0) {
return;
}

dst[blockIdx.x] = loss;
}

void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));

const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);

const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;

ggml_cuda_pool & pool = ctx.pool();
cudaStream_t stream = ctx.stream();

const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);

ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);

cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);

// Combine results from individual blocks:
sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
}
5 changes: 5 additions & 0 deletions src/ggml-cuda/cross-entropy-loss.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256

void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3 changes: 1 addition & 2 deletions src/ggml-cuda/sumrows.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
}
}

static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
Expand All @@ -32,7 +32,6 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));


const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);

Expand Down
2 changes: 2 additions & 0 deletions src/ggml-cuda/sumrows.cuh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include "common.cuh"

void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);

void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
35 changes: 19 additions & 16 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2671,6 +2671,19 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
return sum;
}

static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
// log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)

int i = 0;
ggml_float sum = 0;
for (; i < n; ++i) {
float val = x[i] - max;
y[i] = val;
sum += (ggml_float)expf(val);
}
return sum = (ggml_float)logf(sum);
}

inline static float ggml_silu_backward_f32(float x, float dy) {
const float s = 1.0f/(1.0f + expf(-x));
return dy*s*(1.0f + x*(1.0f - s));
Expand Down Expand Up @@ -17023,8 +17036,6 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
}
ggml_barrier(params->shared);

const double eps = 1e-9;

// rows per thread
const int dr = (nr + nth - 1)/nth;

Expand All @@ -17045,20 +17056,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
}
#endif

// soft_max
float max = -INFINITY;
ggml_vec_max_f32(nc, &max, s0);
ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
assert(sum > 0.0);
sum = (1.0 - eps) / sum;
ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max);
assert(sum >= 0.0);

// avoid log(0) by rescaling from [0..1] to [eps..1]
ggml_vec_scale_f32(nc, st, sum);
ggml_vec_add1_f32(nc, st, st, eps);
ggml_vec_log_f32(nc, st, st);
ggml_vec_add1_f32(nc, st, st, -sum);
ggml_vec_mul_f32(nc, st, st, s1);

float st_sum = 0;
float st_sum = 0.0f;
ggml_vec_sum_f32(nc, &st_sum, st);
sums[ith] += st_sum;

Expand Down Expand Up @@ -17115,8 +17121,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
const int64_t ith = params->ith;
const int64_t nth = params->nth;

const double eps = 1e-9;

// TODO: handle transposed/permuted matrices
const int64_t nc = src0->ne[0];
const int64_t nr = ggml_nrows(src0);
Expand Down Expand Up @@ -17148,11 +17152,9 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
ggml_vec_max_f32(nc, &max, s0);
ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
assert(sum > 0.0);
sum = (1.0 - eps) / sum;
ggml_vec_scale_f32(nc, ds0, 1.0/sum);

// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
ggml_vec_scale_f32(nc, ds0, sum);
ggml_vec_add1_f32(nc, ds0, ds0, eps);
ggml_vec_sub_f32(nc, ds0, ds0, s1);
ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);

Expand Down Expand Up @@ -20288,6 +20290,7 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_opt_callback callback,
void * callback_data) {
GGML_ASSERT(ggml_is_scalar(f));
GGML_ASSERT(f->type == GGML_TYPE_F32);

// these will store the parameters we want to optimize
struct ggml_tensor * ps[GGML_MAX_PARAMS];
Expand Down
23 changes: 23 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,27 @@ struct test_flash_attn_ext : public test_case {
}
};

// GGML_OP_CROSS_ENTROPY_LOSS
struct test_cross_entropy_loss : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;

std::string vars() override {
return VARS_TO_STR2(type, ne);
}

test_cross_entropy_loss(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 10})
: type(type), ne(ne) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * out = ggml_cross_entropy_loss(ctx, logits, labels);
return out;
}
};

enum llm_norm_type {
LLM_NORM,
LLM_NORM_RMS,
Expand Down Expand Up @@ -2491,6 +2512,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
}

test_cases.emplace_back(new test_cross_entropy_loss());

// these tests are disabled to save execution time, but they can be handy for debugging
#if 0
test_cases.emplace_back(new test_llama(1));
Expand Down
4 changes: 2 additions & 2 deletions tests/test-grad0.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 4);

for (int ndims = 1; ndims <= 4; ++ndims) {
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -0.1f, 0.1f);
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
// the second argument to cross_entropy_loss must sum up to 1 for each row
int nr = ggml_nrows(x[1]);
Expand All @@ -1462,7 +1462,7 @@ int main(int argc, const char ** argv) {

struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]);

check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-4f, 1e-3f, INFINITY, {});
check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
}
}

Expand Down

0 comments on commit c5fb49b

Please sign in to comment.