Skip to content

Commit

Permalink
PR #15904: [XLA:GPU]implement sycl platform id
Browse files Browse the repository at this point in the history
Imported from GitHub PR #15904

Copybara import of the project:

--
df9b82a by mayuyuace <[email protected]>:

implement sycl platform id

--
72cf11f by mayuyuace <[email protected]>:

remove override' of GetUncachedExecutor

Merging this change closes #15904

COPYBARA_INTEGRATE_REVIEW=#15904 from Intel-tensorflow:qiming/implement_sycl_id d43e108
PiperOrigin-RevId: 678642780
  • Loading branch information
mayuyuace authored and Google-ML-Automation committed Sep 25, 2024
1 parent 464d309 commit 6bccc1d
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 1 deletion.
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4357,6 +4357,7 @@ cc_library(
"//xla/stream_executor/cuda:cuda_platform_id",
"//xla/stream_executor/host:host_platform_id",
"//xla/stream_executor/rocm:rocm_platform_id",
"//xla/stream_executor/sycl:sycl_platform_id",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
Expand Down
3 changes: 3 additions & 0 deletions xla/service/computation_placer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/host/host_platform_id.h"
#include "xla/stream_executor/rocm/rocm_platform_id.h"
#include "xla/stream_executor/sycl/sycl_platform_id.h"
#include "xla/types.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -222,6 +223,8 @@ static bool InitModule() {
stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
xla::ComputationPlacer::RegisterComputationPlacer(
stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer);
xla::ComputationPlacer::RegisterComputationPlacer(
stream_executor::sycl::kSyclPlatformId, &CreateComputationPlacer);
return true;
}
static bool module_initialized = InitModule();
2 changes: 2 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ cc_library(
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:scoped_activate_context",
"//xla/stream_executor/rocm:rocm_platform_id",
"//xla/stream_executor/sycl:sycl_platform_id",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down Expand Up @@ -1211,6 +1212,7 @@ cc_library(
"//xla/stream_executor:memory_allocation",
"//xla/stream_executor/cuda:cuda_platform_id",
"//xla/stream_executor/rocm:rocm_platform_id",
"//xla/stream_executor/sycl:sycl_platform_id",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ limitations under the License.
#include "xla/stream_executor/scoped_module_handle.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/sycl/sycl_platform_id.h"
#include "xla/util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -177,6 +178,8 @@ absl::Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions(
<< std::get<se::CudaComputeCapability>(gpu_version_).ToString()
<< "}, but was {" << std::get<se::CudaComputeCapability>(cc).ToString()
<< "}";
} else if (platform_id == stream_executor::sycl::kSyclPlatformId) {
// TODO: Add check.
} else {
return Internal("Unknown platform");
}
Expand Down
10 changes: 10 additions & 0 deletions xla/service/gpu/gpu_transfer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ limitations under the License.
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/rocm/rocm_platform_id.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/sycl/sycl_platform_id.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
Expand Down Expand Up @@ -369,11 +370,20 @@ static std::unique_ptr<xla::TransferManager> CreateAMDGPUTransferManager() {
.getPointerSize(0 /* default address space */));
}

static std::unique_ptr<xla::TransferManager> CreateSYCLTransferManager() {
return std::make_unique<xla::gpu::GpuTransferManager>(
/*id=*/stream_executor::sycl::kSyclPlatformId,
/*pointer_size=*/llvm::DataLayout(xla::gpu::spir::DataLayout())
.getPointerSize(0 /* default address space */));
}

static bool InitModule() {
xla::TransferManager::RegisterTransferManager(
stream_executor::cuda::kCudaPlatformId, &CreateNVPTXTransferManager);
xla::TransferManager::RegisterTransferManager(
stream_executor::rocm::kROCmPlatformId, &CreateAMDGPUTransferManager);
xla::TransferManager::RegisterTransferManager(
stream_executor::sycl::kSyclPlatformId, &CreateSYCLTransferManager);
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/sycl/sycl_platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class SyclPlatform : public Platform {
// looking in or storing to the Platform's executor cache.
// Ownership IS transferred to the caller.
absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
int ordinal) override;
int ordinal);

// This platform's name.
std::string name_;
Expand Down

0 comments on commit 6bccc1d

Please sign in to comment.