Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds new OP ggml_unfold_1d #867

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ extern "C" {
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
GGML_OP_UNFOLD_1D,
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD,
GGML_OP_ARANGE,
Expand Down Expand Up @@ -1708,6 +1709,14 @@ extern "C" {
float p0,
float p1);


GGML_API struct ggml_tensor * ggml_unfold_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k,
int s);


// nearest interpolate
// multiplies ne0 and ne1 by scale factor
// used in stable-diffusion
Expand Down
5 changes: 5 additions & 0 deletions src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/unfold1d.cuh"

#include <algorithm>
#include <array>
Expand Down Expand Up @@ -2288,6 +2289,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_PAD:
ggml_cuda_op_pad(ctx, dst);
break;
case GGML_OP_UNFOLD_1D:
ggml_cuda_op_unfold_1d(ctx, dst);
break;
case GGML_OP_ARANGE:
ggml_cuda_op_arange(ctx, dst);
break;
Expand Down Expand Up @@ -2895,6 +2899,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_GROUP_NORM:
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_UNFOLD_1D:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
Expand Down
44 changes: 44 additions & 0 deletions src/ggml-cuda/unfold1d.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "unfold1d.cuh"

static __global__ void unfold_1d_f32(const float * x, float * dst, const int s, const int ne0, const int ne1, const int ne2,
const int ne3, const int ne00, const int ne01, const int ne02, const int ne03) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0 * ne1 * ne2 * ne3) {
return;
}

const int i3 = nidx/(ne0 * ne1 * ne2);
const int i2 = (nidx - i3*ne0*ne1*ne2 )/ (ne0*ne1);
const int i1 = (nidx - i3*ne0*ne1*ne2 - i2*ne1*ne0) / ne0;
const int i0 = nidx - i3*ne0*ne1*ne2 - i2*ne1*ne0 - i1*ne0;

const int src_idx = i3 *(ne00*ne01) + i2 * (ne00) + i1*s + i0;

dst[nidx] = x[src_idx];
}

static void unfold_1d_f32_cuda(const float * x, float * dst, const int s,
const int ne0, const int ne1, const int ne2, const int ne3,
const int ne00, const int ne01, const int ne02, const int ne03, cudaStream_t stream) {
int num_blocks = ((ne0 * ne1 * ne2 * ne3) + CUDA_UNFOLD_1D_BLOCK_SIZE - 1) / CUDA_UNFOLD_1D_BLOCK_SIZE;

unfold_1d_f32<<<num_blocks, CUDA_UNFOLD_1D_BLOCK_SIZE,0,stream>>>(x, dst, s, ne0, ne1, ne2, ne3, ne00, ne01, ne02, ne03);
}

void ggml_cuda_op_unfold_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1); // only up to 3 dimensions for input tensor

const int32_t * opts = (const int32_t *)dst->op_params;
const int s = opts[1];

unfold_1d_f32_cuda(src0_d, dst_d, s,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], stream);
}
5 changes: 5 additions & 0 deletions src/ggml-cuda/unfold1d.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_UNFOLD_1D_BLOCK_SIZE 256

void ggml_cuda_op_unfold_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
99 changes: 97 additions & 2 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2696,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};

static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -2752,6 +2752,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"pool_1d(x)",
"pool_2d(x)",
"upscale(x)",
"unfold_1d(x)",
"pad(x)",
"arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)",
Expand Down Expand Up @@ -2784,7 +2785,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};

static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -6852,6 +6853,43 @@ struct ggml_tensor * ggml_upscale_ext(
return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
}


// ggml_unfold_1d

struct ggml_tensor * ggml_unfold_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k,
int s) {

bool is_node = false;

if (a->grad) {
GGML_ASSERT(false); // TODO: implement backward
is_node = true;
}

GGML_ASSERT(a->ne[3] == 1); // we only allow up to 3d input tensors, since this operations adds a dimension

GGML_ASSERT((a->ne[0] - k) % s == 0);// are the stride and kernel size valid given the unfold dimension


const int64_t ne[4] = { k, ((a->ne[0] - k) / s) + 1 ,a->ne[1], a->ne[2]};
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);

int32_t params[] = { k, s };
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_UNFOLD_1D;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;

return result;
}




// ggml_pad

struct ggml_tensor * ggml_pad(
Expand Down Expand Up @@ -15445,6 +15483,51 @@ static void ggml_compute_forward_upscale(
}
}

// ggml_compute_forward_unfold_1d

static void ggml_compute_forward_unfold_1d(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

const struct ggml_tensor * src0 = dst->src[0];

if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
}

GGML_ASSERT(src0->type == GGML_TYPE_F32);

const int ith = params->ith;
const int nth = params->nth;

GGML_TENSOR_UNARY_OP_LOCALS

float * dst_ptr = (float *) dst->data;
float * src0_ptr = (float *) src0->data;


const int32_t * opts = (const int32_t *)dst->op_params;
const int s = opts[1];


// TODO: optimize

for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
for (int64_t i3 = 0; i3 < ne3; ++i3) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;

const int64_t src_idx = i3 *(ne00*ne01) + i2 * (ne00) + i1*s + i0;

dst_ptr[dst_idx] = src0_ptr[src_idx];

}
}
}
}
}


// ggml_compute_forward_pad

Expand Down Expand Up @@ -17453,6 +17536,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_upscale(params, tensor);
} break;
case GGML_OP_UNFOLD_1D:
{
ggml_compute_forward_unfold_1d(params, tensor);
} break;
case GGML_OP_PAD:
{
ggml_compute_forward_pad(params, tensor);
Expand Down Expand Up @@ -18463,6 +18550,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_UNFOLD_1D:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_PAD:
{
GGML_ASSERT(false); // TODO: not implemented
Expand Down Expand Up @@ -19206,6 +19297,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
{
n_tasks = n_threads;
} break;
case GGML_OP_UNFOLD_1D:
{
n_tasks = 1;
} break;
case GGML_OP_PAD:
{
n_tasks = n_threads;
Expand Down
8 changes: 8 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,14 @@ target_link_libraries(${TEST_TARGET} PRIVATE ggml)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")

#
# test-unfold-1d

set(TEST_TARGET test-unfold-1d)
add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")

#
# test-mul-mat
Expand Down
Loading
Loading