Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SM75 (Turing) support for FP6 kernel #942

Merged
merged 5 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
def benchmark(m: int, n: int, k: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change intentional? usually we talk about matmuls as mkn i.e. m x k activation and k x n weight (odd to reverse them and i'm unsure if benchmarks previously were assuming the other ordering)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this again, I indeed made a mistake. The benchmark results are still correct, it's just the list of shapes are different.

I took the benchmark shapes from here: https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/run.sh - It's a bit confusing since the author use different variable names...

The code generating the list of shapes (under __name__ == "__main__") are correct (follow the author), and it calls benchmark(m, n, k). If you think we should benchmark a different sets of shapes, it should be good too!

In summary, this change corrects my previous mistake.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's intentional. The function signature was def benchmark(m: int, k: int, n: int):, but arguments were passed as (m, n, k), so I thought that that was unnecessarily confusing and wanted to change the ordering in either the function call or the function signature. In the function itself, the shapes become m x k for the activation and n x k for the weight.

I see one benchmark example (benchmark_gpu_sparsity, see below) where the ordering is m, k, n, so let me change the function signature back to that ordering for consistency.

def run_gpu_sparse_benchmark(m, k, n, args):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed your comment before I posted my own @gau-nernst. Thanks for clarifying! I actually noticed that m gets passed as n to the actual kernel, which is slightly confusing. If you don't mind, I'll change this for consistency. I don't think it should affect the results, expect that m will be switched by n in the performance table.

Copy link
Collaborator

@gau-nernst gau-nernst Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tobiasvanderwerff You mean k and n right? Your current change looks correct. Yea it doesn't affect the results, it will only show results for different shapes instead.

Copy link
Contributor Author

@tobiasvanderwerff tobiasvanderwerff Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant is slightly different @gau-nernst. I'm referring to the fact that the original authors do some odd switching of the shapes in fp6_linear.cu. The arguments that get passed are _in_feats (activations of shape m x k) and _weights (shape n x k), but then they unpack the shapes as M = _weights.size(0), K = _in_feats.shape(0), N = _in_feats.shape(1) (see below).

int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
int num_out_channels = _weights.size(0);
TORCH_CHECK(num_in_channels % 64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels);
TORCH_CHECK((num_in_channels / 8 * NBITS) == _weights.size(1)); // Making sure the K dimension is matched.
//
int M = num_out_channels;
int K = num_in_channels;
int N = num_in_feats;

So even though we pass the arguments correctly to the benchmark function as m, k, n, the names get switched inside the kernel. Anyway, this is mainly confusing when debugging the kernel, but it might actually be fine to just leave it as is.

float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayoutType(3, 2))
fp16_weight = fp6_weight.dequantize(torch.half)
Expand Down
2 changes: 1 addition & 1 deletion torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
//
// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu

#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750 // at least Turing

#include "kernel_matmul.cuh"
#include "kernel_reduction.cuh"
Expand Down
6 changes: 6 additions & 0 deletions torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
for(int j=0; j<REG_PER_THREAD_C_TENSOR_16_16; j++)
c[i][j] = 0.0f;
//
#if __CUDA_ARCH__ >= 800
cp_async_wait_all();
#endif
__syncthreads();

/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -175,12 +177,16 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
if(USE_SEG_4BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy);
// copying B tile from GlobalMemory to SharedMemory
CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS> (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy);
#if __CUDA_ARCH__ >= 800
cp_async_group_commit();
#endif
core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs
core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2);
core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3);
// Barriers and Synchronizations
#if __CUDA_ARCH__ >= 800
cp_async_wait_group<PIPELINE_LEVEL_GMEM-2>();
#endif
__syncthreads();
core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0);
// Updating global PTRs
Expand Down
31 changes: 31 additions & 0 deletions torchao/csrc/cuda/fp6_llm/ptx_mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[
assert( warp_start_col==0 );
#endif

#if __CUDA_ARCH__ == 750
if (TilingConfig::WARP_COL_MMA_TENSORS==1) {
// For .target sm_75, all threads must contain valid addresses for the 'ldmatrix' op. below. Otherwise, the behavior is undefined.
// See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-load-instruction-ldmatrix
// To avoid this, we make threads 16-32 point to the same smem addresses as threads 0-15 by changing the lane id.
lane_id = lane_id % 16;
}
#endif
int col = (lane_id%8) + (lane_id/16)*8;
int row = (lane_id%16) / 8 * 8;
uint32_t smem_local_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][slice_id*MMA_16 + row]));
Expand All @@ -80,6 +88,28 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[
__device__ __forceinline__ void
MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
{
#if __CUDA_ARCH__ == 750
// m16n8k16 op. requires >=sm_80, so instead we use two m16n8k8 ops.
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{ %0, %1, %2, %3},"
"{ %4, %5},"
"{ %6 },"
"{ %7, %8, %9, %10 };"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]),
"r"(b[0]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{ %0, %1, %2, %3},"
"{ %4, %5},"
"{ %6 },"
"{ %7, %8, %9, %10 };"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[2]), "r"(a[3]),
"r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));

#else
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{ %0, %1, %2, %3},"
"{ %4, %5, %6, %7 },"
Expand All @@ -89,6 +119,7 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
"r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
#endif
}

#endif
17 changes: 16 additions & 1 deletion torchao/csrc/cuda/fp6_llm/utils_gmem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,15 @@ __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR,
GPTR_HALF += lane_id*8;
#pragma unroll
for(int i=0; i<SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE/16; i++) {
#if __CUDA_ARCH__ == 750
if (pred_guard) {
float4* SPTR_VEC = reinterpret_cast<float4*>(SPTR_HALF);
const float4* GPTR_VEC = reinterpret_cast<const float4*>(GPTR_HALF);
SPTR_VEC[0] = GPTR_VEC[0];
}
#else
cp_async<16>( SPTR_HALF, GPTR_HALF, pred_guard);
#endif
SPTR_HALF += 256; // Forward 512 Bytes
GPTR_HALF += 256; // Forward 512 Bytes
}
Expand Down Expand Up @@ -82,8 +90,15 @@ __device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ Shar
#pragma unroll
for (int i = 0; i < MaxIteration; i++) {
bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred;
#if __CUDA_ARCH__ == 750
if (AsyncCopyPred) {
float4* SharedPtrVec = reinterpret_cast<float4*>(&(*SharedPTR)[line_offset]);
const float4* GlobalPtrVec = reinterpret_cast<const float4*>(GlobalPTR);
SharedPtrVec[0] = GlobalPtrVec[0];
}
#else
cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred);
//
#endif
GlobalPTR += NumOfGroups * GlobalStride;
SharedPTR += NumOfGroups;
}
Expand Down
Loading