Skip to content

Commit

Permalink
Move XLA:CPU plugin API integration point into the PJRT Plugin direct…
Browse files Browse the repository at this point in the history
…ory.

PiperOrigin-RevId: 694571064
  • Loading branch information
changm authored and Google-ML-Automation committed Nov 8, 2024
1 parent d9e4ec5 commit a6d880d
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 4 deletions.
11 changes: 7 additions & 4 deletions xla/pjrt/cpu/cpu_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/base/attributes.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
Expand Down Expand Up @@ -723,12 +724,14 @@ struct CpuClientOptions {
std::function<void(HloModuleConfig&)> customize_hlo_module_config;
};

absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(
CpuClientOptions options);
absl::StatusOr<std::unique_ptr<PjRtClient>> ABSL_DEPRECATED(
"Use public XLA:CPU GetXlaPjrtCpuClient instead")
GetTfrtCpuClient(CpuClientOptions options);

// Deprecated. Use the overload that takes 'options' instead.
inline absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(
bool asynchronous) {
inline absl::StatusOr<std::unique_ptr<PjRtClient>> ABSL_DEPRECATED(
"Use public XLA:CPU GetXlaPjrtCpuClient instead")
GetTfrtCpuClient(bool asynchronous) {
CpuClientOptions options;
options.asynchronous = asynchronous;
return GetTfrtCpuClient(std::move(options));
Expand Down
34 changes: 34 additions & 0 deletions xla/pjrt/plugin/xla_cpu/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
load("//xla:xla.bzl", "xla_cc_test")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:public"],
licenses = ["notice"],
)

cc_library(
name = "xla_cpu_pjrt_client",
srcs = [
"xla_cpu_pjrt_client.cc",
],
hdrs = ["xla_cpu_pjrt_client.h"],
deps = [
"//xla/pjrt:pjrt_client",
"//xla/pjrt/cpu:cpu_client",
"@com_google_absl//absl/status:statusor",
],
)

xla_cc_test(
name = "xla_cpu_pjrt_client_test",
srcs = ["xla_cpu_pjrt_client_test.cc"],
deps = [
":xla_cpu_pjrt_client",
"//xla/pjrt:pjrt_client",
"//xla/pjrt/cpu:cpu_client",
"//xla/tests:xla_internal_test_main",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)
2 changes: 2 additions & 0 deletions xla/pjrt/plugin/xla_cpu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Public PJRT entry point for XLA:CPU. Please use PJRT to access XLA:CPU
functionality.
32 changes: 32 additions & 0 deletions xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"

#include <memory>

#include "absl/status/statusor.h"
#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/pjrt_client.h"

namespace xla {

absl::StatusOr<std::unique_ptr<PjRtClient>> GetXlaPjrtCpuClient(
CpuClientOptions options) {
// TODO(masonchang): Wrap the TFRTCPU Client inside the PJRT Sandwich
return xla::GetTfrtCpuClient(options);
}

} // namespace xla
33 changes: 33 additions & 0 deletions xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_PJRT_PLUGIN_XLA_CPU_XLA_CPU_PJRT_CLIENT_H_
#define XLA_PJRT_PLUGIN_XLA_CPU_XLA_CPU_PJRT_CLIENT_H_

#include <memory>

#include "absl/status/statusor.h"
#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/pjrt_client.h"

namespace xla {

// Public entry point to get an XLA:CPU PjRtClient
absl::StatusOr<std::unique_ptr<PjRtClient>> GetXlaPjrtCpuClient(
CpuClientOptions options);

} // namespace xla

#endif // XLA_PJRT_PLUGIN_XLA_CPU_XLA_CPU_PJRT_CLIENT_H_
30 changes: 30 additions & 0 deletions xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"

#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {

TEST(XlaCpuPjrtClientTest, GetXlaPjrtCpuClient) {
TF_ASSERT_OK_AND_ASSIGN(auto client, GetXlaPjrtCpuClient(CpuClientOptions()));
EXPECT_EQ(client->platform_name(), "cpu");
}

} // namespace xla

0 comments on commit a6d880d

Please sign in to comment.