diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 01845a131..e32763f56 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -284,7 +284,9 @@ def tile_indices(self): class MatMul8bitLt(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): + def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): + state = state or MatmulLtState() + using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt # default of pytorch behavior if inputs are empty ctx.is_empty = False @@ -417,8 +419,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - clone_func = torch.clone if len(output_shape) == 3 else lambda x: x - return clone_func(output.view(output_shape)) + return output.reshape(output_shape) @staticmethod def backward(ctx, grad_output): @@ -442,37 +443,18 @@ def backward(ctx, grad_output): Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: - # CxAt, SAt = F.transform(CAt, formatB, transpose=True) - # C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) - # gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) - # grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) - gradB32, SgradB32 = F.igemmlt( - Cgradt.t(), CAt.t() - ) # issue here in test_linear_serialization w/ has fp16 weights + gradB32, SgradB32 = F.igemmlt(Cgradt.t(), CAt.t()) grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: if state.CBt is not None: - # C32grad, Sgrad = F.transform(Cgrad, "col32") - # if state.CxBt is None: - # state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) - # gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - # grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t()) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) - elif state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) - # elif state.CxB is not None: - # CB = ( - # undo_layout(state.CxB, state.tile_indices) - # .to(ctx.dtype_A) - # .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) - # ) - # grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: raise Exception("State must contain either CBt or CB matrix for backward") diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 7f07778ef..d59fc8778 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2330,7 +2330,7 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): ldb = shapeB[-1] # Activations (batch, tokens, inputs) ldc = shapeC[-1] # Output (batch, tokens, outputs) - assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ B={shapeA}" + assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}" prev_device = A.device torch.cuda.set_device(A.device) @@ -2361,18 +2361,25 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): return out, Sout -def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): +def mm_dequant_torch( + A: torch.Tensor, + quant_state: Optional[Tuple[torch.Size, str]], # TODO: deprecate. (shape, format) + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + new_row_stats=None, # TODO: unused + new_col_stats=None, # TODO: unused + bias: Optional[torch.Tensor] = None, +): assert A.dtype == torch.int32 - compute_dtype = torch.float32 - - A_calc = A.view(-1, A.shape[-1]).to(compute_dtype) - row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) - col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) + A_calc = A.view(-1, A.shape[-1]) + row_stats = row_stats.reshape(-1).unsqueeze(-1) + col_stats = col_stats.reshape(-1).unsqueeze(0) # TODO support out != None - out = A_calc * (row_stats * col_stats) * 6.200124e-5 # .to(torch.float16) + out = A_calc * (row_stats * col_stats) * 6.200124e-5 if bias is not None: # assert bias.dtype == torch.float16 @@ -2381,42 +2388,40 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non return out.to(torch.float16) -def mm_dequant_old(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): +def mm_dequant( + A: torch.Tensor, + quant_state: Optional[Tuple[torch.Size, str]], # TODO: deprecate. (shape, format) + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + new_row_stats=None, # TODO: unused + new_col_stats=None, # TODO: unused + bias: Optional[torch.Tensor] = None, +): assert A.dtype == torch.int32 + if bias is not None: assert bias.dtype == torch.float16 - out_shape = quant_state[0] - if len(out_shape) == 3: - out_shape = (out_shape[0] * out_shape[1], out_shape[2]) if out is None: - out = torch.empty(out_shape, dtype=torch.float16, device=A.device) - if new_row_stats is None: - new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) - if new_col_stats is None: - new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) - assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" - assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" + out = torch.empty_like(A, dtype=torch.float16) - prev_device = pre_call(A.device) ptrA = get_ptr(A) ptrOut = get_ptr(out) ptrRowStats = get_ptr(row_stats) ptrColStats = get_ptr(col_stats) - ptrNewRowStats = get_ptr(new_row_stats) - ptrNewColStats = get_ptr(new_col_stats) ptrBias = get_ptr(bias) - numRows = ct.c_int32(out_shape[0]) - numCols = ct.c_int32(out_shape[1]) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + is_on_gpu([A, row_stats, col_stats, out, bias]) - is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) + prev_device = pre_call(A.device) lib.cdequant_mm_int32_fp16( ptrA, ptrRowStats, ptrColStats, ptrOut, - ptrNewRowStats, - ptrNewColStats, ptrBias, numRows, numCols, @@ -2426,7 +2431,33 @@ def mm_dequant_old(A, quant_state, row_stats, col_stats, out=None, new_row_stats return out -def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): +def get_colrow_absmax( + A: torch.Tensor, + row_stats: torch.Tensor = None, + col_stats: torch.Tensor = None, + nnz_block_ptr: torch.Tensor = None, + threshold=0.0, +): + # Note: prior impl only works with fp16 + assert A.is_floating_point() + + if row_stats is None or col_stats is None: + absA = A.abs().view(-1, A.shape[-1]) # view as 2D + if row_stats is None: + # shape [rows]; unsqueeze(-1) gives [rows,1] + row_stats = absA.amax(dim=1, keepdim=False).float() + if col_stats is None: + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + # TODO: threshold support + if nnz_block_ptr is None and threshold > 0.0: + nnz_block_ptr = torch.zeros_like(A, dtype=torch.int32) + + return row_stats, col_stats, nnz_block_ptr + + +def get_colrow_absmax_old(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): assert A.dtype == torch.float16 device = A.device @@ -2543,19 +2574,27 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) +@torch.compile def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - # TODO: Optimize/write CUDA kernel for this. Currently vectorwise_quant will recalculate row/col stats. + # TODO: Optimize/write CUDA kernel for this # TODO: Support threshold - # if out_col is None: - # out_col = torch.zeros(A.shape, device=A.device, dtype=torch.int8) - # if out_row is None: - # out_row = torch.zeros(A.shape, device=A.device, dtype=torch.int8) + if row_stats is None or col_stats is None: + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) + + scaled_A = A.mul(C) + + # quant_row = torch.round(A * (C / row_stats.unsqueeze(-1))).to(torch.int8) + # quant_col = torch.round(A * (C / col_stats.unsqueeze(0))).to(torch.int8) + quant_row = torch.round(scaled_A / row_stats.unsqueeze(-1)).to(torch.int8) + quant_col = torch.round(scaled_A / col_stats.unsqueeze(0)).to(torch.int8) - out_col, Scol = vectorwise_quant(A, dim=0) - out_row, Srow = vectorwise_quant(A, dim=1) + if out_row is not None: + quant_row = out_row.copy_(quant_row) + if out_col is not None: + quant_col = out_col.copy_(quant_col) - return out_row, out_col, Srow.flatten().float(), Scol.flatten().float(), None # coo_tensor + return quant_row, quant_col, row_stats.flatten().float(), col_stats.flatten().float(), None def double_quant_old(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): diff --git a/csrc/kernels.cu b/csrc/kernels.cu index be7779de1..5bdcb1a41 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -219,7 +219,7 @@ __device__ half dhDequantizeNF4(unsigned char val) } -__device__ float dDequantizeNF4(unsigned char val) +__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { // the values for this tree was generated by test_normal_map_tree @@ -722,7 +722,7 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; - for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) + for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) { if(DATA_TYPE > 0) { @@ -734,7 +734,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; } - local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]); + //local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); __syncthreads(); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); @@ -2291,128 +2292,68 @@ template __global__ void kgetColRowStats(half * __rest #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) -template __global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n) -{ +template +__global__ void kdequant_mm_int32_fp16( + int* __restrict__ const A, + float *__restrict__ const rowStats, + float *__restrict__ const colStats, + half *out, + half *__restrict__ const bias, + const int numRows, + const int numCols, + const int n +) { + const int n_out = numRows * numCols; - // Strategy: To dequantize we need to load col/row statistics. This can be very expensive - // since different row/col stats need to be loaded with each thread. - // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure - // and would lead to low global load utilization. - // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads - // for each thread and this is duplicated by a factor of 32/num-cols-per-thread. - // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock. - // This allows for efficient row/col loading from shared memory within the tile. - // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has - // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts - // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the - // shared memory loads. - - // data is in 32 column-tile major with tile width 32 columns and numRows rows - // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) - // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) - // C2. Compute normalization values and store col values in register - // S1. Store C1 into 16-bit output - // S2. Store col/row statistics of new buffer in shared memory - - // We allow for sub-tiles to span multiple col32 tiles. This is okay - // since the items per thread only rely on a single column statistic. - - - const int n_out = numRows*numCols; - - int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); - // we have tiles of size numRows*32, thus col only increases every numRows - // num_row_tiles is the tiles after which the column increases by 32 - // blockIdx.x is the index of the current tile - int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); - // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached - int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); - - // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS - // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD - // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads. - // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have - // 1024*1024/(128*32) = 256 tiles - // 256 tiles are 256*128*32/4 = 256*1024 threads - - // 1. Figure out how index relates to the start of the sub-tile - // 2. Each thread < SUBTILE_ROWS calculates row index - // 3. Load striped and store in shared memory + int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; + int thread_offset = threadIdx.x * ITEMS_PER_THREAD; int local_values[ITEMS_PER_THREAD]; half local_output[ITEMS_PER_THREAD]; + float local_rowStats[ITEMS_PER_THREAD]; - __shared__ float smem_rowStats[SUBTILE_ROWS]; + float local_colStats[ITEMS_PER_THREAD]; + float local_biasValue[ITEMS_PER_THREAD]; - typedef cub::BlockLoad LoadInt32; - typedef cub::BlockExchange ExchangeInt32; + typedef cub::BlockLoad LoadInt32; __shared__ typename LoadInt32::TempStorage loadint32; - __shared__ typename ExchangeInt32::TempStorage exchangeint32; + int row_idx, col_idx; - // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - float colStat = col >= numCols ? 0.0f : colStats[col]; - float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); - // no block loads for rows for now -- keep it simple - for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) - { - // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? - int row = (base_row+j) % numRows; // wrap around - // each warp accesses the same element, for four consequitive elements - // todo: update description about striped shared memory, it is not needed - // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements - smem_rowStats[j] = rowStats[row]; - } - __syncthreads(); - + #pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; ++j) { - // each block processes SUBTILE_ROWS*32 elements - const int items_per_load = THREADS*ITEMS_PER_THREAD; - const int rows_per_load = items_per_load/32; + row_idx = (block_offset + thread_offset + j) / numCols; + col_idx = (block_offset + thread_offset + j) % numCols; - int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile - int row_offset = 0; - // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed - int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32); - for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load) - { - int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset); - int valid_items = valid_rows*32; - if(valid_items <= 0) // the sub-tile might have more elements than the tile itself - break; + local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; + local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + local_biasValue[j] = ((bias == nullptr) || col_idx >= numCols) ? 0.0f : __half2float(bias[col_idx]); + } - // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) - LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); - ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); + // Each block loads THREADS * ITEMS_PER_THREAD values from A + int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out + ? THREADS * ITEMS_PER_THREAD + : n_out - block_offset; + __syncthreads(); + LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; + #pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; ++j) { + local_output[j] = __float2half( + fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j]) + ); + } - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue); - //absmax_col = fmax(fabsf(local_output[j]), absmax_col); - - // we store data in row major - // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3] - // so that each thread holds ITEMS_PER_THREAD consecutive items for each row - // this way throughput into storage is increased by a factor of ~2x - // for now we use a simple store - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - { - int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols); - if(outIdx< n_out && col < numCols) - out[outIdx] = local_output[j]; + #pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; j++) { + int outIdx = block_offset + thread_offset + j; + if (outIdx < n_out) { + out[outIdx] = local_output[j]; } - - row_offset += rows_per_load; } } - template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols) { // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD @@ -3525,17 +3466,20 @@ template __global__ void kgemm_4bit_inferenc __shared__ T quant_map[16]; T local_absmax = T(0.0f); - for(int i = threadIdx.x; i < 16; i++) - quant_map[i] = T(datatype[i]); + if (threadIdx.x < 16) + quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); + //for(int i = threadIdx.x; i < 16; i++) + //quant_map[i] = T(__ldg(&datatype[i])); __syncthreads(); // A: [1, K] // B: [N, K] for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) { - int inner_idx_halved = inner_idx/2; - int offset_B = ldb*row_B; - int absidx = ((2*offset_B)+inner_idx)/blocksize; + const int inner_idx_halved = inner_idx/2; + const int offset_B = ldb*row_B; + const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize)); + //int absidx = ((2*offset_B)+inner_idx)/blocksize; local_absmax = __ldg(&(absmax[absidx])); if(row_B < M) @@ -3810,7 +3754,7 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>( template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); -template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); +template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index ec6daebe5..1e094dbd2 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -112,9 +112,9 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kdequant_mm_int32_fp16( +template __global__ void kdequant_mm_int32_fp16( int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, - half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); + half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); diff --git a/csrc/ops.cu b/csrc/ops.cu index 8c72b22b4..f3d349a41 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -584,19 +584,15 @@ int fill_up_to_nearest_multiple(int value, int multiple) return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); } -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols) +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols) { - int threads = 512; - int tileCols = fill_up_to_nearest_multiple(numCols, 32); - int n = numRows*tileCols; - int subtile_rows = 128; - int tilesize = 32*subtile_rows; - int num_blocks = numRows/subtile_rows; - num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; - num_blocks = num_blocks*(tileCols/32); - assert(threads <= tilesize); - - kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); + const int threads = 512; + const int num_per_thread = 4; + const int num_per_block = threads * num_per_thread; + const int n = numRows*numCols; + const int num_blocks = (n + num_per_block - 1) / num_per_block; + + kdequant_mm_int32_fp16<<>>(A, rowStats, colStats, out, bias, numRows, numCols, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -861,17 +857,10 @@ template void extractOutliers(char * A, int *idx, char *out, int idx template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); - template void transformRowToFormat(char * A, char *out, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index ab0185242..9ecb93bf2 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -176,11 +176,10 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i long long int strideA, long long int strideB, long long int strideC, int batchCount); template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols); void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 09b9b62a9..0034db262 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -5,6 +5,7 @@ #if BUILD_CUDA #include +uint abc; #endif #if BUILD_MPS // #include @@ -175,32 +176,15 @@ void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRo void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } - int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { - return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); - } - int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { - return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); - } - int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { - return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); - } - int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_turing_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_turing_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_ampere_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_ampere_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_ampere_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } +int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +} +int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +} +int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +} void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } @@ -335,26 +319,6 @@ extern "C" return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - //{ (cublasLtHandle_t)context->m_handle; return 0; } - //{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_turing_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_turing_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_turing_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_turing_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_ampere_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_ampere_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_ampere_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_ampere_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \ { \ @@ -370,8 +334,8 @@ extern "C" MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8) MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32) - void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols) - { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); } + void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols) + { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols); } void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) { getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 522af516c..5052909e7 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -847,45 +847,41 @@ def test_dequant_mm(dim1, dim4, dims, has_bias): @pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) -def test_colrow_absmax(dim1, dim2, dims): +@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp")) +def test_colrow_absmax(dim1, dim2, dims, threshold): for i in range(k): - threshold = 3.0 A = torch.randn(dim1, dim2, device="cuda").half() - A_truncated = A.clone() - A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0 - if dims == 2: - row_stats1, _ = torch.abs(A.float()).max(1) - col_stats1, _ = torch.abs(A.float()).max(0) + + assert dims == 2 + + row_stats1, _ = torch.abs(A.float()).max(1) + col_stats1, _ = torch.abs(A.float()).max(0) + + if threshold > 0.0: + A_truncated = A.clone() + A_truncated[torch.abs(A_truncated) >= threshold] = 0.0 row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1) col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0) - else: - assert False - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) - A_blocked = einops.rearrange( - torch.abs(A), - "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size", - row_tiles=16, - block_size=64 * 4, - ) - nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten() - nnz_block_ptr1 = torch.zeros( - nnz_rows1_counts.shape[0] + 1, - dtype=nnz_rows1_counts.dtype, - device=nnz_rows1_counts.device, - ) - nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) - - torch.testing.assert_close(col_stats1_trunc, col_stats2) - torch.testing.assert_close(row_stats1_trunc, row_stats2) - torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) + nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten() + nnz_block_ptr1 = torch.zeros( + nnz_rows1_counts.shape[0] + 1, + dtype=nnz_rows1_counts.dtype, + device=nnz_rows1_counts.device, + ) + nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) + torch.testing.assert_close(col_stats1_trunc, col_stats2) + torch.testing.assert_close(row_stats1_trunc, row_stats2) + torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) + else: + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) + assert nnz_block_ptr2 is None torch.testing.assert_close(col_stats1, col_stats2) torch.testing.assert_close(row_stats1, row_stats2) - assert nnz_block_ptr2 is None # @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) @@ -1480,9 +1476,9 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): ("batch", "seq", "model", "hidden"), [ # pytest.param(1, 128, 6656, 4 * 6656, id="batch=1, seq=128, model=6656, hidden=26k"), - # pytest.param(2, 128, 6656, 4 * 6656, id="batch=2, seq=128, model=6656, hidden=26k"), + pytest.param(1, 1, 3584, 512, id="batch=1, seq=128, model=3584, hidden=19k"), # pytest.param(4, 128, 6656, 4 * 6656, id="batch=4, seq=128, model=6656, hidden=26k"), - pytest.param(16, 256, 6656, 4 * 6656, id="batch=16, seq=256, model=6656, hidden=26k") + # pytest.param(16, 256, 6656, 4 * 6656, id="batch=16, seq=256, model=6656, hidden=26k") ], ) @pytest.mark.benchmark diff --git a/tests/test_modules.py b/tests/test_modules.py index d5c968395..7369bb1cf 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -335,8 +335,8 @@ def test_linear8bitlt_accumulated_gradient(): loss1.backward() loss2.backward() if i == 2: - assert l1[0].state.CxB is not None - assert l1[1].state.CxB is not None + assert l1[0].state.CB is not None + assert l1[1].state.CB is not None if i > 0 and i % acc_steps == 0: opt1.step()