diff --git a/xla/service/algebraic_simplifier.cc b/xla/service/algebraic_simplifier.cc index dd13ffebd9e852..641fedf0c72405 100644 --- a/xla/service/algebraic_simplifier.cc +++ b/xla/service/algebraic_simplifier.cc @@ -8188,6 +8188,33 @@ absl::Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) { select->mutable_operand(0)->shape(), HloOpcode::kNot, select->mutable_operand(0))); } + // select(compare(a, b, GT/GE), a, b) => or(a, b) + // select(compare(a, b, LT/LE), a, b) => and(a, b) + // select(compare(a, b, EQ), a, b) => b + // select(compare(a, b, NE), a, b) => a + HloInstruction *compare, *lhs, *rhs; + if (Match(select, m::Select(m::Op(&compare), m::Op(&lhs), m::Op(&rhs))) && + Match(compare, m::Compare(m::Op().Is(lhs), m::Op().Is(rhs)))) { + auto cmp_dir = compare->comparison_direction(); + if (cmp_dir == ComparisonDirection::kGt || + cmp_dir == ComparisonDirection::kGe) { + return ReplaceWithNewInstruction( + select, HloInstruction::CreateBinary(select->shape(), + HloOpcode::kOr, lhs, rhs)); + } + if (cmp_dir == ComparisonDirection::kLt || + cmp_dir == ComparisonDirection::kLe) { + return ReplaceWithNewInstruction( + select, HloInstruction::CreateBinary(select->shape(), + HloOpcode::kAnd, lhs, rhs)); + } + if (cmp_dir == ComparisonDirection::kEq) { + return ReplaceInstruction(select, rhs); + } + if (cmp_dir == ComparisonDirection::kNe) { + return ReplaceInstruction(select, lhs); + } + } } // select(pred, xs, dynamic_update_slice(xs, x, i)) diff --git a/xla/service/algebraic_simplifier_test.cc b/xla/service/algebraic_simplifier_test.cc index 00970a51546b1a..921098aa7565e8 100644 --- a/xla/service/algebraic_simplifier_test.cc +++ b/xla/service/algebraic_simplifier_test.cc @@ -736,6 +736,95 @@ TEST_F(AlgebraicSimplifierTest, SelectPredPred2) { GmockMatch(m::Not(m::Parameter(0)))); } +// select(compare(a, b, GT/GE), a, b) => or(a, b), a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectGtCompare) { + for (const auto cmp_dir : {"GT", "GE"}) { + const auto kModuleStr = absl::StrFormat(R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=%s + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )", + cmp_dir); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Or(m::Parameter(0), m::Parameter(1)))); + } +} + +// select(compare(a, b, LT/LE), a, b) => and(a, b), a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectLtCompare) { + for (const auto cmp_dir : {"LT", "LE"}) { + const auto kModuleStr = absl::StrFormat(R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=%s + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )", + cmp_dir); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::And(m::Parameter(0), m::Parameter(1)))); + } +} + +// select(compare(a, b, EQ), a, b) => b, a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectEqCompare) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=EQ + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(1))); +} + +// select(compare(a, b, NE), a, b) => a, a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectNeCompare) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=NE + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +// select(compare(a, b, NE), b, a) ≠> a - wrong operands order +TEST_F(AlgebraicSimplifierTest, SelectNeCompare_NegativeTestCase) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=NE + ROOT select = pred[8]{0} select(compare, p1, p0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + // Test that select(pred, xs, dynamic_update_slice(xs, x, i)) is simplified // to dynamic_update_slice(xs, select(pred, dynamic_slice(xs, i), x), i) TEST_F(AlgebraicSimplifierTest, SelectDUSWithShapedPred) { diff --git a/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc index bbc32250444b0f..b03df23e41e764 100644 --- a/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -76,7 +76,6 @@ BENCHMARK(BM_SelectAndScatterF32) ->Arg(64) ->Arg(128) ->Arg(256) - ->Arg(512) - ->Arg(1024); + ->Arg(512); } // namespace xla::cpu diff --git a/xla/service/cpu/runtime/kernel_thunk.cc b/xla/service/cpu/runtime/kernel_thunk.cc index a8d793d1076071..21c57fef35f940 100644 --- a/xla/service/cpu/runtime/kernel_thunk.cc +++ b/xla/service/cpu/runtime/kernel_thunk.cc @@ -87,40 +87,46 @@ tsl::AsyncValueRef KernelThunk::Execute( kernel_name_, arguments_buffers_.size(), results_buffers_.size(), thread_dim_.ToString()); - absl::InlinedVector kernel_args; - kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size()); + int64_t num_args = arguments_buffers_.size() + results_buffers_.size(); + absl::InlinedVector kernel_args(num_args); + + // We initialize `kernel_args` array using pointer to the first argument, + // because individual elements access adds up measurable overhead, and this + // code is on the critical path. + SE_HOST_KernelArg* kernel_args_ptr = kernel_args.data(); + int64_t kernel_arg_idx = 0; int64_t arg_num = 0; for (BufferAllocation::Slice& buffer : arguments_buffers_) { TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data, params.buffer_allocations->GetDeviceAddress(buffer)); - kernel_args.push_back( - SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}); VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++, - buffer.ToString(), kernel_args.back().data); + buffer.ToString(), arg_data.opaque()); + kernel_args_ptr[kernel_arg_idx++] = + SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}; } int64_t res_num = 0; for (BufferAllocation::Slice& buffer : results_buffers_) { TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data, params.buffer_allocations->GetDeviceAddress(buffer)); - kernel_args.push_back( - SE_HOST_KernelArg{result_data.opaque(), result_data.size()}); VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++, - buffer.ToString(), kernel_args.back().data); + buffer.ToString(), result_data.opaque()); + kernel_args_ptr[kernel_arg_idx++] = + SE_HOST_KernelArg{result_data.opaque(), result_data.size()}; } // Check that all buffers are aligned to the minimum alignment. We codegen // with the assumption that all buffers are aligned, and if they are not, we // will crash with a segmentation fault, or worse, produce incorrect results. if (min_alignment_.has_value()) { - for (int64_t i = 0; i < kernel_args.size(); ++i) { - auto ptr = reinterpret_cast(kernel_args[i].data); + for (int64_t i = 0; i < num_args; ++i) { + auto ptr = reinterpret_cast(kernel_args_ptr[i].data); if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { return Internal( "Host kernel %s buffer argument #%d (%p) is not aligned to a " "required minimum alignment of %d bytes", - info().op_name, i, kernel_args[i].data, *min_alignment_); + info().op_name, i, kernel_args_ptr[i].data, *min_alignment_); } } } @@ -136,7 +142,7 @@ tsl::AsyncValueRef KernelThunk::Execute( params.host_kernels->Find(kernel_name_)); absl::MutexLock lock(&mutex_); - kernel_.emplace(kernel_args.size(), kernel_fn, nullptr); + kernel_.emplace(num_args, kernel_fn, nullptr); kernel_ptr_.store(kernel = &kernel_.value()); } diff --git a/xla/service/gpu/gpu_windowed_einsum_handler.cc b/xla/service/gpu/gpu_windowed_einsum_handler.cc index ce1dfa4f1c7863..8f5e26124f24a4 100644 --- a/xla/service/gpu/gpu_windowed_einsum_handler.cc +++ b/xla/service/gpu/gpu_windowed_einsum_handler.cc @@ -378,8 +378,16 @@ absl::StatusOr HandleAgWindowedEinsumLoop(HloComputation* comp, return changed; } +static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) { + const HloInstruction* loop_tuple = while_loop->operand(0); + const Shape& tuple_shape = loop_tuple->shape(); + CHECK(tuple_shape.IsTuple()); + return tuple_shape.tuple_shapes_size(); +} + absl::Status ProcessWindowedEinsumLoopForActivationCaching( - GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop) { + GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop, + HloInstruction* ag_with_shared_operand) { HloInstruction* loop = ag_loop.loop; // Transform the while body to cache the allgathered result in the // output buffer to be consumed by the dot @@ -392,15 +400,61 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( } // Get the output operand of the full buffer. HloInstruction* root = while_body->root_instruction(); + // Change loop body to include the new input and output element. + HloInstruction* input_tuple = while_body->parameter_instruction(0); + const Shape& input_shape = input_tuple->shape(); // The full buffer that we will use to cache the accumulated activation - // is the 4th operand in the output tuple. - int64_t full_cache_buffer_index = 3; + // is the last operand in the output tuple. + int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop); + std::vector new_input_shapes(input_shape.tuple_shapes().begin(), + input_shape.tuple_shapes().end()); + new_input_shapes.push_back(ag_with_shared_operand->shape()); + // Update body input shape + Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes); + *input_tuple->mutable_shape() = new_input_shape; HloInstruction* full_buffer_output_gte = - root->mutable_operand(full_cache_buffer_index); - HloInstruction* new_full_buffer_output; + while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + ag_with_shared_operand->shape(), input_tuple, + full_cache_buffer_index)); + + // Update condition input shape + HloComputation* cond_comp = loop->while_condition(); + HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0); + *cond_input_tuple->mutable_shape() = new_input_shape; + + // Update input to the while instruction in parent computation + HloInstruction* original_while_input = loop->mutable_operand(0); + HloComputation* parent_comp = loop->parent(); + std::vector new_operands( + original_while_input->operands().begin(), + original_while_input->operands().end()); + new_operands.push_back( + parent_comp->AddInstruction(HloInstruction::CreateBroadcast( + ag_with_shared_operand->shape(), + parent_comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(new_input_shapes[0].element_type()))), + {}))); + HloInstruction* new_while_input = + parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR( + loop->ReplaceOperandWithDifferentShape(0, new_while_input)); + TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape( + original_while_input, new_while_input)); + *loop->mutable_shape() = new_input_shape; + + HloInstruction* new_full_buffer_output = nullptr; // Find the DUS in the loop body and re-use the slice indices // This should just be a constant(0) HloInstruction* dus_boundary_constant; + // The slice we need this time is the output of the first + // collective-permute + HloInstruction* first_cp_output; + for (HloInstruction* gte_user : input_gte->users()) { + if (gte_user->opcode() == HloOpcode::kCollectivePermute) { + first_cp_output = gte_user; + break; + } + } for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) { HloInstruction* slice_indices; // If we have a DUS(PARAM,DS) pattern, we need to update the output @@ -434,24 +488,68 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( dus_boundary_constant->shape(), slice_indices)); VLOG(5) << "Created slice op for second slice: " << slice_indices->ToString(); - // The slice we need this time is the output of the first - // collective-permute - HloInstruction* cp_output; - for (HloInstruction* gte_user : input_gte->users()) { - if (gte_user->opcode() == HloOpcode::kCollectivePermute) { - cp_output = gte_user; - break; - } - } new_full_buffer_output = while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_buffer_output_gte->shape(), full_buffer_output_gte, - cp_output, + first_cp_output, {dus_boundary_constant, slice_indices, dus_boundary_constant})); } + + // If we have a Dot(DS(parameter_index1)), then operands are sharded along + // the contracting dim. Slice indices will be the contracting dim's slices. + HloInstruction* slice_index; + HloInstruction* ds_index_constant; + HloInstruction* remainder; + HloInstruction* ds_param; + // There will be 2 dynamic-slices for unrolled loops, match for each one to + // get the slice index which will be used to write the corresponding + // received shard into cached activation buffer. For unrolled loops, we need + // to write to the final buffer twice per iteration, so we need to match for + // the correct slice index based on each DS. + if (Match(inst, m::Dot(m::Op(), m::DynamicSlice(&ds_param))) && + Match(ds_param->operand(0), m::GetTupleElement(m::Parameter(), 1))) { + for (int64_t ds_op_i = 1; ds_op_i < ds_param->operands().size(); + ds_op_i++) { + if (!Match( + ds_param->mutable_operand(ds_op_i), + m::Reshape(&slice_index, m::DynamicSlice(m::Constant(), + m::Op(&remainder)))) && + !Match(ds_param->mutable_operand(ds_op_i), + m::Constant(&ds_index_constant))) { + return absl::OkStatus(); + } + } + // First DS has slice index calculated based on loop iterator + // Remainder(add(gte, partition_id)) + if (Match(remainder, + m::Remainder(m::Add(m::GetTupleElement(), m::Op()), m::Op()))) { + full_buffer_output_gte = + while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_buffer_output_gte->shape(), full_buffer_output_gte, + input_gte, + {ds_index_constant, ds_index_constant, slice_index})); + } + // Second DS has slice index calculated based on loop iterator+1 hence + // Remainder(add(add(gte, 1), partition_id)) + if (Match(remainder, + m::Remainder( + m::Add(m::Add(m::GetTupleElement(), m::Op()), m::Op()), + m::Op()))) { + new_full_buffer_output = + while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_buffer_output_gte->shape(), full_buffer_output_gte, + first_cp_output, + {ds_index_constant, ds_index_constant, slice_index})); + } + } } - TF_RETURN_IF_ERROR(root->ReplaceOperandWith(full_cache_buffer_index, - new_full_buffer_output)); + std::vector original_operands(root->operands().begin(), + root->operands().end()); + original_operands.push_back(new_full_buffer_output); + HloInstruction* new_output_tuple = while_body->AddInstruction( + HloInstruction::CreateTuple(original_operands)); + TF_RETURN_IF_ERROR( + while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple)); return absl::OkStatus(); } @@ -620,17 +718,20 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { VLOG(5) << "Found all-gather that shares the same operand with a " "windowed einsum loop : " << loop->ToString(); + + if (!ag_loop.consumed) { + TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching( + ag_loop, ag_with_shared_operand)); + ag_loop.consumed = true; + } int64_t cache_output_index = dot->operand_index(ag_with_shared_operand); - HloInstruction* new_gte = comp->AddInstruction( - HloInstruction::CreateGetTupleElement(loop, 3)); + HloComputation* comp = dot->parent(); + HloInstruction* new_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + loop, GetAgActivationCacheIndex(loop) - 1)); TF_RETURN_IF_ERROR( dot->ReplaceOperandWith(cache_output_index, new_gte)); TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand)); - if (!ag_loop.consumed) { - TF_RETURN_IF_ERROR( - ProcessWindowedEinsumLoopForActivationCaching(ag_loop)); - ag_loop.consumed = true; - } } } // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms diff --git a/xla/service/gpu/gpu_windowed_einsum_handler_test.cc b/xla/service/gpu/gpu_windowed_einsum_handler_test.cc index 23257e1c71a34b..6f23319980e90c 100644 --- a/xla/service/gpu/gpu_windowed_einsum_handler_test.cc +++ b/xla/service/gpu/gpu_windowed_einsum_handler_test.cc @@ -269,23 +269,22 @@ ENTRY main.12_spmd { FindInstructionByName(module->entry_computation(), "dot.7"); // dot.7 should now consume output of the windowed einsum while loop. EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); - EXPECT_EQ(inst->operand(0)->tuple_index(), 3); + EXPECT_EQ(inst->operand(0)->tuple_index(), 5); EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); // while loop's root should now have a chain of DUS. HloInstruction* ag_while_root = ag_loop->while_body()->root_instruction(); EXPECT_THAT(ag_while_root, GmockMatch(m::Tuple( - m::Op(), m::Op(), m::Op(), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op(), m::DynamicUpdateSlice( m::DynamicUpdateSlice( m::GetTupleElement(m::Parameter()) .WithPredicate([](const HloInstruction* instr) { - return instr->tuple_index() == 3; + return instr->tuple_index() == 5; }), m::Op(), m::Op(), m::Op(), m::Op()), - m::Op(), m::Op(), m::Op(), m::Op()), - m::Op()))); + m::Op(), m::Op(), m::Op(), m::Op())))); } TEST_F(GpuWindowedEinsumHanlderTest, A2aGemmHaveStreamIds) { constexpr absl::string_view kHloString = R"( @@ -838,5 +837,82 @@ ENTRY main.9_spmd { )"); } +TEST_F(GpuWindowedEinsumHanlderTest, + AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) { + constexpr absl::string_view kHloString = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8 + +windowed_dot_general_body_ag { + param.195 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) + get-tuple-element.588 = bf16[16,2048,512]{2,1,0} get-tuple-element(param.195), index=0 + collective-permute.194 = bf16[16,2048,512]{2,1,0} collective-permute(get-tuple-element.588), channel_id=446, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}} + collective-permute.195 = bf16[16,2048,512]{2,1,0} collective-permute(collective-permute.194), channel_id=447, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}} + get-tuple-element.589 = bf16[4096,6288]{1,0} get-tuple-element(param.195), index=1 + get-tuple-element.590 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=2 + constant.11432 = s32[8]{0} constant({0, 512, 1024, 1536, 2048, 2560, 3072, 3584}) + get-tuple-element.592 = u32[] get-tuple-element(param.195), index=4 + partition-id.194 = u32[] partition-id() + add.4309 = u32[] add(get-tuple-element.592, partition-id.194) + constant.11431 = u32[] constant(8) + remainder.194 = u32[] remainder(add.4309, constant.11431) + dynamic-slice.388 = s32[1]{0} dynamic-slice(constant.11432, remainder.194), dynamic_slice_sizes={1} + reshape.12959 = s32[] reshape(dynamic-slice.388) + constant.11433 = s32[] constant(0) + dynamic-slice.389 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12959, constant.11433), dynamic_slice_sizes={512,6288} + dot.244 = bf16[16,2048,6288]{2,1,0} dot(get-tuple-element.588, dynamic-slice.389), lhs_contracting_dims={2}, rhs_contracting_dims={0} + add.4310 = bf16[16,2048,6288]{2,1,0} add(get-tuple-element.590, dot.244) + constant.11434 = u32[] constant(1) + add.4312 = u32[] add(get-tuple-element.592, constant.11434) + add.4313 = u32[] add(add.4312, partition-id.194) + remainder.195 = u32[] remainder(add.4313, constant.11431) + dynamic-slice.390 = s32[1]{0} dynamic-slice(constant.11432, remainder.195), dynamic_slice_sizes={1} + reshape.12960 = s32[] reshape(dynamic-slice.390) + dynamic-slice.391 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12960, constant.11433), dynamic_slice_sizes={512,6288} + dot.245 = bf16[16,2048,6288]{2,1,0} dot(collective-permute.194, dynamic-slice.391), lhs_contracting_dims={2}, rhs_contracting_dims={0} + add.4314 = bf16[16,2048,6288]{2,1,0} add(add.4310, dot.245) + get-tuple-element.591 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=3 + add.4315 = u32[] add(add.4312, constant.11434) + ROOT tuple.98 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(collective-permute.195, get-tuple-element.589, add.4314, get-tuple-element.591, add.4315) +} // windowed_dot_general_body_ag + +windowed_dot_general_cond_ag { + param = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) + get-tuple-element = u32[] get-tuple-element(param), index=4 + constant = u32[] constant(4) + ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT +} + +ENTRY main.12_spmd { + param.4 = bf16[16,2048,512]{2,1,0} parameter(0) + param.5 = bf16[4096,6288]{1,0} parameter(1) + constant.22 = bf16[] constant(0) + broadcast = bf16[16,2048,6288]{2,1,0} broadcast(constant.22), dimensions={} + constant.24 = u32[] constant(0) + tuple.2 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24) + while = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag + get-tuple-element.13 = bf16[16,2048,6288]{2,1,0} get-tuple-element(while), index=2 + all-gather = bf16[16,2048,4096]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, use_global_device_ids=true + param.6 = bf16[16,2048,6288]{2,1,0} parameter(2) + ROOT dot.7 = bf16[4096,6288]{1,0} dot(all-gather, param.6), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + GpuWindowedEinsumHandler gpu_handler; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* ag_loop = + FindInstructionByName(module->entry_computation(), "while"); + HloInstruction* inst = + FindInstructionByName(module->entry_computation(), "dot.7"); + // dot.7 should now consume output of the windowed einsum while loop. + EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); + EXPECT_EQ(inst->operand(0)->tuple_index(), 5); + EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); +} } // namespace } // namespace xla::gpu diff --git a/xla/stream_executor/host/host_kernel.cc b/xla/stream_executor/host/host_kernel.cc index 04586b5272432b..cad37e1bfa4fb0 100644 --- a/xla/stream_executor/host/host_kernel.cc +++ b/xla/stream_executor/host/host_kernel.cc @@ -67,8 +67,7 @@ class HostKernelExecuteState : public tsl::ReferenceCounted { public: HostKernelExecuteState(HostKernel::TaskRunner task_runner, - HostKernel::KernelFunction* function, - ThreadDim thread_dims, + SE_HOST_Kernel* kernel, ThreadDim thread_dims, absl::Span args); // Notify of a completion of a host kernel task. @@ -112,6 +111,7 @@ HostKernel::HostKernel(std::shared_ptr thread_pool) HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel, std::shared_ptr thread_pool) : function_(std::make_unique(kernel)), + kernel_(function_->kernel()), arity_(arity), thread_pool_(thread_pool) {} @@ -130,8 +130,6 @@ absl::Status HostKernel::Launch( thread_dims.z, }; - SE_HOST_Kernel* kernel = function_->kernel(); - for (uint64_t z = 0; z < thread_dims.z; ++z) { for (uint64_t y = 0; y < thread_dims.y; ++y) { for (uint64_t x = 0; x < thread_dims.x; ++x) { @@ -140,7 +138,7 @@ absl::Status HostKernel::Launch( SE_HOST_KernelCallFrame call_frame = { &kernel_thread_dims, &kernel_thread, args.size(), args.data()}; - SE_HOST_KernelError* error = (*kernel)(&call_frame); + SE_HOST_KernelError* error = (*kernel_)(&call_frame); if (ABSL_PREDICT_FALSE(error != nullptr)) { return absl::InternalError("Failed to call host kernel"); @@ -174,8 +172,8 @@ tsl::AsyncValueRef HostKernel::Launch( } // Allocate a control structure that will orchestrate kernel execution. - auto state = tsl::MakeRef( - std::move(task_runner), function_.get(), thread_dims, args); + auto state = tsl::MakeRef(std::move(task_runner), + kernel_, thread_dims, args); state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks); @@ -183,11 +181,11 @@ tsl::AsyncValueRef HostKernel::Launch( } HostKernelExecuteState::HostKernelExecuteState( - HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, + HostKernel::TaskRunner task_runner, SE_HOST_Kernel kernel, ThreadDim thread_dims, absl::Span args) : task_runner_(std::move(task_runner)), num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z), - kernel_(function->kernel()), + kernel_(kernel), thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}), args_(args.begin(), args.end()), abort_(false), diff --git a/xla/stream_executor/host/host_kernel.h b/xla/stream_executor/host/host_kernel.h index 9d278b2b79c357..003b093795e998 100644 --- a/xla/stream_executor/host/host_kernel.h +++ b/xla/stream_executor/host/host_kernel.h @@ -117,6 +117,7 @@ class HostKernel : public Kernel { private: std::unique_ptr function_; + SE_HOST_Kernel* kernel_; unsigned arity_; std::shared_ptr thread_pool_;