Skip to content

Commit

Permalink
Fix bug and add new NVIDIA configurations
Browse files Browse the repository at this point in the history
  • Loading branch information
pgorlani committed Oct 16, 2023
1 parent dd587dd commit 5e655c9
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 34 deletions.
24 changes: 17 additions & 7 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -469,18 +469,28 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU")
add_gemm_configuration(
"float" 256 "false" "true" "true"
128 8 8 16 16 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true")
endif()
# Non-Joint Matrix specific GEMM Configurations
add_gemm_configuration(
"${data}" 128 "false" "false" "true"
128 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
add_gemm_configuration(
"${data}" 64 "false" "false" "true"
64 8 8 8 8 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
endif()

# Non-Joint Matrix specific GEMM Configurations
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false")
endforeach()
add_gemm_configuration(
"${data}" 128 "false" "true" "true"
128 2 2 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 128 "false" "true" "true"
128 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 128 "false" "true" "true"
128 8 8 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "true" "true"
128 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
endforeach()
else() # default cpu backend
set(supported_types
"float"
Expand Down
58 changes: 46 additions & 12 deletions src/interface/blas3/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,26 +100,50 @@ typename sb_handle_t::event_t _gemm(
_ldc, _stridec, batch_size,
_dependencies);
}
} else {
}
#endif // SB_ENABLE_JOINT_MATRIX

if (batch_size > 1) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, true, 64,
Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b,
s_a, s_b, static_cast<int>(gemm_memory_t::local),
container_0_t, container_1_t, container_2_t, 256, false, true, true,
128, Tile<8, 8, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a,
_t_b, s_a, s_b, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size,
_dependencies);
}

#else // SB_ENABLE_JOINT_MATRIX
else {
} else if (_M <= 256 && _N <= 256) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 128, false, true, true,
128, Tile<2, 2, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a,
_t_b, s_a, s_b, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size,
_dependencies);
} else if (_M <= 1024 && _N <= 1024) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, true, 64,
Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b,
s_a, s_b, static_cast<int>(gemm_memory_t::local),
container_0_t, container_1_t, container_2_t, 128, false, true, true,
128, Tile<4, 4, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a,
_t_b, s_a, s_b, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size,
_dependencies);
} else if (_M <= 2048 && _N <= 2048) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 128, false, true, true,
128, Tile<8, 8, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a,
_t_b, s_a, s_b, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
Expand All @@ -128,7 +152,17 @@ typename sb_handle_t::event_t _gemm(
_ldc, _stridec, batch_size,
_dependencies);
}
#endif

return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, true, true, 128,
Tile<8, 8, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b,
s_a, s_b, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size, _dependencies);
}
} // namespace backend
} // namespace gemm
Expand Down
19 changes: 7 additions & 12 deletions src/operations/blas3/gemm_local.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,6 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
"of the number of columns in a block\n"
" --- this is ensured iff: item_cols | wg_rows");

static_assert(big_tile_rows == big_tile_cols,
"Big tile level dimensions should be square, i.e. tl_rows * "
"block_rows == tl_cols * block_cols");

static_assert(item_rows % packetize_t::packet_size == 0,
"Item rows must be a multiple of the vector packet size");

Expand Down Expand Up @@ -210,7 +206,9 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
*/
PORTBLAS_INLINE index_t
get_num_workgroup_cluster(index_t compute_units) const noexcept {
return ((4 * compute_units - 1) / get_workgroup_cluster() + 1);
return (batch_size_ > 1)
? ((4 * compute_units - 1) / get_workgroup_cluster() + 1)
: 1;
}

/*!
Expand Down Expand Up @@ -289,7 +287,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
const index_t tile_row = (tile_id % tiles_per_col) * tl_rows;
const index_t tile_col = (tile_id / tiles_per_col) * tl_cols;
const index_t wg_row = (tile_row + tile_local_id % tl_rows) * block_rows;
const index_t wg_col = (tile_col + tile_local_id / tl_rows) * block_rows;
const index_t wg_col = (tile_col + tile_local_id / tl_rows) * block_cols;
const bool out_of_range = (wg_row >= m || wg_col >= n);
const bool internal = m - wg_row >= block_rows && n - wg_col >= block_cols;
const index_t vector_offset = internal ? packetize_t::packet_size : 1;
Expand Down Expand Up @@ -373,15 +371,12 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
}
constexpr index_t offset =
(!check_m_limit && !check_n_limit) ? packetize_t::packet_size : 1;
#pragma unroll
for (index_t i = 0; i < item_cols; ++i) {
#pragma unroll
for (index_t j = 0; j < item_rows / offset; ++j) {
const bool in_range =
do_check<check_m_limit>(j * wg_rows * offset < mc) &&
do_check<check_n_limit>(i < nc);
if (in_range) {
#pragma unroll
for (index_t l = 0; l < offset; ++l) {
reg_res[i * item_rows + j * offset + l] =
beta_ * *(C + j * (wg_rows * offset) + l);
Expand All @@ -397,7 +392,6 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
PORTBLAS_INLINE typename std::enable_if<beta_zero>::type scaling_c(
element_t *reg_res, InputPointerType, const index_t &, const index_t &,
const index_t &, const bool) {
#pragma unroll
for (index_t i = 0; i < item_cols * item_rows; ++i) {
reg_res[i] = 0;
}
Expand Down Expand Up @@ -560,9 +554,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
}
constexpr index_t offset =
(!check_m_limit && !check_n_limit) ? packetize_t::packet_size : 1;
#pragma unroll
for (index_t i = 0; i < item_cols; ++i) {
#pragma unroll
for (index_t j = 0; j < item_rows / offset; j++) {
const bool in_range =
do_check<check_m_limit>(j * wg_rows * offset < mc) &&
Expand Down Expand Up @@ -744,6 +736,9 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
// resulting from loop unrollment.
constexpr index_t work_per_load =
!check_m_limit && !check_n_limit ? packetize_t::packet_size : 1;
#if defined NVIDIA_GPU
#pragma unroll
#endif
for (index_t i = 0; i < cl_elems; ++i) {
#pragma unroll
for (index_t j = 0; j < item_rows / work_per_load; ++j) {
Expand Down
2 changes: 1 addition & 1 deletion test/unittest/blas3/blas3_gemm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ const auto LargeBetaNonZeroLDMatch = ::testing::Combine(
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0), // offset
::testing::Values(1), // batch
::testing::Values(253, 511), // m
::testing::Values(253, 511, 1024, 2048, 2200), // m
::testing::Values(257, 511), // n
::testing::Values(253, 511), // k
::testing::Values('n', 't'), // transa
Expand Down
3 changes: 1 addition & 2 deletions tools/auto_tuner/gen/generate_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def is_valid(self):
return (self.tile.group_rows % self.tile.item_cols == 0
and self.tile.group_cols % self.tile.item_rows == 0
and self.tile.group_rows * self.tile.group_cols %
(self.cache_size / 4) == 0 and
self.tile.group_rows * self.tile.item_rows == self.tile.group_cols * self.tile.item_cols)
(self.cache_size / 4) == 0)


class NonLocalGemmStrided(GemmParams):
Expand Down

0 comments on commit 5e655c9

Please sign in to comment.