Skip to content

Commit

Permalink
Better naming for loop fusion.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679349330
  • Loading branch information
Google-ML-Automation committed Sep 27, 2024
1 parent 9f7cb88 commit 9f35601
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 9 deletions.
6 changes: 4 additions & 2 deletions xla/hlo/ir/hlo_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2174,8 +2174,10 @@ HloInstruction::CreateDynamicReshape(
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
return std::make_unique<HloFusionInstruction>(shape, fusion_kind, fused_root);
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root,
absl::string_view prefix) {
return std::make_unique<HloFusionInstruction>(shape, fusion_kind, fused_root,
prefix);
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
Expand Down
3 changes: 2 additions & 1 deletion xla/hlo/ir/hlo_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,8 @@ class HloInstruction {
// "fused_root". Additional instructions can be added to the fusion
// instruction with the method FuseInstruction.
static std::unique_ptr<HloInstruction> CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root,
absl::string_view prefix = "");

static std::unique_ptr<HloInstruction> CreateFusion(
const Shape& shape, FusionKind fusion_kind,
Expand Down
6 changes: 4 additions & 2 deletions xla/hlo/ir/hlo_instructions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2172,11 +2172,13 @@ void HloCallableInstruction::RecursivelySetComputationsThreadName(

HloFusionInstruction::HloFusionInstruction(const Shape& shape,
FusionKind fusion_kind,
HloInstruction* fused_root)
HloInstruction* fused_root,
absl::string_view prefix)
: HloCallableInstruction(HloOpcode::kFusion, shape),
fusion_kind_(fusion_kind) {
CHECK(fused_root != nullptr);
SetAndSanitizeName(HloOpcodeString(opcode()));
SetAndSanitizeName(absl::StrCat(prefix, HloOpcodeString(opcode())));

set_parent(fused_root->parent());
set_metadata(fused_root->metadata());
set_frontend_attributes(fused_root->frontend_attributes());
Expand Down
3 changes: 2 additions & 1 deletion xla/hlo/ir/hlo_instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -1439,7 +1439,8 @@ class HloCallableInstruction : public HloInstruction {
class HloFusionInstruction : public HloCallableInstruction {
public:
explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
HloInstruction* fused_root);
HloInstruction* fused_root,
absl::string_view prefix = "");

explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
absl::Span<HloInstruction* const> operands,
Expand Down
7 changes: 5 additions & 2 deletions xla/service/instruction_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,8 +718,11 @@ HloInstruction* InstructionFusion::AddFusionInstruction(
fusion_instruction->set_fusion_kind(kind);
}
} else {
fusion_instruction = computation->AddInstruction(
HloInstruction::CreateFusion(consumer->shape(), kind, consumer));
fusion_instruction =
computation->AddInstruction(HloInstruction::CreateFusion(
consumer->shape(), kind, consumer,
absl::StrCat(HloOpcodeString(producer->opcode()), "_",
HloOpcodeString(consumer->opcode()), "_")));
TF_CHECK_OK(computation->ReplaceInstruction(consumer, fusion_instruction));
}
fusion_instruction->set_called_computations_execution_thread(
Expand Down
2 changes: 1 addition & 1 deletion xla/service/propagate_original_value_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ CHECK: ROOT %[[ADD:.*]] = u32[2]{0:T(256)} add(%[[PAD]], %[[PAD1]]), origin={{
CHECK: ENTRY %test
CHECK: %Arg_0 = s32[]{:T(256)} parameter(0), origin={{[{]}}{"Arg_0"}
CHECK: ROOT %fusion = u32[2]{0:T(256)} fusion(%Arg_0), kind=kLoop, calls=%fused_computation
CHECK: ROOT %pad_add_fusion = u32[2]{0:T(256)} fusion(%Arg_0), kind=kLoop, calls=%fused_computation
)");
}

Expand Down

0 comments on commit 9f35601

Please sign in to comment.