diff --git a/xla/python/ifrt/BUILD b/xla/python/ifrt/BUILD index 27118ddf854c31..78cb7d039bcd0f 100644 --- a/xla/python/ifrt/BUILD +++ b/xla/python/ifrt/BUILD @@ -87,6 +87,7 @@ cc_library( ":attribute_map", ":device_proto_cc", ":dtype_proto_cc", + ":execute_options_proto_cc", ":remap_plan_proto_cc", ":serdes", ":shape_proto_cc", @@ -109,6 +110,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/functional:function_ref", @@ -165,6 +167,11 @@ xla_cc_test( ], ) +tf_proto_library( + name = "execute_options_proto", + srcs = ["execute_options.proto"], +) + xla_cc_test( name = "future_test", size = "small", diff --git a/xla/python/ifrt/executable.cc b/xla/python/ifrt/executable.cc index 77cabe7f6a9389..509c8bbd4d47c7 100644 --- a/xla/python/ifrt/executable.cc +++ b/xla/python/ifrt/executable.cc @@ -15,11 +15,40 @@ limitations under the License. #include "xla/python/ifrt/executable.h" +#include "absl/status/statusor.h" +#include "xla/python/ifrt/execute_options.pb.h" + namespace xla { namespace ifrt { char Executable::ID = 0; char LoadedExecutable::ID = 0; +absl::StatusOr ExecuteOptions::ToProto() const { + ExecuteOptionsProto proto; + + proto.set_arguments_are_tupled(arguments_are_tupled); + proto.set_untuple_result(untuple_result); + proto.set_launch_id(launch_id); + proto.mutable_non_donatable_input_indices()->Add( + non_donatable_input_indices.begin(), non_donatable_input_indices.end()); + + return proto; +} + +absl::StatusOr ExecuteOptions::FromProto( + const xla::ifrt::ExecuteOptionsProto& proto) { + ExecuteOptions options; + + options.arguments_are_tupled = proto.arguments_are_tupled(); + options.untuple_result = proto.untuple_result(); + options.launch_id = proto.launch_id(); + options.non_donatable_input_indices.insert( + proto.non_donatable_input_indices().begin(), + proto.non_donatable_input_indices().end()); + + return options; +} + } // namespace ifrt } // namespace xla diff --git a/xla/python/ifrt/executable.h b/xla/python/ifrt/executable.h index 6b642bd5d178d6..de7cb1928c947e 100644 --- a/xla/python/ifrt/executable.h +++ b/xla/python/ifrt/executable.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -33,6 +34,7 @@ limitations under the License. #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/execute_options.pb.h" #include "xla/python/ifrt/future.h" #include "xla/tsl/concurrency/ref_count.h" @@ -104,6 +106,38 @@ class Executable : public llvm::RTTIExtends { static char ID; // NOLINT }; +struct ExecuteOptions { + // If true, the client must pass a single IFRT array which contains all of the + // arguments as a single XLA tuple, otherwise each argument must be passed in + // its own IFRT array. May only be true if the executable was compiled with + // parameter_is_tupled_arguments==true. + bool arguments_are_tupled = false; + + // If true, the computation must return a tuple, which will be destructured + // into its elements. + bool untuple_result = false; + + // If non-zero, identifies this execution as part of a potentially + // multi-device launch. This can be used to detect scheduling errors, e.g. if + // multi-host programs are launched in different orders on different hosts, + // the launch IDs may be used by the runtime to detect the mismatch. + int32_t launch_id = 0; + + // A set of indices denoting the input arrays that should not be donated. An + // input array may be non-donable, for example, if it is referenced more than + // once. Since such runtime information is not available at compile time, the + // compiler might mark the input as `may-alias`, which could lead IFRT to + // donate the input array when it should not. By defining this set of indices, + // a higher-level IFRT caller can instruct IFRT client not to donate specific + // input arrays. + absl::flat_hash_set non_donatable_input_indices; + + absl::StatusOr ToProto() const; + + static absl::StatusOr FromProto( + const ExecuteOptionsProto& proto); +}; + // Wraps a computation that has been fully compiled and loaded for execution. class LoadedExecutable : public llvm::RTTIExtends { @@ -176,8 +210,7 @@ class LoadedExecutable // `LoadedExecutable` methods. - // Short-term alias. - using ExecuteOptions = ::xla::ExecuteOptions; + using ExecuteOptions = xla::ifrt::ExecuteOptions; // Result from an execution. struct ExecuteResult { diff --git a/xla/python/ifrt/execute_options.proto b/xla/python/ifrt/execute_options.proto new file mode 100644 index 00000000000000..68ad3974f03174 --- /dev/null +++ b/xla/python/ifrt/execute_options.proto @@ -0,0 +1,27 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +message ExecuteOptionsProto { + bool arguments_are_tupled = 1; + bool untuple_result = 2; + int32 launch_id = 3; + repeated int32 non_donatable_input_indices = 7; + + reserved 4 to 6, 8; +} diff --git a/xla/python/ifrt_proxy/common/BUILD b/xla/python/ifrt_proxy/common/BUILD index aa971c34b7155c..9d6b3bbecb4d96 100644 --- a/xla/python/ifrt_proxy/common/BUILD +++ b/xla/python/ifrt_proxy/common/BUILD @@ -72,6 +72,7 @@ tf_proto_library( "//xla/pjrt:execute_options_proto", "//xla/python/ifrt:attribute_map_proto", "//xla/python/ifrt:dtype_proto", + "//xla/python/ifrt:execute_options_proto", "//xla/python/ifrt:remap_plan_proto", "//xla/python/ifrt:serdes_proto", "//xla/python/ifrt:shape_proto", diff --git a/xla/python/ifrt_proxy/common/ifrt_service.proto b/xla/python/ifrt_proxy/common/ifrt_service.proto index 4a342c253af9cc..3f17ee69abbd57 100644 --- a/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -17,9 +17,9 @@ syntax = "proto3"; package xla.ifrt.proxy; import "google/protobuf/any.proto"; -import "xla/pjrt/execute_options.proto"; import "xla/python/ifrt/attribute_map.proto"; import "xla/python/ifrt/dtype.proto"; +import "xla/python/ifrt/execute_options.proto"; import "xla/python/ifrt/remap_plan.proto"; import "xla/python/ifrt/serdes.proto"; import "xla/python/ifrt/shape.proto"; @@ -428,7 +428,7 @@ message LoadedExecutableMetadataResponse { message LoadedExecutableExecuteRequest { fixed64 loaded_executable_handle = 1; repeated fixed64 args_handles = 2; - xla.ExecuteOptionsProto execute_options = 3; + xla.ifrt.ExecuteOptionsProto execute_options = 3; repeated int32 device_ids = 4; } message LoadedExecutableExecuteResponse { diff --git a/xla/python/pjrt_ifrt/pjrt_executable.cc b/xla/python/pjrt_ifrt/pjrt_executable.cc index 19c441cdea448d..243c8583b6d6f7 100644 --- a/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -542,7 +542,12 @@ PjRtLoadedExecutable::Execute( const bool returned_future_supported = pjrt_loaded_executable_->IsReturnedFutureSupported(); - auto opts = options; + xla::ExecuteOptions opts; + opts.arguments_are_tupled = options.arguments_are_tupled; + opts.untuple_result = options.untuple_result; + opts.launch_id = options.launch_id; + opts.use_major_to_minor_data_layout_for_callbacks = true; + opts.non_donatable_input_indices = options.non_donatable_input_indices; if (!all_loaded_host_callbacks_->empty() && !returned_future_supported) { return Internal( @@ -565,9 +570,7 @@ PjRtLoadedExecutable::Execute( contexts.push_back(CreateHostCallbackStateAndAppendSendRecvCallbacks( host_send_recv_callback->host_callback(), /*host_memory_for_device_manager=*/nullptr, send_callbacks, - recv_callbacks, - /*use_major_to_minor_data_layout_for_callbacks=*/ - options.use_major_to_minor_data_layout_for_callbacks)); + recv_callbacks, opts.use_major_to_minor_data_layout_for_callbacks)); } } opts.send_callbacks = host_callback_states->send_callbacks; diff --git a/xla/python/py_executable.cc b/xla/python/py_executable.cc index b7395ad7793050..07655f212d7666 100644 --- a/xla/python/py_executable.cc +++ b/xla/python/py_executable.cc @@ -99,7 +99,6 @@ PyLoadedExecutable::PyLoadedExecutable( VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name() << ": " << *fingerprint_; } - options_.use_major_to_minor_data_layout_for_callbacks = true; } PyLoadedExecutable::~PyLoadedExecutable() { @@ -203,7 +202,7 @@ void PopulateExecuteShardedResults( template > absl::StatusOr ExecuteShardedOnLocalDevicesInternal( - const ExecuteOptions& options, const nb_class_ptr& client, + const ifrt::ExecuteOptions& options, const nb_class_ptr& client, ifrt::LoadedExecutable* ifrt_loaded_executable, absl::Span args, std::optional>>& returned_futures, bool attach_status_to_results) { diff --git a/xla/python/py_executable.h b/xla/python/py_executable.h index ed34ce99ef1a89..e032ee7b4acdda 100644 --- a/xla/python/py_executable.h +++ b/xla/python/py_executable.h @@ -227,7 +227,7 @@ class PyLoadedExecutable { return exec->shared_ptr_pjrt_loaded_executable(); } - const ExecuteOptions& options() const { return options_; } + const ifrt::ExecuteOptions& options() const { return options_; } const std::optional& fingerprint() const { return fingerprint_; } // Keep `obj` alive as long as PyLoadedExecutable. @@ -246,7 +246,7 @@ class PyLoadedExecutable { std::optional fingerprint_; // The options to pass to `executable_.Execute`. - ExecuteOptions options_; + ifrt::ExecuteOptions options_; // Python objects to keep alive as requested by user. std::vector keepalives_;