-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Misaligned read from smem doing TMA store #3602
Labels
Comments
Generated kernel __global__ void nvfuser_none_f0_c0_r0_g0(Tensor<__half, 3, 3> T0, Tensor<__half, 3, 3> T1, const __grid_constant__ TensorMap var0, const __grid_constant__ TensorMap var1, const __grid_constant__ TensorMap var2, Tensor<__half, 2, 2> T3) {
alignas(16) extern __shared__ char array[];
const unsigned smem_offset = 0;
nvfuser_index_t i3;
i3 = ceilDiv(T0.logical_size[2LL], 16);
const TensorMap* ptr4;
ptr4 = &var0;
nvfuser_index_t i5;
i5 = 128 * ((nvfuser_index_t)blockIdx.x);
int i6;
i6 = (int32_t)(i5);
__half* T5 = reinterpret_cast<__half*>(array + smem_offset + 4096);
unsigned i7;
i7 = toSmem(T5);
const TensorMap* ptr8;
ptr8 = &var1;
nvfuser_index_t i9;
i9 = 128 * ((nvfuser_index_t)blockIdx.y);
int i10;
i10 = (int32_t)(i9);
__half* T4 = reinterpret_cast<__half*>(array + smem_offset + 0);
unsigned i11;
i11 = toSmem(T4);
nvfuser_index_t i12;
i12 = ((nvfuser_index_t)threadIdx.y) / 2;
unsigned i13;
i13 = i11 + (2048 * i12);
uint64_t i14;
i14 = 13835058124001640448ULL | ((262143ULL & (uint64_t)(i13)) >> 4ULL);
nvfuser_index_t i15;
i15 = ((nvfuser_index_t)threadIdx.y) % 2;
unsigned i16;
i16 = i7 + (2048 * i15);
uint64_t i17;
i17 = 13835058124001640448ULL | ((262143ULL & (uint64_t)(i16)) >> 4ULL);
nvfuser_index_t i18;
i18 = ((((nvfuser_index_t)threadIdx.x) / 32) * 16) + ((((nvfuser_index_t)threadIdx.x) % 32) % 16);
bool b19;
b19 = ((nvfuser_index_t)threadIdx.x) == 0;
bool b20;
b20 = b19 && (((nvfuser_index_t)threadIdx.y) == 0);
Array<__half, 32, 8> T6;
__half* T7 = reinterpret_cast<__half*>(array + smem_offset + 8208);
float T2[32];
((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))).set(0);
asm volatile("wgmma.fence.sync.aligned;\n");
asm volatile("fence.proxy.async;\n");
#pragma unroll 1
for(nvfuser_index_t i21 = 0; i21 < i3; ++i21) {
int i22;
i22 = (int32_t)((16 * i21));
Array<int, 2, 1> a23;
a23 = Array<int, 2, 1>{i22, i6};
Array<int, 2, 1> a24;
a24 = Array<int, 2, 1>{i22, i10};
uint64_t* T8 = reinterpret_cast<uint64_t*>(array + smem_offset + 8208);
mbarrier::init(toSmem(T8), 1U);
__syncthreads();
if (b20) {
uint64_t i25;
i25 = mbarrier::arriveExpectTX(toSmem(T8), 4096U);
Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr4, a23, toSmem(T8) }), i7);
mbarrier::wait(toSmem(T8), i25);
}
__syncthreads();
mbarrier::inval(toSmem(T8));
uint64_t* T9 = reinterpret_cast<uint64_t*>(array + smem_offset + 8192);
mbarrier::init(toSmem(T9), 1U);
__syncthreads();
if (b20) {
uint64_t i26;
i26 = mbarrier::arriveExpectTX(toSmem(T9), 4096U);
Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr8, a24, toSmem(T9) }), i11);
mbarrier::wait(toSmem(T9), i26);
}
__syncthreads();
mbarrier::inval(toSmem(T9));
asm volatile(
"{\n"
" .reg .pred p0; \n"
" setp.ne.b32 p0, %34, 0;\n"
" wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, %32, %33, p0, %35, %36, %37, %38;\n"
"}\n"
:"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[0]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[1]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[2]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[3]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[4]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[5]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[6]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[7]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[8]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[9]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[10]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[11]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[12]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[13]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[14]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[15]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[16]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[17]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[18]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[19]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[20]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[21]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[22]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[23]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[24]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[25]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[26]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[27]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[28]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[29]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[30]),
"+f"((*reinterpret_cast<Array<float, 32, 1>*>(&T2[0]))[31])
:"l"(i14),
"l"(i17),
"n"((uint32_t)(true)),
"n"(1),
"n"(1),
"n"(0),
"n"(0)
);
asm volatile("wgmma.commit_group.sync.aligned;\n");
asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
}
asm volatile("wgmma.commit_group.sync.aligned;\n");
asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
#pragma unroll
for(nvfuser_index_t i27 = 0; i27 < 8; ++i27) {
nvfuser_index_t i28;
i28 = 4 * i27;
#pragma unroll
for(nvfuser_index_t i29 = 0; i29 < 2; ++i29) {
nvfuser_index_t i30;
i30 = i28 + (2 * i29);
#pragma unroll
for(nvfuser_index_t i31 = 0; i31 < 2; ++i31) {
nvfuser_index_t i32;
i32 = i30 + i31;
T6[i32]
= __float2half(T2[i32]);
}
}
}
#pragma unroll
for(nvfuser_index_t i33 = 0; i33 < 4; ++i33) {
asm volatile(
"stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
:
:"r"((uint32_t)((toSmem(T7) + ((((nvfuser_index_t)threadIdx.y) * 8192) + (((i33 / 4) * 8192) + ((i18 * 128) + (((((((nvfuser_index_t)threadIdx.x) % 32) / 16) + ((i33 % 4) * 2)) ^ (i18 % 8)) * 16))))))),
"r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T6[(8 * i33)]))[0]),
"r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T6[(8 * i33)]))[1]),
"r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T6[(8 * i33)]))[2]),
"r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T6[(8 * i33)]))[3])
);
}
__syncthreads();
asm volatile("fence.proxy.async;\n");
if (b19) {
Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ (&var2), (Array<int, 2, 1>{(int32_t)(((64 * i15) + i5)), (int32_t)(((64 * i12) + i9))}) }), (toSmem(T7) + (8192 * ((nvfuser_index_t)threadIdx.y))));
}
asm volatile("cp.async.bulk.commit_group;\n");
asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(0LL):"memory");
} |
I think the issue might be that we need to promote memory reuse or else the epilogue smem could have an address that's not divisible by 64 since it might be placed on top of an mbarrier. |
jacobhinkle
added a commit
that referenced
this issue
Jan 8, 2025
This updates the default (non-plugin) matmul heuristic to support Hopper matmuls. This change means that we can not run matmuls on Hopper similarly to how we do it on Ampere and Turing, including using the Python interface. I tried to make the default heuristic somewhat thoughtful and not just a placeholder. Here are some notes about the Hopper heuristic in its current form: - I set the macro to Hopper_64_64_16. I intended to always use the largest macro for which the N size divided the problem's N, but this led to lower perf on the handful of examples I looked at. We should benchmark more and find out why this is once we have warp specialization and register stealing fully plumbed in, but for the time being I simply left it at N=64. - Once the instruction tile is set we set the warp tile equal to the instruction tile (we can revisit this in the future). Then to find the CTA tile we double the instruction tile in the M or N dimension until we run out of registers. - We start with 8 circular buffering stages and decrease until the circular buffers fit into smem. - We use `use_smem_epilogue` when possible. Whenever that is possible we _always_ use `promote_prologue_smem_reuse` even if it's not needed. This is to try and avoid bugs like #3602. - I set the tile rasterization order so that the fast axis is the axis with the fewest tiles, which should encourage more L2 hits unless there are tons of tiles in each dimension. - I cannot yet set grid swizzling due to #3671, but I placed a TODO comment and some code to do the proper swizzling. --------- Co-authored-by: Ryan Spring <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The following repro hits a misaligned read from shared memory during the TMA instruction:
The error is
cc @rdspring1 @protonu
The text was updated successfully, but these errors were encountered: