Skip to content

Commit

Permalink
PR #14088: Create executable buffer at the appropriate memory space o…
Browse files Browse the repository at this point in the history
…n GPUs

Imported from GitHub PR #14088

Copybara import of the project:

--
fa7d607 by Jaroslav Sevcik <[email protected]>:

Create executable buffer at the appropriate memory space

Merging this change closes #14088

FUTURE_COPYBARA_INTEGRATE_REVIEW=#14088 from jaro-sevcik:execute-output-buffers-with-memory-kind fa7d607
PiperOrigin-RevId: 646386007
  • Loading branch information
jaro-sevcik authored and copybara-github committed Jun 25, 2024
1 parent ca0f4b3 commit 74cf39b
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 1 deletion.
1 change: 1 addition & 0 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ cc_library(
"//xla/client:local_client",
"//xla/client:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/pjrt:host_memory_spaces",
"//xla/pjrt/distributed:protocol_proto_cc",
"//xla/service:compiler",
"//xla/service:computation_layout",
Expand Down
96 changes: 96 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <cstring>
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -780,5 +781,100 @@ TEST(StreamExecutorGpuClientTest, MockNcclClientTest) {
}
}

namespace {

absl::StatusOr<std::unique_ptr<PjRtBuffer>> CreateDeviceBufferForTest(
xla::PjRtClient* client) {
auto device = client->addressable_devices()[0];
TF_EXPECT_OK(device->default_memory_space());

std::vector<int32_t> data{1, 2, 3, 4};
Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, {4}, {0});
TF_ASSIGN_OR_RETURN(
auto input, client->BufferFromHostBuffer(
data.data(), shape.element_type(), shape.dimensions(),
/*byte_strides=*/std::nullopt,
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
/*on_done_with_host_buffer=*/nullptr, device));
EXPECT_EQ(input->memory_space()->kind(), "device");
return input;
}

} // namespace

TEST(StreamExecutorGpuClientTest, ExecutePinnedHostOutputTest) {
TF_ASSERT_OK_AND_ASSIGN(auto client,
GetStreamExecutorGpuClient(GpuClientOptions()));
TF_ASSERT_OK_AND_ASSIGN(auto input, CreateDeviceBufferForTest(client.get()));

static constexpr char const* kD2HProgram = R"(
HloModule f
ENTRY main.5 {
p = s32[4]{0} parameter(0)
ROOT cc = s32[4] custom-call(p),
custom_call_target="annotate_device_placement",
frontend_attributes={_xla_buffer_placement="pinned_host"}
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto executable,
CompileExecutable(kD2HProgram, *client));
TF_ASSERT_OK_AND_ASSIGN(
auto result, executable->Execute({{input.get()}}, ExecuteOptions()));

std::vector<std::unique_ptr<xla::PjRtBuffer>>& result_buffers = result[0];
EXPECT_EQ(result_buffers[0]->memory_space()->kind(), "pinned_host");

TF_ASSERT_OK_AND_ASSIGN(auto memory_stats,
executable->GetCompiledMemoryStats());
EXPECT_EQ(memory_stats.output_size_in_bytes, 0);
EXPECT_EQ(memory_stats.host_output_size_in_bytes, 16);
}

TEST(StreamExecutorGpuClientTest, ExecutePinnedHostOutputTupleTest) {
TF_ASSERT_OK_AND_ASSIGN(auto client,
GetStreamExecutorGpuClient(GpuClientOptions()));
TF_ASSERT_OK_AND_ASSIGN(auto input, CreateDeviceBufferForTest(client.get()));

static constexpr char const* kD2HProgram = R"(
HloModule f
ENTRY main.5 {
p = s32[4]{0} parameter(0)
cc = s32[4] custom-call(p),
custom_call_target="annotate_device_placement",
frontend_attributes={_xla_buffer_placement="pinned_host"}
ROOT tuple = (s32[4]{0}, s32[4]{0}) tuple(s32[4]{0} p, s32[4]{0} cc)
}
)";

// Build the output shape with the correct memory space set.
Shape host_shape = input->on_device_shape();
host_shape.mutable_layout()->set_memory_space(Layout::kHostMemorySpace);
Shape out_shape =
ShapeUtil::MakeTupleShape({input->on_device_shape(), host_shape});

// Set the result layout so that the compiler assertions on memory
// spaces pass.
xla::CompileOptions options;
options.executable_build_options.set_result_layout(out_shape);

TF_ASSERT_OK_AND_ASSIGN(auto executable,
CompileExecutable(kD2HProgram, *client, options));

// Untuple the result so that we get separate buffers.
// This is how JAX invokes XLA.
ExecuteOptions execute_options;
execute_options.untuple_result = true;
TF_ASSERT_OK_AND_ASSIGN(
auto result, executable->Execute({{input.get()}}, execute_options));

std::vector<std::unique_ptr<xla::PjRtBuffer>>& result_buffers = result[0];
EXPECT_EQ(result_buffers.size(), 2);
EXPECT_EQ(result_buffers[0]->memory_space()->kind(), "device");
EXPECT_EQ(result_buffers[1]->memory_space()->kind(), "pinned_host");
}

} // namespace
} // namespace xla
14 changes: 13 additions & 1 deletion xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ limitations under the License.
#include "xla/pjrt/distributed/protocol.pb.h"
#include "xla/pjrt/event_pool.h"
#include "xla/pjrt/host_callback.h"
#include "xla/pjrt/host_memory_spaces.h"
#include "xla/pjrt/local_device_state.h"
#include "xla/pjrt/metrics.h"
#include "xla/pjrt/mlir_to_hlo.h"
Expand Down Expand Up @@ -2240,9 +2241,20 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
std::shared_ptr<TrackedDeviceBuffer> out_buffer =
TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
{definition_event});
Shape shape = result_buffer->on_device_shape();
PjRtMemorySpace* memory_space =
device->default_memory_space().value_or(nullptr);
if (shape.has_layout() &&
shape.layout().memory_space() == Layout::kHostMemorySpace) {
absl::StatusOr<PjRtMemorySpace*> memory_space_or =
device->memory_space_by_kind(PinnedHostMemorySpace::kKind);
if (memory_space_or.ok()) {
memory_space = memory_space_or.value();
}
}
auto pjrt_buffer = std::make_unique<PjRtStreamExecutorBuffer>(
result_buffer->on_device_shape(), std::move(out_buffer), client, device,
device->default_memory_space().value_or(nullptr));
memory_space);
RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
definition_event, local_device->compute_stream(),
/*prefer_to_retain_reference=*/false, &buffers_to_release);
Expand Down

0 comments on commit 74cf39b

Please sign in to comment.