Skip to content

Commit

Permalink
[xla:cpu] NFC: Micro-optimizations for KernelThunk
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=#14073 from apivovarov:select_compare_algsimp 6fe68d7
PiperOrigin-RevId: 646316826
  • Loading branch information
ezhulenev authored and copybara-github committed Jun 26, 2024
1 parent ec96f90 commit 2307d94
Show file tree
Hide file tree
Showing 8 changed files with 349 additions and 52 deletions.
27 changes: 27 additions & 0 deletions xla/service/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8188,6 +8188,33 @@ absl::Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) {
select->mutable_operand(0)->shape(), HloOpcode::kNot,
select->mutable_operand(0)));
}
// select(compare(a, b, GT/GE), a, b) => or(a, b)
// select(compare(a, b, LT/LE), a, b) => and(a, b)
// select(compare(a, b, EQ), a, b) => b
// select(compare(a, b, NE), a, b) => a
HloInstruction *compare, *lhs, *rhs;
if (Match(select, m::Select(m::Op(&compare), m::Op(&lhs), m::Op(&rhs))) &&
Match(compare, m::Compare(m::Op().Is(lhs), m::Op().Is(rhs)))) {
auto cmp_dir = compare->comparison_direction();
if (cmp_dir == ComparisonDirection::kGt ||
cmp_dir == ComparisonDirection::kGe) {
return ReplaceWithNewInstruction(
select, HloInstruction::CreateBinary(select->shape(),
HloOpcode::kOr, lhs, rhs));
}
if (cmp_dir == ComparisonDirection::kLt ||
cmp_dir == ComparisonDirection::kLe) {
return ReplaceWithNewInstruction(
select, HloInstruction::CreateBinary(select->shape(),
HloOpcode::kAnd, lhs, rhs));
}
if (cmp_dir == ComparisonDirection::kEq) {
return ReplaceInstruction(select, rhs);
}
if (cmp_dir == ComparisonDirection::kNe) {
return ReplaceInstruction(select, lhs);
}
}
}

// select(pred, xs, dynamic_update_slice(xs, x, i))
Expand Down
89 changes: 89 additions & 0 deletions xla/service/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,95 @@ TEST_F(AlgebraicSimplifierTest, SelectPredPred2) {
GmockMatch(m::Not(m::Parameter(0))));
}

// select(compare(a, b, GT/GE), a, b) => or(a, b), a,b ∈ PRED
TEST_F(AlgebraicSimplifierTest, SelectGtCompare) {
for (const auto cmp_dir : {"GT", "GE"}) {
const auto kModuleStr = absl::StrFormat(R"(
HloModule m
test {
p0 = pred[8]{0} parameter(0)
p1 = pred[8]{0} parameter(1)
compare = pred[8]{0} compare(p0, p1), direction=%s
ROOT select = pred[8]{0} select(compare, p0, p1)
}
)",
cmp_dir);
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Or(m::Parameter(0), m::Parameter(1))));
}
}

// select(compare(a, b, LT/LE), a, b) => and(a, b), a,b ∈ PRED
TEST_F(AlgebraicSimplifierTest, SelectLtCompare) {
for (const auto cmp_dir : {"LT", "LE"}) {
const auto kModuleStr = absl::StrFormat(R"(
HloModule m
test {
p0 = pred[8]{0} parameter(0)
p1 = pred[8]{0} parameter(1)
compare = pred[8]{0} compare(p0, p1), direction=%s
ROOT select = pred[8]{0} select(compare, p0, p1)
}
)",
cmp_dir);
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::And(m::Parameter(0), m::Parameter(1))));
}
}

// select(compare(a, b, EQ), a, b) => b, a,b ∈ PRED
TEST_F(AlgebraicSimplifierTest, SelectEqCompare) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = pred[8]{0} parameter(0)
p1 = pred[8]{0} parameter(1)
compare = pred[8]{0} compare(p0, p1), direction=EQ
ROOT select = pred[8]{0} select(compare, p0, p1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(1)));
}

// select(compare(a, b, NE), a, b) => a, a,b ∈ PRED
TEST_F(AlgebraicSimplifierTest, SelectNeCompare) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = pred[8]{0} parameter(0)
p1 = pred[8]{0} parameter(1)
compare = pred[8]{0} compare(p0, p1), direction=NE
ROOT select = pred[8]{0} select(compare, p0, p1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(0)));
}

// select(compare(a, b, NE), b, a) ≠> a - wrong operands order
TEST_F(AlgebraicSimplifierTest, SelectNeCompare_NegativeTestCase) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = pred[8]{0} parameter(0)
p1 = pred[8]{0} parameter(1)
compare = pred[8]{0} compare(p0, p1), direction=NE
ROOT select = pred[8]{0} select(compare, p1, p0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
}

// Test that select(pred, xs, dynamic_update_slice(xs, x, i)) is simplified
// to dynamic_update_slice(xs, select(pred, dynamic_slice(xs, i), x), i)
TEST_F(AlgebraicSimplifierTest, SelectDUSWithShapedPred) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ BENCHMARK(BM_SelectAndScatterF32)
->Arg(64)
->Arg(128)
->Arg(256)
->Arg(512)
->Arg(1024);
->Arg(512);

} // namespace xla::cpu
30 changes: 18 additions & 12 deletions xla/service/cpu/runtime/kernel_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,40 +87,46 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
kernel_name_, arguments_buffers_.size(), results_buffers_.size(),
thread_dim_.ToString());

absl::InlinedVector<SE_HOST_KernelArg, 8> kernel_args;
kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size());
int64_t num_args = arguments_buffers_.size() + results_buffers_.size();
absl::InlinedVector<SE_HOST_KernelArg, 8> kernel_args(num_args);

// We initialize `kernel_args` array using pointer to the first argument,
// because individual elements access adds up measurable overhead, and this
// code is on the critical path.
SE_HOST_KernelArg* kernel_args_ptr = kernel_args.data();
int64_t kernel_arg_idx = 0;

int64_t arg_num = 0;
for (BufferAllocation::Slice& buffer : arguments_buffers_) {
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()});
VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++,
buffer.ToString(), kernel_args.back().data);
buffer.ToString(), arg_data.opaque());
kernel_args_ptr[kernel_arg_idx++] =
SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()};
}

int64_t res_num = 0;
for (BufferAllocation::Slice& buffer : results_buffers_) {
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{result_data.opaque(), result_data.size()});
VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++,
buffer.ToString(), kernel_args.back().data);
buffer.ToString(), result_data.opaque());
kernel_args_ptr[kernel_arg_idx++] =
SE_HOST_KernelArg{result_data.opaque(), result_data.size()};
}

// Check that all buffers are aligned to the minimum alignment. We codegen
// with the assumption that all buffers are aligned, and if they are not, we
// will crash with a segmentation fault, or worse, produce incorrect results.
if (min_alignment_.has_value()) {
for (int64_t i = 0; i < kernel_args.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data);
for (int64_t i = 0; i < num_args; ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args_ptr[i].data);
if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) {
return Internal(
"Host kernel %s buffer argument #%d (%p) is not aligned to a "
"required minimum alignment of %d bytes",
info().op_name, i, kernel_args[i].data, *min_alignment_);
info().op_name, i, kernel_args_ptr[i].data, *min_alignment_);
}
}
}
Expand All @@ -136,7 +142,7 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
params.host_kernels->Find(kernel_name_));

absl::MutexLock lock(&mutex_);
kernel_.emplace(kernel_args.size(), kernel_fn, nullptr);
kernel_.emplace(num_args, kernel_fn, nullptr);
kernel_ptr_.store(kernel = &kernel_.value());
}

Expand Down
149 changes: 125 additions & 24 deletions xla/service/gpu/gpu_windowed_einsum_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,16 @@ absl::StatusOr<bool> HandleAgWindowedEinsumLoop(HloComputation* comp,
return changed;
}

static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) {
const HloInstruction* loop_tuple = while_loop->operand(0);
const Shape& tuple_shape = loop_tuple->shape();
CHECK(tuple_shape.IsTuple());
return tuple_shape.tuple_shapes_size();
}

absl::Status ProcessWindowedEinsumLoopForActivationCaching(
GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop) {
GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop,
HloInstruction* ag_with_shared_operand) {
HloInstruction* loop = ag_loop.loop;
// Transform the while body to cache the allgathered result in the
// output buffer to be consumed by the dot
Expand All @@ -392,15 +400,61 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching(
}
// Get the output operand of the full buffer.
HloInstruction* root = while_body->root_instruction();
// Change loop body to include the new input and output element.
HloInstruction* input_tuple = while_body->parameter_instruction(0);
const Shape& input_shape = input_tuple->shape();
// The full buffer that we will use to cache the accumulated activation
// is the 4th operand in the output tuple.
int64_t full_cache_buffer_index = 3;
// is the last operand in the output tuple.
int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop);
std::vector<Shape> new_input_shapes(input_shape.tuple_shapes().begin(),
input_shape.tuple_shapes().end());
new_input_shapes.push_back(ag_with_shared_operand->shape());
// Update body input shape
Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes);
*input_tuple->mutable_shape() = new_input_shape;
HloInstruction* full_buffer_output_gte =
root->mutable_operand(full_cache_buffer_index);
HloInstruction* new_full_buffer_output;
while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
ag_with_shared_operand->shape(), input_tuple,
full_cache_buffer_index));

// Update condition input shape
HloComputation* cond_comp = loop->while_condition();
HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0);
*cond_input_tuple->mutable_shape() = new_input_shape;

// Update input to the while instruction in parent computation
HloInstruction* original_while_input = loop->mutable_operand(0);
HloComputation* parent_comp = loop->parent();
std::vector<HloInstruction*> new_operands(
original_while_input->operands().begin(),
original_while_input->operands().end());
new_operands.push_back(
parent_comp->AddInstruction(HloInstruction::CreateBroadcast(
ag_with_shared_operand->shape(),
parent_comp->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(new_input_shapes[0].element_type()))),
{})));
HloInstruction* new_while_input =
parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands));
TF_RETURN_IF_ERROR(
loop->ReplaceOperandWithDifferentShape(0, new_while_input));
TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape(
original_while_input, new_while_input));
*loop->mutable_shape() = new_input_shape;

