Skip to content

Commit

Permalink
Add function WrapClientAroundCApi().
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694514486
  • Loading branch information
matthiaskramm authored and Google-ML-Automation committed Nov 8, 2024
1 parent 9dee234 commit db16547
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 0 deletions.
6 changes: 6 additions & 0 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2339,7 +2339,13 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetCApiClient(
if (c_api == nullptr) {
return Internal("PJRT C API is nullptr for %s", device_type);
}
return WrapClientAroundCApi(c_api, create_options, kv_store);
}

absl::StatusOr<std::unique_ptr<PjRtClient>> WrapClientAroundCApi(
const PJRT_Api* c_api,
const absl::flat_hash_map<std::string, PjRtValueType>& create_options,
std::shared_ptr<KeyValueStoreInterface> kv_store) {
PJRT_Client_Create_Args init_args;
init_args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE;
init_args.extension_start = nullptr;
Expand Down
5 changes: 5 additions & 0 deletions xla/pjrt/pjrt_c_api_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,11 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetCApiClient(
const absl::flat_hash_map<std::string, PjRtValueType>& create_options = {},
std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr);

absl::StatusOr<std::unique_ptr<PjRtClient>> WrapClientAroundCApi(
const PJRT_Api* c_api,
const absl::flat_hash_map<std::string, PjRtValueType>& create_options = {},
std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr);

absl::StatusOr<std::unique_ptr<PjRtTopologyDescription>> GetCApiTopology(
const PJRT_Api* c_api, absl::string_view topology_name,
const absl::flat_hash_map<std::string, PjRtValueType>& create_options);
Expand Down
8 changes: 8 additions & 0 deletions xla/pjrt/pjrt_c_api_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,5 +199,13 @@ TEST(PjRtClientTest, CompileUsesStableHloVersion) {
const_cast<PJRT_Api*>(c_api)->PJRT_Client_Compile = PJRT_Client_Compile_Orig;
}

TEST(PjRtCApiClientTest, WrapClientAroundCApi) {
const PJRT_Api* c_api = ::pjrt::cpu_plugin::GetCpuPjrtApi();
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
WrapClientAroundCApi(c_api));
EXPECT_EQ(client->platform_name(), xla::CpuName());
EXPECT_EQ(client->platform_id(), xla::CpuId());
}

} // namespace
} // namespace xla

0 comments on commit db16547

Please sign in to comment.