forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tune gemm v3.4] Add xcd-based pid remapping and change back to rocprofv1 #630
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,7 +54,7 @@ def get_full_tuning_space(): | |
block_k_range = [16, 32, 64, 128, 256] | ||
split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] | ||
num_warps_range = [1, 2, 4, 8] | ||
group_m_range = [1, 4, 8, 16, 32] | ||
group_m_range = [1, 2, 4, 8, 16, 32] | ||
# For now we see better perf with num_stages=0 for all gemm configs we care | ||
# But keep this explicit so that we do not forget we may need to set it to | ||
# other values in the future | ||
|
@@ -157,7 +157,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): | |
if num_warps < 4: | ||
continue | ||
# check if tiling is integer multiple of GEMM size because we have no boundary check | ||
if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0 or K % BLOCK_SIZE_K != 0: | ||
if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. M, N could be irregular as well ? |
||
continue | ||
|
||
pruned_configs.append(config) | ||
|
@@ -169,20 +169,15 @@ def need_split_k(SIZE_M, SIZE_N, SIZE_K): | |
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 | ||
|
||
|
||
def extract_kernel_time(M, N, K, config, df, bias_size): | ||
# Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19 | ||
# once the bug(https://github.com/ROCm/rocprofiler/issues/144) fixed, we should | ||
# not need below two lines | ||
cols = [ | ||
'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', 'tid', 'grd', 'wgr', 'lds', 'scr', | ||
'arch_vgpr', 'accum_vgpr', 'sgpr', 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs' | ||
] | ||
df.columns = cols | ||
def extract_kernel_time(M, N, K, config, df): | ||
configStr = gen_configStr(config) | ||
filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy() | ||
filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs'] | ||
meanTime = filtered_df['DurationNs'].tail(100).mean() | ||
return config, meanTime | ||
df = df[df['KernelName'].str.contains(configStr)] | ||
|
||
first_value = df['DurationNs'].iloc[0] | ||
filtered_data = df['DurationNs'][df['DurationNs'] <= first_value] | ||
new_meanTime = filtered_data.tail(100).mean() | ||
|
||
return config, new_meanTime | ||
|
||
|
||
def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose): | ||
|
@@ -197,7 +192,7 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose): | |
if verbose: | ||
print(f"profiling {kernel_name} on GPU {gpuid}") | ||
run_bash_command_wrapper( | ||
f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {get_filename_profile_driver(M, N, K, jobId)}", | ||
f"rocprof --stats -o results_{jobId}.csv python {get_filename_profile_driver(M, N, K, jobId)}", | ||
capture=(verbose < 2)) | ||
jobId += ngpus | ||
|
||
|
@@ -244,13 +239,10 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type | |
thread_pool = multiprocessing.Pool(processes=num_threads) | ||
tasks = [] | ||
idx = 0 | ||
df_prof = [ | ||
pd.read_csv(f"results_{i}.csv", skiprows=1, header=None, delimiter=',', quotechar='"', escapechar='\\') | ||
for i in range(jobs) | ||
] | ||
df_prof = [pd.read_csv(f"results_{i}.csv") for i in range(jobs)] | ||
for config in configs: | ||
file_idx = idx % jobs | ||
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx], bias_size))] | ||
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx]))] | ||
idx += 1 | ||
thread_pool.close() | ||
thread_pool.join() | ||
|
@@ -366,11 +358,12 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps | |
grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k | ||
stride_bias = bias.stride(0) if use_bias else 0 | ||
EVEN_K = K % block_k == 0 | ||
num_xcds = 1 if split_k > 1 else 8 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. XCD = 8 is only applicable to MI300X ? |
||
matmul_kernel[grid](a, b, c, bias, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), | ||
c.stride(1), stride_bias=stride_bias, BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, | ||
BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, SPLIT_K=split_k, num_warps=num_warps, | ||
num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, | ||
kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K) | ||
kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds) | ||
return c | ||
|
||
|
||
|
@@ -441,7 +434,7 @@ def parse_args(): | |
parser.add_argument("--num_threads", type=int, default=32, | ||
help="number of threads to use for kernel compilation and post processing") | ||
parser.add_argument("--jobs", type=int, default=1, help="number of tasks during the profiling process") | ||
parser.add_argument("--iters", type=int, default=1000, help="number of iterations used in --benchmark mode") | ||
parser.add_argument("--iters", type=int, default=200, help="number of iterations used in --benchmark mode") | ||
parser.add_argument("--init_type", type=str, default='randn', choices=['randn', 'hpl', 'trig_float', 'zeros'], | ||
help="Input tensor initialization (default normal distribution)") | ||
parser.add_argument( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this change make an impact?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really. This affects the swizzling in the very last group when M % GROUP_SIZE_M !=0, which is not usually in our tuning space.