Skip to content

Commit

Permalink
[XLA:GPU] Skip small tile sizes for sparse gemms on Ampere as well. E…
Browse files Browse the repository at this point in the history
…nable the JAX test again that has been failing.

PiperOrigin-RevId: 695360850
  • Loading branch information
chsigg authored and Google-ML-Automation committed Nov 11, 2024
1 parent b70917b commit 5b5d48f
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -887,8 +887,6 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {

// Triton configurations are adjusted and deduplicated.
absl::flat_hash_set<TritonGemmConfig> added;
bool is_hopper =
!config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
for (TritonGemmConfig& config : triton_configs) {
config.block_m = std::min(config.block_m, limits.block_m);
config.block_n = std::min(config.block_n, limits.block_n);
Expand All @@ -911,10 +909,8 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
// Sparse meta should have at least one element per thread.
// Note: only 2:4 structured sparsity is currently supported.
if (dot.sparse_operands()) {
if (is_hopper) {
config.block_m = std::max(config.block_m, 64);
config.num_warps = std::max(config.num_warps, 4);
}
config.block_m = std::max(config.block_m, 64);
config.num_warps = std::max(config.num_warps, 4);
config.block_k = std::max(
config.block_k,
2 * std::max(kMinTileSize, kLdmatrixGranularity / minBitWidth));
Expand Down

0 comments on commit 5b5d48f

Please sign in to comment.