Skip to content

Commit

Permalink
PR #16841: Delete FP8 Scaling Factors in GEMM Rewriter
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#16841

Removes the scaling factors of C and D (matrix bias and result) from FP8 Custom Calls created in the GEMM rewriter when their data types are not FP8. See openxla/xla#15795.
Copybara import of the project:

--
fd9750fa8474fe72fe641c7b3bc005ff30396e0a by Philipp Hack <[email protected]>:

Removes superfluous FP8 scaling factors in GEMM rewriter.

Merging this change closes #16841

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16841 from philipphack:u_fp8_scales_xla fd9750fa8474fe72fe641c7b3bc005ff30396e0a
PiperOrigin-RevId: 679766659
  • Loading branch information
philipphack authored and tensorflower-gardener committed Sep 27, 2024
1 parent 4e9171c commit b026e99
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 108 deletions.
2 changes: 2 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ build:cuda_clang --copt=-Qunused-arguments
# major release. Example: sm_80 kernels can run on sm_89 GPUs but
# not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs.
build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90"
# Permit newer CUDA versions than Clang is aware of
build:cuda_clang --copt="-Wno-unknown-cuda-version"
# Set lld as the linker.
build:cuda_clang --host_linkopt="-fuse-ld=lld"
build:cuda_clang --host_linkopt="-lm"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ def _create_libcuda_symlinks(
repository_ctx.symlink(nvidia_driver_path, "lib/libcuda.so.1")
repository_ctx.symlink("lib/libcuda.so.1", "lib/libcuda.so")

def _create_cuda_header_symlinks(repository_ctx):
if repository_ctx.name == "cuda_nvcc":
repository_ctx.symlink("../cuda_cudart/include/cuda.h", "include/cuda.h")

def use_local_path(repository_ctx, local_path, dirs):
# buildifier: disable=function-docstring-args
"""Creates repository using local redistribution paths."""
Expand Down Expand Up @@ -339,6 +343,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx):
repository_ctx,
lib_name_to_version_dict,
)
_create_cuda_header_symlinks(repository_ctx)
repository_ctx.file("version.txt", major_version)

def _cuda_repo_impl(repository_ctx):
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ build:cuda_clang --copt=-Qunused-arguments
# major release. Example: sm_80 kernels can run on sm_89 GPUs but
# not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs.
build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90"
# Permit newer CUDA versions than Clang is aware of
build:cuda_clang --copt="-Wno-unknown-cuda-version"
# Set lld as the linker.
build:cuda_clang --host_linkopt="-fuse-ld=lld"
build:cuda_clang --host_linkopt="-lm"
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/third_party/tsl/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ build:cuda_clang --copt=-Qunused-arguments
# major release. Example: sm_80 kernels can run on sm_89 GPUs but
# not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs.
build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90"
# Permit newer CUDA versions than Clang is aware of
build:cuda_clang --copt="-Wno-unknown-cuda-version"
# Set lld as the linker.
build:cuda_clang --host_linkopt="-fuse-ld=lld"
build:cuda_clang --host_linkopt="-lm"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ def _create_libcuda_symlinks(
repository_ctx.symlink(nvidia_driver_path, "lib/libcuda.so.1")
repository_ctx.symlink("lib/libcuda.so.1", "lib/libcuda.so")

def _create_cuda_header_symlinks(repository_ctx):
if repository_ctx.name == "cuda_nvcc":
repository_ctx.symlink("../cuda_cudart/include/cuda.h", "include/cuda.h")

def use_local_path(repository_ctx, local_path, dirs):
# buildifier: disable=function-docstring-args
"""Creates repository using local redistribution paths."""
Expand Down Expand Up @@ -339,6 +343,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx):
repository_ctx,
lib_name_to_version_dict,
)
_create_cuda_header_symlinks(repository_ctx)
repository_ctx.file("version.txt", major_version)

