Skip to content

Commit

Permalink
Fixed non-local transpose-add kernel for edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki authored and s-Nick committed Jun 14, 2023
1 parent 9841e4a commit c47c3d6
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions src/operations/extension/transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,31 +214,35 @@ TransposeAdd<both_trans, Tile_size, local_memory, in1_t, in2_t, out_t,

if constexpr (both_trans) {
// Compute sum & then transpose
auto j = idx / N_;
auto i = idx - j * N_;
auto j = idx / N_pad_;
auto i = idx - j * N_pad_;

auto in_index_a = i + j * lda_;
auto in_index_b = i + j * ldb_;

auto temp_sum = alpha_ * A[in_index_a] + beta_ * B[in_index_b];
if (i < N_ && j < M_) {
auto temp_sum = alpha_ * A[in_index_a] + beta_ * B[in_index_b];

auto out_index_c = i * ldc_ + j;
auto out_index_c = i * ldc_ + j;

C[out_index_c] = temp_sum;
C[out_index_c] = temp_sum;
}

} else {
// Transpose A then compute sum (Applies to B as well)
auto j = idx / M_;
auto i = idx - j * M_;
auto j = idx / M_pad_;
auto i = idx - j * M_pad_;

auto in_index_at = j + i * lda_;
auto in_index_b = i + j * ldb_;

auto temp_sum = alpha_ * A[in_index_at] + beta_ * B[in_index_b];
if (i < M_ && j < N_) {
auto temp_sum = alpha_ * A[in_index_at] + beta_ * B[in_index_b];

auto out_index_c = i + j * ldc_;
auto out_index_c = i + j * ldc_;

C[out_index_c] = temp_sum;
C[out_index_c] = temp_sum;
}
}
}
}
Expand Down

0 comments on commit c47c3d6

Please sign in to comment.