-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tune gemm v3.4] Add xcd-based pid remapping and change back to rocprofv1 #630
Conversation
- set --iters=200 as default. This is enough since the time is stable after the first few runs. - Filter out kernel time that is too large. We use the first kernel time as the threshold. There must be something wrong with the kernel if its elapsedTime is larger than the first run. We need to investigate the reason. For now, just filter them out.
@@ -54,7 +54,7 @@ def get_full_tuning_space(): | |||
block_k_range = [16, 32, 64, 128, 256] | |||
split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] | |||
num_warps_range = [1, 2, 4, 8] | |||
group_m_range = [1, 4, 8, 16, 32] | |||
group_m_range = [1, 2, 4, 8, 16] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason why we get rid of 2, and 32? If we can't enable XCD mapping, we may need 32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me add XCD mapping and re-think about the range for group_m then.
I think we also need enable irregular shapes tuning by removing below two lines |
c550c5b
to
907605a
Compare
@@ -19,8 +30,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, | |||
group_id = pid // num_pid_in_group | |||
first_pid_m = group_id * GROUP_SIZE_M | |||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |||
pid_m = first_pid_m + (pid % group_size_m) | |||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this change make an impact?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really. This affects the swizzling in the very last group when M % GROUP_SIZE_M !=0, which is not usually in our tuning space.
@@ -157,7 +157,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): | |||
if num_warps < 4: | |||
continue | |||
# check if tiling is integer multiple of GEMM size because we have no boundary check | |||
if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0 or K % BLOCK_SIZE_K != 0: | |||
if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
M, N could be irregular as well ?
@@ -366,11 +358,12 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps | |||
grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k | |||
stride_bias = bias.stride(0) if use_bias else 0 | |||
EVEN_K = K % block_k == 0 | |||
num_xcds = 1 if split_k > 1 else 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XCD = 8 is only applicable to MI300X ?
…ofv1 (#630) * Change to rocprofv1 * improve post processing of rocprof results - set --iters=200 as default. This is enough since the time is stable after the first few runs. - Filter out kernel time that is too large. We use the first kernel time as the threshold. There must be something wrong with the kernel if its elapsedTime is larger than the first run. We need to investigate the reason. For now, just filter them out. * Add xcd-based pid remapping * Enable EVEN_K=false for large gemms * Update readme
This PR reverts #613 since there is a severe problem with rocprofv2 described in ticket#228.
The problem is that rocprofv2 will "miss" a lot of kernels in the tuning space. Therefore, sub-optimal config is picked.
We will switch back to rocprofv2 when the issue is resolved.
This PR also enabled xcd-based pid remapping. I need to run more experiments to understand the effects of xcd-based remapping and group_size_m (as described in ticket#229).
To disable xcd-based remapping, change this line from
to