Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:cpu] Move BufferAllocations implementation to header file #14148

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions xla/service/cpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,36 @@ package_group(

cc_library(
name = "buffer_allocations",
srcs = ["buffer_allocations.cc"],
hdrs = ["buffer_allocations.h"],
deps = [
"//xla:util",
"//xla/service:buffer_assignment",
"//xla/service:maybe_owning_device_memory",
"//xla/stream_executor",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "buffer_allocations_test",
srcs = ["buffer_allocations_test.cc"],
deps = [
":buffer_allocations",
"//xla/service:buffer_assignment",
"//xla/service:maybe_owning_device_memory",
"//xla/stream_executor:device_memory",
"@com_google_absl//absl/status",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)

cc_library(
name = "task",
hdrs = ["task.h"],
Expand Down
76 changes: 0 additions & 76 deletions xla/service/cpu/runtime/buffer_allocations.cc

This file was deleted.

85 changes: 74 additions & 11 deletions xla/service/cpu/runtime/buffer_allocations.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,102 @@ limitations under the License.
#ifndef XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_
#define XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_

#include <cstddef>
#include <cstdint>
#include <vector>

#include "absl/base/attributes.h"
#include "absl/base/optimization.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/maybe_owning_device_memory.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/util.h"

namespace xla::cpu {

// Buffer allocation is a container for device buffers allocated for a
// particular XLA execution. Buffers are indexed by the buffer allocation index.
//
// TODO(b/342513610): BufferAllocations should be unified with a same class in
// the XLA:GPU runtime, probably as a part of `buffer_assignment.h`.
class BufferAllocations {
public:
explicit BufferAllocations(absl::Span<const MaybeOwningDeviceMemory> buffers)
: buffers_(buffers) {}
explicit inline BufferAllocations(
absl::Span<const MaybeOwningDeviceMemory> buffers);

// Returns the device address of buffer `buffer_index`. `buffer_index` must be
// a valid index, i.e., in [0, buffer_count).
absl::StatusOr<se::DeviceMemoryBase> GetDeviceAddress(
BufferAllocation::Index buffer_index) const;
inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::StatusOr<se::DeviceMemoryBase>
GetDeviceAddress(BufferAllocation::Index buffer_index) const;

// Same as above, but also adjusts the returned address for the offset and
// size contained in the given slice.
absl::StatusOr<se::DeviceMemoryBase> GetDeviceAddress(
const BufferAllocation::Slice& buffer_slice) const;
inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::StatusOr<se::DeviceMemoryBase>
GetDeviceAddress(const BufferAllocation::Slice& buffer_slice) const;

private:
// TODO(ezhulenev): Make BufferAllocations an owner of the buffers.
absl::Span<const MaybeOwningDeviceMemory> buffers_; // not owned
std::vector<se::DeviceMemoryBase> buffers_;
size_t num_buffers_;
};

BufferAllocations::BufferAllocations(
absl::Span<const MaybeOwningDeviceMemory> buffers)
: buffers_(buffers.size()), num_buffers_(buffers_.size()) {
for (size_t i = 0; i < buffers.size(); ++i) {
buffers_[i] = buffers[i].AsDeviceMemoryBase();
}
}

absl::StatusOr<se::DeviceMemoryBase> BufferAllocations::GetDeviceAddress(
BufferAllocation::Index index) const {
if (ABSL_PREDICT_FALSE(index < 0 || index >= num_buffers_)) {
return InvalidArgument(
"Invalid buffer index %d. It must be in the range [0, %d)", index,
num_buffers_);
}

return buffers_[index];
}

absl::StatusOr<se::DeviceMemoryBase> BufferAllocations::GetDeviceAddress(
const BufferAllocation::Slice& buffer_slice) const {
// Handle empty slices explicitly and return a null pointer device memory to
// guarantee that we do not accidentally write through the empty slice which
// would hide a real bug in the code.
if (ABSL_PREDICT_FALSE(buffer_slice.size() == 0)) {
return se::DeviceMemoryBase(nullptr, 0);
}

int64_t index = buffer_slice.index();
if (ABSL_PREDICT_FALSE(index < 0 || index >= num_buffers_)) {
return InvalidArgument(
"Invalid buffer index %d. It must be in the range [0, %d)", index,
num_buffers_);
}
const se::DeviceMemoryBase& base = buffers_[index];

int64_t offset = buffer_slice.offset();
int64_t extent = offset + buffer_slice.size();

if (ABSL_PREDICT_FALSE(offset < 0)) {
return InvalidArgument("Buffer slice offset %d must be non-negative",
offset);
}

if (ABSL_PREDICT_FALSE(offset >= base.size())) {
return InvalidArgument(
"Buffer slice offset %d is out of range for buffer #%d of size %d",
offset, index, base.size());
}

if (ABSL_PREDICT_FALSE(extent > base.size())) {
return InvalidArgument(
"Buffer slice extent %d is out of range for buffer #%d of size %d",
extent, index, base.size());
}

return base.GetByteSlice(offset, buffer_slice.size());
}

} // namespace xla::cpu

#endif // XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_
53 changes: 53 additions & 0 deletions xla/service/cpu/runtime/buffer_allocations_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/cpu/runtime/buffer_allocations.h"

#include <cstddef>
#include <vector>

#include "xla/service/buffer_assignment.h"
#include "xla/service/maybe_owning_device_memory.h"
#include "xla/stream_executor/device_memory.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla::cpu {
namespace {

TEST(BufferAllocationsTest, GetDeviceAddress) {
std::vector<MaybeOwningDeviceMemory> buffers;
std::vector<float> data = {1.0, 2.0, 3.0, 4.0};

size_t size_in_bytes = data.size() * sizeof(float);
buffers.emplace_back(se::DeviceMemoryBase(data.data(), size_in_bytes));

BufferAllocations allocations(buffers);

BufferAllocation alloc(0, size_in_bytes, 0);
BufferAllocation::Slice slice(&alloc, /*offset=*/2 * sizeof(float),
/*size=*/sizeof(float));

TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemoryBase alloc_mem,
allocations.GetDeviceAddress(0));
EXPECT_EQ(alloc_mem.opaque(), &data[0]);

TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemoryBase slice_mem,
allocations.GetDeviceAddress(slice));
EXPECT_EQ(slice_mem.opaque(), &data[2]);
}

} // namespace
} // namespace xla::cpu
Loading