From 6bccc1d99fcae0c921348815fe1910e273c83d90 Mon Sep 17 00:00:00 2001 From: mayuyuace Date: Wed, 25 Sep 2024 04:29:28 -0700 Subject: [PATCH] PR #15904: [XLA:GPU]implement sycl platform id Imported from GitHub PR https://github.com/openxla/xla/pull/15904 Copybara import of the project: -- df9b82ad0c35cb3f8ad8253b20a38a74f9318d73 by mayuyuace : implement sycl platform id -- 72cf11f61eed4f729d0e5800401fb26da8693a06 by mayuyuace : remove override' of GetUncachedExecutor Merging this change closes #15904 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15904 from Intel-tensorflow:qiming/implement_sycl_id d43e1084f3f9fa7338bae947dc258330707cbe52 PiperOrigin-RevId: 678642780 --- xla/service/BUILD | 1 + xla/service/computation_placer.cc | 3 +++ xla/service/gpu/BUILD | 2 ++ xla/service/gpu/gpu_executable.cc | 3 +++ xla/service/gpu/gpu_transfer_manager.cc | 10 ++++++++++ xla/stream_executor/sycl/sycl_platform.h | 2 +- 6 files changed, 20 insertions(+), 1 deletion(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index 711faef4367d6..ec78bdfc9c93c 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -4357,6 +4357,7 @@ cc_library( "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/stream_executor/sycl:sycl_platform_id", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", diff --git a/xla/service/computation_placer.cc b/xla/service/computation_placer.cc index ee0cf2932a1e8..43f351a548959 100644 --- a/xla/service/computation_placer.cc +++ b/xla/service/computation_placer.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/sycl/sycl_platform_id.h" #include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" @@ -222,6 +223,8 @@ static bool InitModule() { stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer); xla::ComputationPlacer::RegisterComputationPlacer( stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer); + xla::ComputationPlacer::RegisterComputationPlacer( + stream_executor::sycl::kSyclPlatformId, &CreateComputationPlacer); return true; } static bool module_initialized = InitModule(); diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 19d7259600dd4..eebd86b4fa186 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -597,6 +597,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:scoped_activate_context", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/stream_executor/sycl:sycl_platform_id", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -1211,6 +1212,7 @@ cc_library( "//xla/stream_executor:memory_allocation", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/stream_executor/sycl:sycl_platform_id", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc index e3d939e873a22..e43835397d095 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc @@ -81,6 +81,7 @@ limitations under the License. #include "xla/stream_executor/scoped_module_handle.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/sycl/sycl_platform_id.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" @@ -177,6 +178,8 @@ absl::Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( << std::get(gpu_version_).ToString() << "}, but was {" << std::get(cc).ToString() << "}"; + } else if (platform_id == stream_executor::sycl::kSyclPlatformId) { + // TODO: Add check. } else { return Internal("Unknown platform"); } diff --git a/xla/service/gpu/gpu_transfer_manager.cc b/xla/service/gpu/gpu_transfer_manager.cc index dc770514bdda2..ffff4acdd1dfb 100644 --- a/xla/service/gpu/gpu_transfer_manager.cc +++ b/xla/service/gpu/gpu_transfer_manager.cc @@ -48,6 +48,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/sycl/sycl_platform_id.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -369,11 +370,20 @@ static std::unique_ptr CreateAMDGPUTransferManager() { .getPointerSize(0 /* default address space */)); } +static std::unique_ptr CreateSYCLTransferManager() { + return std::make_unique( + /*id=*/stream_executor::sycl::kSyclPlatformId, + /*pointer_size=*/llvm::DataLayout(xla::gpu::spir::DataLayout()) + .getPointerSize(0 /* default address space */)); +} + static bool InitModule() { xla::TransferManager::RegisterTransferManager( stream_executor::cuda::kCudaPlatformId, &CreateNVPTXTransferManager); xla::TransferManager::RegisterTransferManager( stream_executor::rocm::kROCmPlatformId, &CreateAMDGPUTransferManager); + xla::TransferManager::RegisterTransferManager( + stream_executor::sycl::kSyclPlatformId, &CreateSYCLTransferManager); return true; } diff --git a/xla/stream_executor/sycl/sycl_platform.h b/xla/stream_executor/sycl/sycl_platform.h index 61f0eb3d5372b..7c70e5d17e0f6 100644 --- a/xla/stream_executor/sycl/sycl_platform.h +++ b/xla/stream_executor/sycl/sycl_platform.h @@ -60,7 +60,7 @@ class SyclPlatform : public Platform { // looking in or storing to the Platform's executor cache. // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( - int ordinal) override; + int ordinal); // This platform's name. std::string name_;