Skip to content

Commit

Permalink
PR #16893: Unary Ops in FP8 Windowed Einsums
Browse files Browse the repository at this point in the history
Imported from GitHub PR #16893

Adds support for unary ops between dequantization and windowed einsum loop.
Copybara import of the project:

--
fffc93f by Philipp Hack <[email protected]>:

Adds support for unary ops between dequantization and windowed einsum loop.

Merging this change closes #16893

COPYBARA_INTEGRATE_REVIEW=#16893 from philipphack:u_fp8_windowed_unary_xla fffc93f
PiperOrigin-RevId: 679470392
  • Loading branch information
philipphack authored and Google-ML-Automation committed Sep 27, 2024
1 parent 9fb4f21 commit 82c2770
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 150 deletions.
40 changes: 36 additions & 4 deletions xla/service/gpu/transforms/windowed_einsum_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,15 @@ namespace m = match;
// and type conversions of FP8 operands into the bodies of their while loops,
// i.e. rewrites
//
// inputs --> dequant --> while loop {collective-permute/dot/etc}
// inputs --> dequant --> (unary) --> while loop {collective-permute/dot/etc}
//
// into
//
// inputs --> while loop {dequant --> collective-permute/dot/etc}.
// Returns whether the input computation has been changed.
// inputs --> (unary) --> while loop {dequant --> collective-permute/dot/etc}.
//
// Unary bitcast, broadcast, copy, reshape and transpose ops are allowed between
// dequantization and while loop. Returns whether the input computation has been
// changed.
absl::StatusOr<bool> ShiftDequantizationF8(HloComputation* while_body) {
HloInstruction* while_instr = while_body->WhileCallInstruction();
// The input of the while loop will be modified and must have no other users.
Expand All @@ -73,8 +76,21 @@ absl::StatusOr<bool> ShiftDequantizationF8(HloComputation* while_body) {
// while loop.
HloInstruction* param_tuple = while_instr->mutable_operand(0);
std::array<HloInstruction*, 2> binaries, operands, scales;
std::array<std::vector<HloInstruction*>, 2> unaries;
for (int k = 0; k < 2; ++k) {
if (!Match(param_tuple->mutable_operand(k),
HloInstruction* operand = param_tuple->mutable_operand(k);
// Capture bitcast, broadcast, copy, reshape and transpose ops between
// dequantization and the loop.
while (operand->opcode() == HloOpcode::kBitcast ||
operand->opcode() == HloOpcode::kBroadcast ||
operand->opcode() == HloOpcode::kCopy ||
operand->opcode() == HloOpcode::kReshape ||
operand->opcode() == HloOpcode::kTranspose) {
unaries[k].emplace_back(operand);
operand = operand->mutable_operand(0);
}
std::reverse(unaries[k].begin(), unaries[k].end());
if (!Match(operand,
m::AnyOf<HloInstruction>(
m::Divide(&binaries[k], m::Convert(m::Op(&operands[k])),
m::Broadcast(m::Op(&scales[k]))),
Expand Down Expand Up @@ -156,6 +172,22 @@ absl::StatusOr<bool> ShiftDequantizationF8(HloComputation* while_body) {
return false;
}

// Replace any dequantized bitcast, broadcast, copy, reshape and transpose ops
// before the while loop with FP8 unary ops.
for (int k = 0; k < 2; ++k) {
for (HloInstruction* unary : unaries[k]) {
Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout(
operands[k]->shape().element_type(), unary->shape().dimensions(),
unary->shape().layout().minor_to_major());

operands[k] = unary->AddInstruction(unary->CloneWithNewOperands(
ShapeUtil::MakeShapeWithDenseLayout(
operands[k]->shape().element_type(), unary->shape().dimensions(),
unary->shape().layout().minor_to_major()),
{operands[k]}));
}
}

// Replace the dequantized dot operands in the parameter tuple used by while
// with FP8 operands.
for (int k = 0; k < 2; ++k) {
Expand Down
Loading

0 comments on commit 82c2770

Please sign in to comment.