diff --git a/xla/python/ifrt_proxy/client/BUILD b/xla/python/ifrt_proxy/client/BUILD index f187553813a5b..bfbf24d0a8d9d 100644 --- a/xla/python/ifrt_proxy/client/BUILD +++ b/xla/python/ifrt_proxy/client/BUILD @@ -568,6 +568,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", "@nanobind", "@tsl//tsl/platform:env", "@tsl//tsl/platform:statusor", diff --git a/xla/python/ifrt_proxy/client/py_module.cc b/xla/python/ifrt_proxy/client/py_module.cc index 222719b0b5fb6..c431f2ae9dead 100644 --- a/xla/python/ifrt_proxy/client/py_module.cc +++ b/xla/python/ifrt_proxy/client/py_module.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "xla/python/ifrt_proxy/client/py_module.h" +#include #include #include #include @@ -26,6 +27,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/time/time.h" #include "nanobind/nanobind.h" #include "nanobind/stl/function.h" // IWYU pragma: keep #include "nanobind/stl/optional.h" // IWYU pragma: keep @@ -48,6 +50,7 @@ namespace { struct PyClientConnectionOptions { std::optional> on_disconnect; std::optional> on_connection_update; + std::optional connection_timeout_in_seconds; }; absl::StatusOr> GetClient( @@ -90,6 +93,11 @@ absl::StatusOr> GetClient( }; } + if (py_options.connection_timeout_in_seconds.has_value()) { + options.connection_timeout = + absl::Seconds(*py_options.connection_timeout_in_seconds); + } + { nb::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(client, CreateClient(proxy_server_address, options)); @@ -110,6 +118,9 @@ void BuildIfrtProxySubmodule(nb::module_& m) { nb::arg().none()) .def_rw("on_connection_update", &PyClientConnectionOptions::on_connection_update, + nb::arg().none()) + .def_rw("connection_timeout_in_seconds", + &PyClientConnectionOptions::connection_timeout_in_seconds, nb::arg().none()); sub_module.def("get_client", xla::ValueOrThrowWrapper(GetClient), diff --git a/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py b/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py index 27b5ff1b06089..cceadeb6ef2eb 100644 --- a/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py +++ b/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py @@ -35,10 +35,13 @@ class ConnectionOptions: on_connection_update: Optional, a callback that will be called with status updates about initial connection establishment. The updates will be provided as human-readable strings, and an end-user may find them helpful. + connection_timeout_in_seconds: Optional, the timeout for establishing a + connection to the proxy server. """ on_disconnect: Optional[Callable[[str], None]] = None on_connection_update: Optional[Callable[[str], None]] = None + connection_timeout_in_seconds: Optional[int] = None _backend_created: bool = False @@ -52,6 +55,9 @@ def get_client(proxy_server_address: str) -> xla_client.Client: cpp_options = py_module.ClientConnectionOptions() cpp_options.on_disconnect = _connection_options.on_disconnect cpp_options.on_connection_update = _connection_options.on_connection_update + cpp_options.connection_timeout_in_seconds = ( + _connection_options.connection_timeout_in_seconds + ) client = py_module.get_client(proxy_server_address, cpp_options) if client is not None: _backend_created = True diff --git a/xla/python/xla_extension/ifrt_proxy.pyi b/xla/python/xla_extension/ifrt_proxy.pyi index f65685025e516..b3137a04501e8 100644 --- a/xla/python/xla_extension/ifrt_proxy.pyi +++ b/xla/python/xla_extension/ifrt_proxy.pyi @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== +import datetime from typing import Any, Optional, Callable from xla.python import xla_extension @@ -24,6 +25,7 @@ Client = xla_extension.Client class ClientConnectionOptions: on_disconnect: Optional[Callable[[_Status], None]] = None on_connection_update: Optional[Callable[[str], None]] = None + connection_timeout_in_seconds: Optional[int] = None def get_client(