Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Triton] Refactoring condition in autotuner to be more robust. Added test to make sure crashing Triton configurations are actually skipped and to guard against breaking it. #13895

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions xla/service/gpu/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,6 @@ absl::StatusOr<TileSizeLimit> GetLimits(const HloDotInstruction& dot) {
const int max_k = tsl::NextPowerOfTwoS64(
dot.operand(1)->shape().dimensions(contracting_index));

// TODO(b/337839570): block_k = 16 is bugged in Triton for dots with 8-bit
// input. Setting minimum to 32 instead of 16 for these cases.
// TODO(b/337838200): Write the restriction on the minimum tile size to be
// generic. Currently we only handle the 8-bit case as this was the bug we
// ran into.
return TileSizeLimit{
/*block_m=*/std::max(max_m, kMinTileSize),
/*block_n=*/std::max(max_n, kMinTileSize),
Expand Down Expand Up @@ -634,13 +629,20 @@ absl::StatusOr<std::vector<Config>> GemmFusionAutotunerImpl::GenerateConfigs(

absl::StatusOr<std::vector<TritonGemmConfig>>
GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
bool has_8_bit_operand = HloAnyOf({&dot}, [&](const HloInstruction* node) {
if (node->opcode() != HloOpcode::kConvert) {
return false;
}
auto in_type = node->operand(0)->shape().element_type();
return primitive_util::BitWidth(in_type) == 8;
});
// Retrieve the minimum bit-width participating in the dot. This is needed
// to avoid autotuning configurations that are not supported by Triton. This
// is used to restrict the values for tile_k.
std::vector<const HloInstruction*> converts =
HloFindAll({&dot}, [&](const HloInstruction* node) {
return node->opcode() == HloOpcode::kConvert;
});
int minBitWidth = primitive_util::BitWidth(dot.shape().element_type());
for (auto convert : converts) {
auto in_type = convert->operand(0)->shape().element_type();
auto out_type = convert->shape().element_type();
minBitWidth = std::min({minBitWidth, primitive_util::BitWidth(in_type),
primitive_util::BitWidth(out_type)});
}

std::vector<TritonGemmConfig> result_configs;
TF_ASSIGN_OR_RETURN(TileSizeLimit limits, GetLimits(dot));
Expand Down Expand Up @@ -690,14 +692,12 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
}
config.split_k = std::min(config.split_k, max_split_k);

// TODO(b/337839570): block_k = 16 is bugged in Triton for dots with 8-bit
// input. Setting minimum to 32 instead of 16 for these cases.
// TODO(b/337838200): Write the restriction on the minimum tile size to be
// generic. Currently we only handle the 8-bit case as this was the bug we
// ran into.
if (has_8_bit_operand && config.block_k == kMinTileSize) {
config.block_k *= 2;
}
// TODO(b/337839570): Triton currently has a limitation where it crashes
// on small block_k values depending on the bit-width of the inputs to the
// dot. The logic below accounts for this limitation.
constexpr int kLdmatrixGranularity = 256;
config.block_k =
std::max(config.block_k, kLdmatrixGranularity / minBitWidth);

// Sparse meta should have at least one element per thread.
// Note: only 2:4 structured sparsity is currently supported.
Expand All @@ -706,8 +706,9 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
config.block_m = std::max(config.block_m, 64);
config.num_warps = std::max(config.num_warps, 4);
}
config.block_k =
std::max(config.block_k, kMinTileSize * (has_8_bit_operand ? 4 : 2));
config.block_k = std::max(
config.block_k,
2 * std::max(kMinTileSize, kLdmatrixGranularity / minBitWidth));
int meta_elements = config.block_m * config.block_k / 16;
config.num_warps =
std::min<int>(config.num_warps, meta_elements / WarpSize());
Expand Down
28 changes: 28 additions & 0 deletions xla/service/gpu/gemm_fusion_autotuner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,34 @@ ENTRY e {
)");
}

// TODO(b/337839570): Triton currently has a limitation where it crashes
// on small block_k values depending on the bit-width of the inputs to the
// dot. For this test case, it should skip any block_k values that are <= 16
// since the smallest type has a bit-width of 8.
TEST_F(GemmFusionAutotunerExhaustiveTest, SkipsCrashingTileKConfig) {
std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
HloModule module
ENTRY e {
x = s8[33,33]{1,0} parameter(0)
c = f16[33,33]{1,0} convert(x)
y = f16[33,33]{1,0} parameter(1)
ROOT out = f16[33,33]{1,0} dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)")
.value();
const se::CudaComputeCapability compute_capability{
se::CudaComputeCapability::AMPERE, /*minor=*/0};
TF_ASSERT_OK_AND_ASSIGN(
const std::vector<TritonGemmConfig> configs,
GetPossibleMatmulAutotuneConfigs(
*Cast<HloDotInstruction>(
module->entry_computation()->root_instruction()),
compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
EXPECT_TRUE(std::all_of(
configs.begin(), configs.end(),
[](const TritonGemmConfig& config) { return config.block_k > 16; }));
}

class GemmFusionAutotunerDisableSplitK : public GemmFusionAutotunerTest {
public:
DebugOptions GetDebugOptionsForTest() override {
Expand Down
Loading