Skip to content

Commit

Permalink
[IFRT] Add donated_input_indices attribute to CallOp to distinguish b…
Browse files Browse the repository at this point in the history
…etween donation and aliasing.

PiperOrigin-RevId: 679751788
  • Loading branch information
ICGog authored and Google-ML-Automation committed Sep 27, 2024
1 parent ca8bb43 commit 81724d4
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 105 deletions.
37 changes: 26 additions & 11 deletions xla/python/ifrt/ir/ifrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,24 @@ mlir::LogicalResult VerifyIoAlias(mlir::Operation* op, IoAlias io_alias,
return mlir::success();
}

mlir::LogicalResult VerifyIoAliases(mlir::Operation* op,
mlir::ArrayAttr io_aliases,
llvm::ArrayRef<IfrtArrayType> inputs,
llvm::ArrayRef<IfrtArrayType> outputs) {
llvm::SmallSet<int, 4> aliased_inputs;
mlir::LogicalResult VerifyIoAliasesAndDonations(
mlir::Operation* op, mlir::ArrayAttr io_aliases,
llvm::ArrayRef<int32_t> donated_input_indices,
llvm::ArrayRef<IfrtArrayType> inputs,
llvm::ArrayRef<IfrtArrayType> outputs) {
llvm::SmallSet<int, 4> aliased_or_donated_inputs;
llvm::SmallSet<int, 4> aliased_outputs;
for (const int32_t donated_input_index : donated_input_indices) {
if (donated_input_index < 0 || donated_input_index >= inputs.size()) {
return op->emitOpError()
<< "can't donate input #" << donated_input_index
<< " as only having " << inputs.size() << " inputs";
}
if (!aliased_or_donated_inputs.insert(donated_input_index).second) {
return op->emitOpError() << "can't donate input #" << donated_input_index
<< " more than once";
}
}
for (const auto& raw_io_alias :
io_aliases.getAsRange<mlir::DenseI32ArrayAttr>()) {
llvm::ArrayRef<int> io_alias_as_array = raw_io_alias.asArrayRef();
Expand All @@ -263,9 +275,9 @@ mlir::LogicalResult VerifyIoAliases(mlir::Operation* op,
inputs, outputs))) {
return mlir::failure();
}
if (!aliased_inputs.insert(aliased_input).second) {
return op->emitOpError()
<< "can't alias input #" << aliased_input << " more than once";
if (!aliased_or_donated_inputs.insert(aliased_input).second) {
return op->emitOpError() << "can't alias or donate input #"
<< aliased_input << " more than once";
}
if (!aliased_outputs.insert(aliased_output).second) {
return op->emitOpError()
Expand Down Expand Up @@ -618,8 +630,9 @@ mlir::LogicalResult CallOp::verify() {

if (mlir::failed(VerifyDevicePlacement(*this, getDevices(), input_arrays,
output_arrays)) ||
mlir::failed(VerifyIoAliases(*this, getIoAliases(), input_arrays,
output_arrays))) {
mlir::failed(VerifyIoAliasesAndDonations(*this, getIoAliases(),
getDonatedInputIndices(),
input_arrays, output_arrays))) {
return mlir::failure();
}
return mlir::success();
Expand Down Expand Up @@ -680,7 +693,9 @@ mlir::LogicalResult CallLoadedExecutableOp::verify() {
output_arrays.push_back(mlir::cast<IfrtArrayType>(output.getType()));
}

return VerifyIoAliases(*this, getIoAliases(), input_arrays, output_arrays);
return VerifyIoAliasesAndDonations(*this, getIoAliases(),
getDonatedInputIndices(), input_arrays,
output_arrays);
}

mlir::LogicalResult LoadedExecutableOp::verify() {
Expand Down
22 changes: 14 additions & 8 deletions xla/python/ifrt/ir/ifrt_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,20 @@ def Ifrt_CallOp : Ifrt_Op<"Call",
a subset of these devices.

`io_aliases` represents pairs of inputs and outputs, where the input buffer
may be donated and used as the output buffer. The aliased pair must have the
same Ifrt_ArrayType. It's up to IFRT implementations whether to respect this
hint or not.
may be aliased and used as the output buffer. The aliased pair must have the
same byte size. It's up to IFRT implementations whether to respect this
hint or not. Alternatively, if the index of an input is In
`donated_input_indices` then the input buffer might be donated to the
callee if an output with the same byte size is found.
}];

let arguments = (ins
Variadic<Ifrt_ArrayType>:$inputs,
Variadic<Ifrt_ControlType>:$control_inputs,
SymbolRefAttr:$callee,
Ifrt_DevicesAttr:$devices,
DefaultValuedAttr<IoAliasesAttr, "{}">:$io_aliases);
DefaultValuedAttr<IoAliasesAttr, "{}">:$io_aliases,
DefaultValuedAttr<DenseI32ArrayAttr, "{}">:$donated_input_indices);
let results = (outs
Variadic<Ifrt_ArrayType>:$outputs,
Ifrt_ControlType:$control_output);
Expand Down Expand Up @@ -220,16 +223,19 @@ def Ifrt_CallLoadedExecutableOp : Ifrt_Op<"CallLoadedExecutable",
be placed on a subset of these devices.

`io_aliases` represents pairs of inputs and outputs, where the input buffer
may be donated and used as the output buffer. The aliased pair must have the
same Ifrt_ArrayType. It's up to IFRT implementations whether to respect this
hint or not.
may be aliased and used as the output buffer. The aliased pair must have the
same byte size. It's up to IFRT implementations whether to respect this
hint or not. Alternatively, if the index of an input is In
`donated_input_indices` then the input buffer might be donated to the
callee if an output with the same byte size is found.
}];