HloInstruction* new_full_buffer_output = nullptr;
// Find the DUS in the loop body and re-use the slice indices
// This should just be a constant(0)
HloInstruction* dus_boundary_constant;
// The slice we need this time is the output of the first
// collective-permute
HloInstruction* first_cp_output;
for (HloInstruction* gte_user : input_gte->users()) {
if (gte_user->opcode() == HloOpcode::kCollectivePermute) {
first_cp_output = gte_user;
break;
}
}
for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) {
HloInstruction* slice_indices;
// If we have a DUS(PARAM,DS) pattern, we need to update the output
Expand Down Expand Up @@ -434,24 +488,68 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching(
dus_boundary_constant->shape(), slice_indices));
VLOG(5) << "Created slice op for second slice: "
<< slice_indices->ToString();
// The slice we need this time is the output of the first
// collective-permute
HloInstruction* cp_output;
for (HloInstruction* gte_user : input_gte->users()) {
if (gte_user->opcode() == HloOpcode::kCollectivePermute) {
cp_output = gte_user;
break;
}
}
new_full_buffer_output =
while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
full_buffer_output_gte->shape(), full_buffer_output_gte,
cp_output,
first_cp_output,
{dus_boundary_constant, slice_indices, dus_boundary_constant}));
}

// If we have a Dot(DS(parameter_index1)), then operands are sharded along
// the contracting dim. Slice indices will be the contracting dim's slices.
HloInstruction* slice_index;
HloInstruction* ds_index_constant;
HloInstruction* remainder;
HloInstruction* ds_param;
// There will be 2 dynamic-slices for unrolled loops, match for each one to
// get the slice index which will be used to write the corresponding
// received shard into cached activation buffer. For unrolled loops, we need
// to write to the final buffer twice per iteration, so we need to match for
// the correct slice index based on each DS.
if (Match(inst, m::Dot(m::Op(), m::DynamicSlice(&ds_param))) &&
Match(ds_param->operand(0), m::GetTupleElement(m::Parameter(), 1))) {
for (int64_t ds_op_i = 1; ds_op_i < ds_param->operands().size();
ds_op_i++) {
if (!Match(
ds_param->mutable_operand(ds_op_i),
m::Reshape(&slice_index, m::DynamicSlice(m::Constant(),
m::Op(&remainder)))) &&
!Match(ds_param->mutable_operand(ds_op_i),
m::Constant(&ds_index_constant))) {
return absl::OkStatus();
}
}
// First DS has slice index calculated based on loop iterator
// Remainder(add(gte, partition_id))
if (Match(remainder,
m::Remainder(m::Add(m::GetTupleElement(), m::Op()), m::Op()))) {
full_buffer_output_gte =
while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
full_buffer_output_gte->shape(), full_buffer_output_gte,
input_gte,
{ds_index_constant, ds_index_constant, slice_index}));
}
// Second DS has slice index calculated based on loop iterator+1 hence
// Remainder(add(add(gte, 1), partition_id))
if (Match(remainder,
m::Remainder(
m::Add(m::Add(m::GetTupleElement(), m::Op()), m::Op()),
m::Op()))) {
new_full_buffer_output =
while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
full_buffer_output_gte->shape(), full_buffer_output_gte,
first_cp_output,
{ds_index_constant, ds_index_constant, slice_index}));
}
}
}
TF_RETURN_IF_ERROR(root->ReplaceOperandWith(full_cache_buffer_index,
new_full_buffer_output));
std::vector<HloInstruction*> original_operands(root->operands().begin(),
root->operands().end());
original_operands.push_back(new_full_buffer_output);
HloInstruction* new_output_tuple = while_body->AddInstruction(
HloInstruction::CreateTuple(original_operands));
TF_RETURN_IF_ERROR(
while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple));
return absl::OkStatus();
}

Expand Down Expand Up @@ -620,17 +718,20 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor {
VLOG(5) << "Found all-gather that shares the same operand with a "
"windowed einsum loop : "
<< loop->ToString();

if (!ag_loop.consumed) {
TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching(
ag_loop, ag_with_shared_operand));
ag_loop.consumed = true;
}
int64_t cache_output_index = dot->operand_index(ag_with_shared_operand);
HloInstruction* new_gte = comp->AddInstruction(
HloInstruction::CreateGetTupleElement(loop, 3));
HloComputation* comp = dot->parent();
HloInstruction* new_gte =
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
loop, GetAgActivationCacheIndex(loop) - 1));
TF_RETURN_IF_ERROR(
dot->ReplaceOperandWith(cache_output_index, new_gte));
TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand));
if (!ag_loop.consumed) {
TF_RETURN_IF_ERROR(
ProcessWindowedEinsumLoopForActivationCaching(ag_loop));
ag_loop.consumed = true;
}
}
}
// Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms
Expand Down
Loading

0 comments on commit 2307d94

Please sign in to comment.