def _cuda_repo_impl(repository_ctx):
Expand Down
23 changes: 11 additions & 12 deletions third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -740,8 +740,7 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk(

absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(
const HloCustomCallInstruction* instr) {
TF_RET_CHECK(instr->operand_count() == 6 || instr->operand_count() == 7 ||
instr->operand_count() == 8);
TF_RET_CHECK(instr->operand_count() > 3 && instr->operand_count() < 8);
TF_ASSIGN_OR_RETURN(const auto gpu_config,
instr->backend_config<xla::gpu::GpuBackendConfig>());
const xla::gpu::GemmBackendConfig& config = gpu_config.gemm_backend_config();
Expand Down Expand Up @@ -777,22 +776,22 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice b_scale,
GetAllocationSliceForHlo(instr->operand(a_scale_index + 1)));

// cublasLT requires c_scale/d_scale to be null when C/D is not FP8.
// Currently, C cannot be FP8.
BufferAllocation::Slice c_scale, d_scale;
#if GOOGLE_CUDA
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice c_scale,
GetAllocationSliceForHlo(instr->operand(a_scale_index + 2)));
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice d_scale,
GetAllocationSliceForHlo(instr->operand(a_scale_index + 3)));
#else // TENSORFLOW_USE_ROCM
BufferAllocation::Slice c_scale;
BufferAllocation::Slice d_scale;
if (instr->shape().tuple_shapes(0).element_type() == F8E4M3FN ||
instr->shape().tuple_shapes(0).element_type() == F8E5M2) {
TF_ASSIGN_OR_RETURN(d_scale,
GetAllocationSliceForHlo(instr->operands().back()));
}
#endif

BufferAllocation::Slice bias;
if (has_vector_bias) {
TF_ASSIGN_OR_RETURN(
bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 4)));
bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 2)));
}

BufferAllocation::Slice d_amax;
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/matmul_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ bool IsTf32Allowed(PrecisionConfig::Algorithm algorithm,
if (has_vector_bias) {
int vector_bias_index = has_matrix_bias ? 3 : 2;
if (primitive_util::IsF8Type(lhs_shape.element_type())) {
// FP8 gemms have 4 scales as inputs which come before the vector bias.
vector_bias_index += 4;
// FP8 gemms have 2 scales as inputs which come before the vector bias.
vector_bias_index += 2;
}
vector_bias_shape = gemm->operand(vector_bias_index)->shape();
}
Expand Down
27 changes: 18 additions & 9 deletions third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1083,12 +1083,18 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {

// cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32
// format. Set the factors to one when no scaling factors were captured.
Literal one_literal = LiteralUtil::One(F32);
HloInstruction *one = instr->AddInstruction(
HloInstruction::CreateConstant(one_literal.Clone()));
std::array<bool, 2> mult_scale{a.mult_scale, b.mult_scale};
std::array<HloInstruction *, 2> scales{a.scale, b.scale}, inv_scales,
scales_f32;
HloInstruction *one_constant = nullptr;
auto one = [&one_constant, instr]() -> HloInstruction * {
if (!one_constant) {
one_constant = instr->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::One(F32)));
}
return one_constant;
};

for (int i = 0; i < scales.size(); ++i) {
if (scales[i]) {
if (!ShapeUtil::IsScalar(scales[i]->shape())) {
Expand All @@ -1099,15 +1105,15 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
}
if (!mult_scale[i]) {
inv_scales[i] = instr->AddInstruction(HloInstruction::CreateBinary(
scales[i]->shape(), HloOpcode::kDivide, one, scales[i]));
scales[i]->shape(), HloOpcode::kDivide, one(), scales[i]));
}
scales_f32[i] = mult_scale[i] ? scales[i] : inv_scales[i];
if (scales_f32[i]->shape().element_type() != F32) {
scales_f32[i] = instr->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeScalarShape(F32), scales_f32[i]));
}
} else {
scales_f32[i] = one;
scales_f32[i] = one();
}
}

Expand Down Expand Up @@ -1249,7 +1255,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
PadShapeToMultipleOf16(instr->shape(), out_batch_dims);

std::vector<HloInstruction *> operands_list = {
a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1], one, one};
a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1]};

HloInstruction *new_custom_call =
instr->AddInstruction(HloInstruction::CreateCustomCall(
Expand Down Expand Up @@ -1415,13 +1421,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
}
}

// If necessary, invert the scaling factor of D and convert to F32.
// If necessary, invert the scaling factor of D and convert to F32. When no
// scaling factor was captured, set the factor to one.
if (d_scale) {
TF_ASSIGN_OR_RETURN(d_scale,
InvertAndConvertScalar(d_scale, !mult_scale));
TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(
gemm_backend_config.beta() == 0.0 ? 5 : 6, d_scale));
} else {
d_scale = instr->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::One(F32)));
}
existing_gemm->AppendOperand(d_scale);

// If present, elide the calculation of the maximum of the absolute values
// of the result of the GEMM.
Expand Down
Loading

0 comments on commit b026e99

Please sign in to comment.