Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fixes and new NVIDIA gemm configurations #469

Merged
merged 1 commit into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -469,18 +469,28 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU")
add_gemm_configuration(
"float" 256 "false" "true" "true"
128 8 8 16 16 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true")
endif()
# Non-Joint Matrix specific GEMM Configurations
add_gemm_configuration(
"${data}" 128 "false" "false" "true"
128 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
add_gemm_configuration(
"${data}" 64 "false" "false" "true"
64 8 8 8 8 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
endif()

# Non-Joint Matrix specific GEMM Configurations
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false")
endforeach()
add_gemm_configuration(
"${data}" 128 "false" "true" "true"
128 2 2 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 128 "false" "true" "true"
128 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 128 "false" "true" "true"
128 8 8 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "true" "true"
128 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
endforeach()
else() # default cpu backend
set(supported_types
"float"
Expand Down
58 changes: 46 additions & 12 deletions src/interface/blas3/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,26 +100,50 @@ typename sb_handle_t::event_t _gemm(
_ldc, _stridec, batch_size,
_dependencies);
}
} else {
}
#endif // SB_ENABLE_JOINT_MATRIX

if (batch_size > 1) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, true, 64,
Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b,
s_a, s_b, static_cast<int>(gemm_memory_t::local),
container_0_t, container_1_t, container_2_t, 256, false, true, true,
128, Tile<8, 8, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a,
_t_b, s_a, s_b, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size,
_dependencies);
}

#else // SB_ENABLE_JOINT_MATRIX
else {
} else if (_M <= 256 && _N <= 256) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 128, false, true, true,
128, Tile<2, 2, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a,
_t_b, s_a, s_b, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size,
_dependencies);
} else if (_M <= 1024 && _N <= 1024) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, true, 64,
Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b,
s_a, s_b, static_cast<int>(gemm_memory_t::local),
container_0_t, container_1_t, container_2_t, 128, false, true, true,
128, Tile<4, 4, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a,
_t_b, s_a, s_b, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size,
_dependencies);
} else if (_M <= 2048 && _N <= 2048) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 128, false, true, true,
128, Tile<8, 8, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a,
_t_b, s_a, s_b, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
Expand All @@ -128,7 +152,17 @@ typename sb_handle_t::event_t _gemm(
_ldc, _stridec, batch_size,
_dependencies);
}
#endif

return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, true, true, 128,
Tile<8, 8, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b,
s_a, s_b, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size, _dependencies);
}
} // namespace backend
} // namespace gemm
Expand Down
19 changes: 7 additions & 12 deletions src/operations/blas3/gemm_local.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,6 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
"of the number of columns in a block\n"
" --- this is ensured iff: item_cols | wg_rows");

static_assert(big_tile_rows == big_tile_cols,
"Big tile level dimensions should be square, i.e. tl_rows * "
"block_rows == tl_cols * block_cols");

static_assert(item_rows % packetize_t::packet_size == 0,
"Item rows must be a multiple of the vector packet size");

Expand Down Expand Up @@ -210,7 +206,9 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
*/
PORTBLAS_INLINE index_t
get_num_workgroup_cluster(index_t compute_units) const noexcept {
return ((4 * compute_units - 1) / get_workgroup_cluster() + 1);
return (batch_size_ > 1)
? ((4 * compute_units - 1) / get_workgroup_cluster() + 1)
: 1;
}

/*!
Expand Down Expand Up @@ -289,7 +287,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
const index_t tile_row = (tile_id % tiles_per_col) * tl_rows;
const index_t tile_col = (tile_id / tiles_per_col) * tl_cols;
const index_t wg_row = (tile_row + tile_local_id % tl_rows) * block_rows;
const index_t wg_col = (tile_col + tile_local_id / tl_rows) * block_rows;
const index_t wg_col = (tile_col + tile_local_id / tl_rows) * block_cols;
const bool out_of_range = (wg_row >= m || wg_col >= n);
const bool internal = m - wg_row >= block_rows && n - wg_col >= block_cols;
const index_t vector_offset = internal ? packetize_t::packet_size : 1;
Expand Down Expand Up @@ -373,15 +371,12 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
}
constexpr index_t offset =
(!check_m_limit && !check_n_limit) ? packetize_t::packet_size : 1;
#pragma unroll
for (index_t i = 0; i < item_cols; ++i) {
#pragma unroll
for (index_t j = 0; j < item_rows / offset; ++j) {
const bool in_range =
do_check<check_m_limit>(j * wg_rows * offset < mc) &&
do_check<check_n_limit>(i < nc);
if (in_range) {
#pragma unroll
for (index_t l = 0; l < offset; ++l) {
reg_res[i * item_rows + j * offset + l] =
beta_ * *(C + j * (wg_rows * offset) + l);
Expand All @@ -397,7 +392,6 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
PORTBLAS_INLINE typename std::enable_if<beta_zero>::type scaling_c(
element_t *reg_res, InputPointerType, const index_t &, const index_t &,
const index_t &, const bool) {
#pragma unroll
for (index_t i = 0; i < item_cols * item_rows; ++i) {
reg_res[i] = 0;
}
Expand Down Expand Up @@ -560,9 +554,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
}
constexpr index_t offset =
(!check_m_limit && !check_n_limit) ? packetize_t::packet_size : 1;
#pragma unroll
for (index_t i = 0; i < item_cols; ++i) {
#pragma unroll
for (index_t j = 0; j < item_rows / offset; j++) {
const bool in_range =
do_check<check_m_limit>(j * wg_rows * offset < mc) &&
Expand Down Expand Up @@ -744,6 +736,9 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
// resulting from loop unrollment.
constexpr index_t work_per_load =
!check_m_limit && !check_n_limit ? packetize_t::packet_size : 1;
#if defined NVIDIA_GPU
#pragma unroll
#endif
for (index_t i = 0; i < cl_elems; ++i) {
#pragma unroll
for (index_t j = 0; j < item_rows / work_per_load; ++j) {
Expand Down
2 changes: 1 addition & 1 deletion test/unittest/blas3/blas3_gemm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ const auto LargeBetaNonZeroLDMatch = ::testing::Combine(
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0), // offset
::testing::Values(1), // batch
::testing::Values(253, 511), // m
::testing::Values(253, 511, 1024, 2048, 2200), // m
::testing::Values(257, 511), // n
::testing::Values(253, 511), // k
::testing::Values('n', 't'), // transa
Expand Down
3 changes: 1 addition & 2 deletions tools/auto_tuner/gen/generate_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def is_valid(self):
return (self.tile.group_rows % self.tile.item_cols == 0
and self.tile.group_cols % self.tile.item_rows == 0
and self.tile.group_rows * self.tile.group_cols %
(self.cache_size / 4) == 0 and
self.tile.group_rows * self.tile.item_rows == self.tile.group_cols * self.tile.item_cols)
(self.cache_size / 4) == 0)


class NonLocalGemmStrided(GemmParams):
Expand Down