Skip to content

Commit

Permalink
Unrolled the Global Memory write loops
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-tanvir-1211 committed Jan 5, 2024
1 parent d1beeff commit 774286d
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/operations/blas3/gemm_local_joint_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
constexpr index_t nc_conditional =
frags_per_sg > 1 ? tile_type::joint_matrix_N : block_cols;

#pragma unroll
for (index_t frag = 0; frag < frags_per_sg; frag++,
C += output_global_outer_offset,
nc -= tile_type::joint_matrix_N) {
Expand All @@ -559,6 +560,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
if constexpr (check_m_limit && check_n_limit) {
if (mc >= block_rows && nc >= nc_conditional) {
const index_t loop_limit = nc_conditional / rows_per_iter;
#pragma unroll
for (int i = 0; i < loop_limit; i++,
new_C += output_global_inner_offset,
new_scratch += output_local_inner_offset) {
Expand All @@ -574,6 +576,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
if (mc < block_rows && nc < nc_conditional) {
if (item_id < mc) {
const index_t loop_limit = nc;
#pragma unroll
for (int i = 0; i < loop_limit; i++,
new_C += output_global_inner_offset,
new_scratch += output_local_inner_offset) {
Expand All @@ -590,6 +593,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
if (mc < block_rows) {
if (it_mod_brows < mc) {
const index_t loop_limit = nc_conditional / rows_per_iter;
#pragma unroll
for (int i = 0; i < loop_limit; i++,
new_C += output_global_inner_offset,
new_scratch += output_local_inner_offset) {
Expand All @@ -606,6 +610,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
if (nc < nc_conditional) {
if (item_id < block_rows) {
const index_t loop_limit = nc;
#pragma unroll
for (int i = 0; i < loop_limit; i++,
new_C += output_global_inner_offset,
new_scratch += output_local_inner_offset) {
Expand All @@ -621,6 +626,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
}
} else {
const index_t loop_limit = nc_conditional / rows_per_iter;
#pragma unroll
for (int i = 0; i < loop_limit; i++,
new_C += output_global_inner_offset,
new_scratch += output_local_inner_offset) {
Expand Down

0 comments on commit 774286d

Please sign in to comment.