Skip to content

Commit

Permalink
PR #16841: Delete FP8 Scaling Factors in GEMM Rewriter
Browse files Browse the repository at this point in the history
Imported from GitHub PR #16841

Removes the scaling factors of C and D (matrix bias and result) from FP8 Custom Calls created in the GEMM rewriter when their data types are not FP8. See #15795.
Copybara import of the project:

--
fd9750f by Philipp Hack <[email protected]>:

Removes superfluous FP8 scaling factors in GEMM rewriter.

Merging this change closes #16841

COPYBARA_INTEGRATE_REVIEW=#16841 from philipphack:u_fp8_scales_xla fd9750f
PiperOrigin-RevId: 679784586
  • Loading branch information
philipphack authored and Google-ML-Automation committed Sep 28, 2024
1 parent 683725f commit 3c5c920
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 108 deletions.
23 changes: 11 additions & 12 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -740,8 +740,7 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk(

absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(
const HloCustomCallInstruction* instr) {
TF_RET_CHECK(instr->operand_count() == 6 || instr->operand_count() == 7 ||
instr->operand_count() == 8);
TF_RET_CHECK(instr->operand_count() > 3 && instr->operand_count() < 8);
TF_ASSIGN_OR_RETURN(const auto gpu_config,
instr->backend_config<xla::gpu::GpuBackendConfig>());
const xla::gpu::GemmBackendConfig& config = gpu_config.gemm_backend_config();
Expand Down Expand Up @@ -777,22 +776,22 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice b_scale,
GetAllocationSliceForHlo(instr->operand(a_scale_index + 1)));

// cublasLT requires c_scale/d_scale to be null when C/D is not FP8.
// Currently, C cannot be FP8.
BufferAllocation::Slice c_scale, d_scale;
#if GOOGLE_CUDA
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice c_scale,
GetAllocationSliceForHlo(instr->operand(a_scale_index + 2)));
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice d_scale,
GetAllocationSliceForHlo(instr->operand(a_scale_index + 3)));
#else // TENSORFLOW_USE_ROCM
BufferAllocation::Slice c_scale;
BufferAllocation::Slice d_scale;
if (instr->shape().tuple_shapes(0).element_type() == F8E4M3FN ||
instr->shape().tuple_shapes(0).element_type() == F8E5M2) {
TF_ASSIGN_OR_RETURN(d_scale,
GetAllocationSliceForHlo(instr->operands().back()));
}
#endif

BufferAllocation::Slice bias;
if (has_vector_bias) {
TF_ASSIGN_OR_RETURN(
bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 4)));
bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 2)));
}

BufferAllocation::Slice d_amax;
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/matmul_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ bool IsTf32Allowed(PrecisionConfig::Algorithm algorithm,
if (has_vector_bias) {
int vector_bias_index = has_matrix_bias ? 3 : 2;
if (primitive_util::IsF8Type(lhs_shape.element_type())) {
// FP8 gemms have 4 scales as inputs which come before the vector bias.
vector_bias_index += 4;
// FP8 gemms have 2 scales as inputs which come before the vector bias.
vector_bias_index += 2;
}
vector_bias_shape = gemm->operand(vector_bias_index)->shape();
}
Expand Down
27 changes: 18 additions & 9 deletions xla/service/gpu/transforms/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1083,12 +1083,18 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {

// cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32
// format. Set the factors to one when no scaling factors were captured.
Literal one_literal = LiteralUtil::One(F32);
HloInstruction *one = instr->AddInstruction(
HloInstruction::CreateConstant(one_literal.Clone()));
std::array<bool, 2> mult_scale{a.mult_scale, b.mult_scale};
std::array<HloInstruction *, 2> scales{a.scale, b.scale}, inv_scales,
scales_f32;
HloInstruction *one_constant = nullptr;
auto one = [&one_constant, instr]() -> HloInstruction * {
if (!one_constant) {
one_constant = instr->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::One(F32)));
}
return one_constant;
};

for (int i = 0; i < scales.size(); ++i) {
if (scales[i]) {
if (!ShapeUtil::IsScalar(scales[i]->shape())) {
Expand All @@ -1099,15 +1105,15 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
}
if (!mult_scale[i]) {
inv_scales[i] = instr->AddInstruction(HloInstruction::CreateBinary(
scales[i]->shape(), HloOpcode::kDivide, one, scales[i]));
scales[i]->shape(), HloOpcode::kDivide, one(), scales[i]));
}
scales_f32[i] = mult_scale[i] ? scales[i] : inv_scales[i];
if (scales_f32[i]->shape().element_type() != F32) {
scales_f32[i] = instr->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeScalarShape(F32), scales_f32[i]));
}
} else {
scales_f32[i] = one;
scales_f32[i] = one();
}
}

Expand Down Expand Up @@ -1249,7 +1255,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
PadShapeToMultipleOf16(instr->shape(), out_batch_dims);

std::vector<HloInstruction *> operands_list = {
a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1], one, one};
a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1]};

HloInstruction *new_custom_call =
instr->AddInstruction(HloInstruction::CreateCustomCall(
Expand Down Expand Up @@ -1415,13 +1421,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
}
}

// If necessary, invert the scaling factor of D and convert to F32.
// If necessary, invert the scaling factor of D and convert to F32. When no
// scaling factor was captured, set the factor to one.
if (d_scale) {
TF_ASSIGN_OR_RETURN(d_scale,
InvertAndConvertScalar(d_scale, !mult_scale));
TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(
gemm_backend_config.beta() == 0.0 ? 5 : 6, d_scale));
} else {
d_scale = instr->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::One(F32)));
}
existing_gemm->AppendOperand(d_scale);

// If present, elide the calculation of the maximum of the absolute values
// of the result of the GEMM.
Expand Down
Loading

0 comments on commit 3c5c920

Please sign in to comment.