let arguments = (ins
Variadic<Ifrt_ArrayType>:$inputs,
Variadic<Ifrt_ControlType>:$control_inputs,
SymbolRefAttr:$callee,
DefaultValuedAttr<IoAliasesAttr, "{}">:$io_aliases);
DefaultValuedAttr<IoAliasesAttr, "{}">:$io_aliases,
DefaultValuedAttr<DenseI32ArrayAttr, "{}">:$donated_input_indices);
let results = (outs
Variadic<Ifrt_ArrayType>:$outputs,
Ifrt_ControlType:$control_output);
Expand Down
28 changes: 18 additions & 10 deletions xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,17 @@ module @call_twice_with_different_sharding {

!array = !ifrt.array<tensor<2x2xi32>,
#ifrt.sharding_param<2x1 to [0] on 2>, [0,1]>
// CHECK-LABEL: @populate_io_alias
module @populate_io_alias {
func.func @main(%arg0: !array) attributes {ifrt.function} {
// CHECK: ifrt.Call @[[CALLEE_0:.+]]::@main(%arg0)
%0, %ctrl_0 = ifrt.Call @callee::@main(%arg0) on devices [0,1]
{io_aliases=[array<i32: 0, 0>]} : (!array) -> !array
// CHECK-LABEL: @populate_io_alias_and_donation
module @populate_io_alias_and_donation {
func.func @main(%arg0: !array, %arg1: !array) attributes {ifrt.function} {
// CHECK: ifrt.Call @[[CALLEE_0:.+]]::@main(%arg0, %arg1)
%0, %ctrl_0 = ifrt.Call @callee::@main(%arg0, %arg1) on devices [0,1]
{io_aliases=[array<i32: 0, 0>], donated_input_indices=array<i32: 1>}
: (!array, !array) -> !array
// Verify that the module is cloned if io_aliases differ.
// CHECK: ifrt.Call @[[CALLEE_1:.+]]::@main(%arg0)
%1, %ctrl_1 = ifrt.Call @callee::@main(%arg0) on devices [0,1]
: (!array) -> !array
// CHECK: ifrt.Call @[[CALLEE_1:.+]]::@main(%arg0, %arg1)
%1, %ctrl_1 = ifrt.Call @callee::@main(%arg0, %arg1) on devices [0,1]
: (!array, !array) -> !array
return
}

Expand All @@ -188,8 +189,15 @@ module @populate_io_alias {
// CHECK-DAG: ifrt.devices = #ifrt<devices[0, 1]>
// CHECK-DAG: tf.aliasing_output = 0 : i32
// CHECK-SAME: }
// CHECK: %arg1: tensor<2x2xi32>
// CHECK-SAME: {
// CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2>
// CHECK-DAG: ifrt.devices = #ifrt<devices[0, 1]>
// CHECK-DAG: jax.buffer_donor = true
// CHECK-SAME: }
module @callee attributes {sym_visibility = "private"} {
func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> {
func.func private @main(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>)
-> tensor<2x2xi32> {
return %arg0: tensor<2x2xi32>
}
}
Expand Down
42 changes: 39 additions & 3 deletions xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ module @donate_to_reshard_duplicated_arg {
// -----

!array = !ifrt.array<tensor<2xi32>, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]>
module @donate_to_two_calls_error {
module @alias_to_two_calls_error {
func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array)
attributes {ifrt.function} {
%0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1]
Expand All @@ -59,13 +59,49 @@ module @donate_to_two_calls_error {

// -----

!array = !ifrt.array<tensor<2xi32>, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]>
module @donate_to_two_calls_error {
func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array)
attributes {ifrt.function} {
%0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1]
{donated_input_indices=array<i32: 0>} : (!array) -> !array
// expected-error @+1 {{'ifrt.Call' op input #0 of @identity was already donated}}
%1, %ctrl_1 = ifrt.Call @identity(%arg0) on devices [0,1]
{donated_input_indices=array<i32: 0>} : (!array) -> !array
return %0, %1 : !array, !array
}

func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> {
return %arg0 : tensor<2xi32>
}
}

// -----

!array = !ifrt.array<tensor<2xi32>, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]>
module @arg_donated_to_call_not_donated_to_program {
func.func @main(%arg0: !array) -> (!array)
attributes {ifrt.function} {
// expected-error @+1 {{'ifrt.Call' op input #0 has not been donated to the program.}}
%0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1]
{donated_input_indices=array<i32: 0>} : (!array) -> !array
return %0 : !array
}

func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> {
return %arg0 : tensor<2xi32>
}
}

// -----

!array0 = !ifrt.array<tensor<2xi32>,
#ifrt.sharding_param<2 to [0] on 2>, [0, 1]>
!array1 = !ifrt.array<tensor<2xi32>,
#ifrt.sharding_param<2 to [0] on 2>, [2, 3]>
module @program_arg_not_donated_error {
func.func @main(%arg0: !array0) -> (!array1) attributes {ifrt.function} {
// expected-error @+1 {{'ifrt.Reshard' op input has not been donated to the program.}}
// expected-error @+1 {{'ifrt.Reshard' op input #0 has not been donated to the program.}}
%0, %ctrl_0 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1
return %0 : !array1
}
Expand Down Expand Up @@ -167,7 +203,7 @@ module @donate_to_two_copy_arrays_error {
module @program_arg_not_donated_to_remap_error {
func.func @main(%arg0: !array {ifrt.donated}, %arg1: !array) -> (!array)
attributes {ifrt.function} {
// expected-error @+1 {{'ifrt.RemapArrays' op input has not been donated to the program.}}
// expected-error @+1 {{'ifrt.RemapArrays' op input #1 has not been donated to the program.}}
%0 = ifrt.RemapArrays(%arg0, %arg1)
mappings=[#ifrt.array_mapping<0, 0, [#ifrt.mapping<[0:1:1] to [0:1:1]>]>,
#ifrt.array_mapping<1, 0, [#ifrt.mapping<[0:1:1] to [1:2:1]>]>]
Expand Down
55 changes: 54 additions & 1 deletion xla/python/ifrt/ir/tests/verify_call.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func.func @io_aliases_should_only_alias_input_once(
%arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>,
[0,1]>)
attributes {ifrt.function} {
// expected-error@+1 {{'ifrt.Call' op can't alias input #0 more than once}}
// expected-error@+1 {{'ifrt.Call' op can't alias or donate input #0 more than once}}
%0, %1, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1]
{io_aliases=[array<i32: 0, 0>, array<i32: 0, 1>]}
: (!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>,
Expand Down Expand Up @@ -429,4 +429,57 @@ func.func @call_local_view_should_have_valid_shape(

func.func @callee(%arg0: tensor<4x4xi32>) -> tensor<4x4xi32> {
return %arg0 : tensor<4x4xi32>
}

// -----

!array = !ifrt.array<tensor<2x2xi32>,
#ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>
func.func @donate_an_arg_and_alias_another(%arg0: !array, %arg1: !array)
attributes {ifrt.function} {
%0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg1) on devices [0,1]
{donated_input_indices=array<i32: 0>, io_aliases=[array<i32: 1, 0>]}
: (!array, !array) -> !array
return
}

func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>)
-> tensor<2x2xi32> {
return %arg0 : tensor<2x2xi32>
}

// -----

!array = !ifrt.array<tensor<2x2xi32>,
#ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>
func.func @should_only_donate_once(%arg0: !array, %arg1: !array)
attributes {ifrt.function} {
// expected-error@+1 {{'ifrt.Call' op can't donate input #0 more than once}}
%0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg1) on devices [0,1]
{donated_input_indices=array<i32: 0, 0>}
: (!array, !array) -> !array
return
}

func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>)
-> tensor<2x2xi32> {
return %arg0 : tensor<2x2xi32>
}

// -----

!array = !ifrt.array<tensor<2x2xi32>,
#ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>
func.func @should_not_both_donate_and_alias_the_same_arg(
%arg0: !array, %arg1: !array) attributes {ifrt.function} {
// expected-error@+1 {{'ifrt.Call' op can't alias or donate input #0 more than once}}
%0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg1) on devices [0,1]
{donated_input_indices=array<i32: 0>, io_aliases=[array<i32: 0, 0>]}
: (!array, !array) -> !array
return
}

func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>)
-> tensor<2x2xi32> {
return %arg0 : tensor<2x2xi32>
}
46 changes: 45 additions & 1 deletion xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func.func @io_aliases_should_only_alias_input_once(
%arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>,
[0,1]>)
attributes {ifrt.function} {
// expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 more than once}}
// expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias or donate input #0 more than once}}
%0, %1, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0)
{io_aliases=[array<i32: 0, 0>, array<i32: 0, 1>]}
: (!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>,
Expand Down Expand Up @@ -230,3 +230,47 @@ ifrt.LoadedExecutable @callee on devices [0,1]
[0,1]>)
-> !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>,
[0,1]>


// -----

!array = !ifrt.array<tensor<2x2xi32>,
#ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>
func.func @donate_one_arg_and_alias_another_arg(%arg0: !array, %arg1: !array)
attributes {ifrt.function} {
%0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg1)
{donated_input_indices=array<i32: 0>, io_aliases=[array<i32: 1, 0>]}
: (!array, !array) -> !array
return
}

ifrt.LoadedExecutable @callee on devices [0,1] : (!array, !array) -> !array

// -----

!array = !ifrt.array<tensor<2x2xi32>,
#ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>
func.func @should_only_donate_once(%arg0: !array, %arg1: !array)
attributes {ifrt.function} {
// expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't donate input #0 more than once}}
%0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg1)
{donated_input_indices=array<i32: 0, 0>} : (!array, !array) -> !array
return
}

ifrt.LoadedExecutable @callee on devices [0,1] : (!array, !array) -> !array

// -----

!array = !ifrt.array<tensor<2x2xi32>,
#ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>
func.func @should_not_both_donate_and_alias_the_same_arg(
%arg0: !array, %arg1: !array) attributes {ifrt.function} {
// expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias or donate input #0 more than once}}
%0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg1)
{donated_input_indices=array<i32: 0>, io_aliases=[array<i32: 0, 0>]}
: (!array, !array) -> !array
return
}

ifrt.LoadedExecutable @callee on devices [0,1] : (!array, !array) -> !array
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ mlir::LogicalResult PopulateMetadata(xla::ifrt::CallOp call_op,
callee_op.setArgAttr(io_alias_as_array[0], "tf.aliasing_output",
builder.getI32IntegerAttr(io_alias_as_array[1]));
}
for (const auto idx : call_op.getDonatedInputIndices()) {
callee_op.setArgAttr(idx, "jax.buffer_donor", builder.getBoolAttr(true));
}
return mlir::success();
}

Expand Down
Loading

0 comments on commit 81724d4

Please sign in to comment.