Skip to content

Commit

Permalink
[xla:gpu] Do not deduplicate kernel arguments when compiling Pallas G…
Browse files Browse the repository at this point in the history
…PU kerneles

PiperOrigin-RevId: 619611227
  • Loading branch information
superbobry authored and tensorflower-gardener committed Mar 27, 2024
1 parent 9eee895 commit a4ca0ab
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
6 changes: 4 additions & 2 deletions third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,8 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall(
TF_ASSIGN_OR_RETURN(
auto kernel_arguments,
KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr,
instr->operands()));
instr->operands(),
/*dedup=*/false));
auto launch_dimensions =
LaunchDimensions(se::BlockDim(call.grid_x, call.grid_y, call.grid_z),
se::ThreadDim(call.num_warps * 32));
Expand Down Expand Up @@ -1696,7 +1697,8 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall(
TF_ASSIGN_OR_RETURN(
auto kernel_arguments,
KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr,
instr->operands()));
instr->operands(),
/*dedup=*/false));

AddThunkToThunkSequence(std::make_unique<KernelThunk>(
instr, entry->kernel_name, kernel_arguments.args(),
Expand Down
10 changes: 5 additions & 5 deletions third_party/xla/xla/service/gpu/kernel_arguments.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ absl::StatusOr<KernelArguments> KernelArguments::Create(
return absl::OkStatus();
}));

return KernelArguments{std::move(kernel_arguments)};
return KernelArguments{std::move(kernel_arguments), /*dedup=*/true};
}

std::vector<KernelArgument> KernelArguments::ProcessArguments(
std::vector<KernelArgument> kernel_arguments) {
std::vector<KernelArgument> kernel_arguments, bool dedup) {
absl::flat_hash_set<BufferAllocation::Slice> buffers_written;
for (const KernelArgument& kernel_argument : kernel_arguments) {
if (kernel_argument.written()) {
Expand All @@ -79,7 +79,7 @@ std::vector<KernelArgument> KernelArguments::ProcessArguments(
KernelArgument& kernel_argument = kernel_arguments[i];

auto& first_index = first_indices_for_slices[kernel_argument.slice_];
if (first_index) {
if (dedup && first_index) {
const KernelArgument& same = kernel_arguments[*first_index];
kernel_argument.first_with_same_slice_ = first_index;
kernel_argument.alignment_ = same.alignment_;
Expand Down Expand Up @@ -128,7 +128,7 @@ std::vector<KernelArgument> KernelArguments::ProcessArguments(
absl::StatusOr<KernelArguments> KernelArguments::Create(
const BufferAssignment& buffer_assignment,
const HloInstruction* non_fusion_hlo,
absl::Span<const HloInstruction* const> needed_operands) {
absl::Span<const HloInstruction* const> needed_operands, bool dedup) {
std::vector<KernelArgument> kernel_arguments;
for (const HloInstruction* operand : needed_operands) {
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
Expand All @@ -151,7 +151,7 @@ absl::StatusOr<KernelArguments> KernelArguments::Create(
return absl::OkStatus();
}));

return KernelArguments{std::move(kernel_arguments)};
return KernelArguments{std::move(kernel_arguments), dedup};
}

} // namespace gpu
Expand Down
9 changes: 5 additions & 4 deletions third_party/xla/xla/service/gpu/kernel_arguments.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,17 @@ class KernelArguments {
static absl::StatusOr<KernelArguments> Create(
const BufferAssignment& buffer_assignment,
const HloInstruction* non_fusion_hlo,
absl::Span<const HloInstruction* const> needed_operands);
absl::Span<const HloInstruction* const> needed_operands,
bool dedup = true);

const std::vector<KernelArgument>& args() const { return args_; }

private:
explicit KernelArguments(std::vector<KernelArgument> args)
: args_(ProcessArguments(std::move(args))) {}
explicit KernelArguments(std::vector<KernelArgument> args, bool dedup = true)
: args_(ProcessArguments(std::move(args), dedup)) {}

static std::vector<KernelArgument> ProcessArguments(
std::vector<KernelArgument> kernel_arguments);
std::vector<KernelArgument> kernel_arguments, bool dedup);

std::vector<KernelArgument> args_;
};
Expand Down

0 comments on commit a4ca0ab

Please sign in to comment.