Skip to content

Commit

Permalink
sync from sd.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
bssrdf committed Sep 29, 2024
1 parent 6afbf6e commit 0491858
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 122 deletions.
157 changes: 38 additions & 119 deletions src/ggml-cuda/conv-winograd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ __device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag
accumulator[1][15].w += input_frag[3].w*filter_frag[3].w;
}

extern "C"
{
// extern "C"
// {

__device__ __forceinline__ void transform_output_tile(float *pOutputs, float2 *C_tile, float2 *At,
int round, int c_tensor, int c_glb_offset, int i1, int i2,
Expand Down Expand Up @@ -248,7 +248,7 @@ float4 *input_frag_mem, float4* filter_frag_mem){

float2 *output_smem = (float2 *) shared_mem;
float2 *accumulator = (float2 *) acumm_smem;
float2 *C_out = (float2*)C;
// float2 *C_out = (float2*)C;

float2 *C_tile = (float2*) input_frag_mem;
float2 *At = (float2*) filter_frag_mem;
Expand Down Expand Up @@ -295,12 +295,11 @@ float4 *input_frag_mem, float4* filter_frag_mem){
// blockIdx.x*BN + (threadIdx.x%16)*2+
// ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset;

int tx = TW, ty = TH;
// int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty;
// int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 +
// threadIdx.y*(in_h*in_w) - (in_w+1);

int c_tensor = blockIdx.z*c_glb_offset*BK + blockIdx.x * tx + blockIdx.y * out_w * ty +
int c_tensor = blockIdx.z*c_glb_offset*BK + blockIdx.x * TW + blockIdx.y * out_w * TH +
// (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * out_w * 2 +
((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset;

Expand Down Expand Up @@ -382,77 +381,31 @@ __device__ float f_row1(float *G, int j){
return G[j+2];
}

typedef float(*pointFunction_t)(float *, int);

__global__ void FX(const float *pInputs, float *pOutputs, int filt_k,
int filt_c, int filt_h, int filt_w){
template <typename T>
static __device__ __forceinline__ float t2f32(T val) {
return (float) val;
}

// assumes CHWK layout
int Inx = threadIdx.x, Iny = threadIdx.y;
int TileX = blockIdx.x, TileY = blockIdx.y;

int c_glb_offset = filt_k*filt_h*filt_w;
int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx;
int c_glb_offset_s = filt_k*4*4;
int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx;

float Gw[21]; //9+12. In registers
float *Gw_buffer = Gw+9;

pointFunction_t func1[4] = {f_row1, f_row2, f_row3, f_row4};
pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4};

for(int bk=0; bk<BK; bk+=blockDim.x){
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
// printf("[");
// }
for(int i=0; i<9; i++){
Gw[i] = pInputs[c_kernel + i*filt_k];
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
// printf("(%f,%d) ", Gw[i], c_kernel + i*filt_k);
// }
}
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
// printf("]\n");
// }

int aux;
for(int i=0; i<4; i++){
aux = i*3;
for(int j=0; j<3; j++){
Gw_buffer[j+aux] = (*func1[i])(Gw, j);
}
}
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
// printf("X[");
// for(int kk = 0; kk < 21; kk++){
// printf("%f, ", Gw[kk]);
// }
// printf("]\n");
// }

int aux2;
for(int i=0; i<4; i++){
aux = i*3; aux2 = i<<2;
for(int j=0; j<4; j++){
pOutputs[c_kernel_s+aux2*filt_k+j*filt_k] = (*func2[j])(Gw_buffer, aux);
}
}

c_kernel += blockDim.x;
c_kernel_s += blockDim.x;
}
template <>
__device__ float __forceinline__ t2f32<half>(half val) {
return __half2float(val);
}

__global__ void FX_FP16(const half *pInputs, float *pOutputs, int filt_k,
typedef float(*pointFunction_t)(float *, int);

template<typename T>
__global__ void FX(const T *pInputs, float *pOutputs, int filt_k,
int filt_c, int filt_h, int filt_w){

// assumes CHWK layout
// assumes KCHW layout
int Inx = threadIdx.x, Iny = threadIdx.y;
int TileX = blockIdx.x, TileY = blockIdx.y;

int c_glb_offset = filt_k*filt_h*filt_w;
int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx;
// int c_glb_offset = filt_k*filt_h*filt_w;
// int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx;
int c_glb_offset = filt_h*filt_w;
// int c_kernel = TileY*BC*c_glb_offset + TileX*BK*filt_c*c_glb_offset + Iny*c_glb_offset+ Inx*filt_c*c_glb_offset;
int c_kernel = (TileY*BC + (TileX*BK+Inx)*filt_c + Iny)*c_glb_offset;
int c_glb_offset_s = filt_k*4*4;
int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx;

Expand All @@ -462,19 +415,11 @@ __device__ float f_row1(float *G, int j){
pointFunction_t func1[4] = {f_row1, f_row2, f_row3, f_row4};
pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4};

for(int bk=0; bk<BK; bk+=blockDim.x){
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
// printf("[");
// }
for(int bk=0; bk<BK; bk+=blockDim.x){
for(int i=0; i<9; i++){
Gw[i] = __half2float(pInputs[c_kernel + i*filt_k]);
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
// printf("(%f,%d) ", Gw[i], c_kernel + i*filt_k);
// }
}
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
// printf("]\n");
// }
Gw[i] = t2f32(pInputs[c_kernel + i]);

}

int aux;
for(int i=0; i<4; i++){
Expand All @@ -483,14 +428,7 @@ __device__ float f_row1(float *G, int j){
Gw_buffer[j+aux] = (*func1[i])(Gw, j);
}
}
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
// printf("X[");
// for(int kk = 0; kk < 21; kk++){
// printf("%f, ", Gw[kk]);
// }
// printf("]\n");
// }


int aux2;
for(int i=0; i<4; i++){
aux = i*3; aux2 = i<<2;
Expand All @@ -499,7 +437,7 @@ __device__ float f_row1(float *G, int j){
}
}

c_kernel += blockDim.x;
c_kernel += blockDim.x*(filt_c*c_glb_offset);
c_kernel_s += blockDim.x;
}
}
Expand Down Expand Up @@ -793,34 +731,16 @@ cudaError_t convolutionForward_32Tx64x8(float *k, int in_h, int in_w, float *w,
return cudaGetLastError();
}

}

// }

static void conv_winograd_stage0_f32_f32_cuda(
template<typename T>
static void conv_winograd_stage0_f32_cuda(
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
const float * src0, float * dst,
const T * src0, float * dst,
cudaStream_t stream) {


int64_t filt_k = src0_ne0;
int64_t filt_c = src0_ne3;

FX<<<dim3(filt_k/BK, filt_c/BC), dim3(32, BC), 0, stream>>>(src0, dst, filt_k, filt_c, src0_ne2, src0_ne1);

}

static void conv_winograd_stage0_f16_f32_cuda(
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
const half * src0, float * dst,
cudaStream_t stream) {


int64_t filt_k = src0_ne0;
int64_t filt_c = src0_ne3;

FX_FP16<<<dim3(filt_k/BK, filt_c/BC), dim3(32, BC), 0, stream>>>(src0, dst, filt_k, filt_c, src0_ne2, src0_ne1);
FX<<<dim3(src0_ne3/BK, src0_ne2/BC), dim3(32, BC), 0, stream>>>(src0, dst, src0_ne3, src0_ne2, src0_ne1, src0_ne0);

}

Expand All @@ -842,12 +762,9 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h,
int64_t out_w = in_w;
int smem_size = (16*BN*BC + 16*BC*BK)*4;

// printf("A %d, %d\n", filt_k, filt_c);
// printf("B %d, %d, %d \n", in_c, in_h, in_w);
// printf("C %d, %d, %d \n", out_c, out_h, out_w);

Winograd_kernel<<<dim3((tiles_dim_w+X-1)/X, (tiles_dim_h+Y-1)/Y, filt_k/BK), dim3(BN, 8), smem_size, stream>>>(src1, src0, dst,
tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w);
tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y,
filt_k, filt_c, out_c, tile_2d_s, out_h, out_w);
}


Expand Down Expand Up @@ -876,12 +793,14 @@ void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor *
// const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get();
if(src0->type == GGML_TYPE_F32){
const float* src0_d = (const float *)src0->data;
conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
// conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
conv_winograd_stage0_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
src0_d, dst_d, stream);
}else{
const half * src0_d = (const half *)src0->data;
conv_winograd_stage0_f16_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
// conv_winograd_stage0_f16_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
conv_winograd_stage0_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
src0_d, dst_d, stream);
}
Expand Down
6 changes: 3 additions & 3 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -7184,7 +7184,7 @@ struct ggml_tensor * ggml_winograd_stage0(
is_node = true;
}

struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], 4, 4, a->ne[3]);
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[3], 4, 4, a->ne[2]);

result->op = GGML_OP_WINOGRAD_STAGE0;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
Expand Down Expand Up @@ -7228,8 +7228,8 @@ struct ggml_tensor * ggml_conv_2d_3x3(
if(a->ne[3] % 64 != 0 || a->ne[2] % 8 != 0) // only works for the number of filters is a multiple of 64
return ggml_conv_2d(ctx, a, b, 1, 1, 1, 1, 1, 1); // and the number of channels is a multiple of 8

struct ggml_tensor* ra = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 3, 0)); // [N, OC, OH, OW]
struct ggml_tensor* W = ggml_winograd_stage0(ctx, ra);
// struct ggml_tensor* ra = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 3, 0)); // [N, OC, OH, OW]
struct ggml_tensor* W = ggml_winograd_stage0(ctx, a);
struct ggml_tensor * result = ggml_winograd_stage1(ctx, W, b);

return result;
Expand Down

0 comments on commit 0491858

Please sign in to comment.