diff --git a/xla/backends/cpu/runtime/BUILD b/xla/backends/cpu/runtime/BUILD index 831b87fc0f8af1..8664cecddc3529 100644 --- a/xla/backends/cpu/runtime/BUILD +++ b/xla/backends/cpu/runtime/BUILD @@ -950,14 +950,12 @@ xla_cc_test( "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", - "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", - "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_benchmark", "@tsl//tsl/platform:test_main", ], ) diff --git a/xla/backends/cpu/runtime/sort_thunk.cc b/xla/backends/cpu/runtime/sort_thunk.cc index 30c5e1a1b34897..990bd523cae461 100644 --- a/xla/backends/cpu/runtime/sort_thunk.cc +++ b/xla/backends/cpu/runtime/sort_thunk.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include #include +#include #include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" @@ -131,6 +132,7 @@ static constexpr size_t kMaxElementSize = 16; // Forward declare reference type defined below. template struct Ref; +struct DRef; // Value type to store values loaded from the input buffers. template @@ -145,6 +147,18 @@ struct Value { std::array value_sizes; }; +struct DValue { + DValue(const DRef& ref); // NOLINT + + const void* compared_value(size_t i) const { return value[i].data(); } + + // Use properly aligned byte array to store primitive values. + using ValueStorage = std::array; + std::vector value; + std::vector value_sizes; + size_t n; +}; + // Reference to values stored in the input buffers. template struct Ref { @@ -160,6 +174,20 @@ struct Ref { std::array ptr_sizes; }; +struct DRef { + DRef(std::vector ptr, std::vector ptr_sizes) + : ptr(ptr), ptr_sizes(ptr_sizes), n(ptr.size()) {} + + DRef& operator=(const DValue& value); + DRef& operator=(const DRef& other); + + const void* compared_value(size_t i) const { return ptr[i]; } + + std::vector ptr; + std::vector ptr_sizes; + const size_t n; +}; + template Value::Value(const Ref& ref) : value_sizes(ref.ptr_sizes) { for (size_t i = 0; i < n; ++i) { @@ -167,6 +195,15 @@ Value::Value(const Ref& ref) : value_sizes(ref.ptr_sizes) { } } +DValue::DValue(const DRef& ref) + : value_sizes(ref.ptr_sizes), n(ref.ptr.size()) { + value.reserve(n); + for (size_t i = 0; i < n; ++i) { + value.emplace_back(); + std::memcpy(value[i].data(), ref.ptr[i], ref.ptr_sizes[i]); + } +} + template Ref& Ref::operator=(const Value& value) { DCHECK(ptr_sizes == value.value_sizes); @@ -176,6 +213,14 @@ Ref& Ref::operator=(const Value& value) { return *this; } +DRef& DRef::operator=(const DValue& value) { + DCHECK(ptr_sizes == value.value_sizes); + for (size_t i = 0; i < n; ++i) { + std::memcpy(ptr[i], value.value[i].data(), value.value_sizes[i]); + } + return *this; +} + template Ref& Ref::operator=(const Ref& other) { DCHECK(ptr_sizes == other.ptr_sizes); @@ -185,6 +230,15 @@ Ref& Ref::operator=(const Ref& other) { return *this; } +DRef& DRef::operator=(const DRef& other) { + DCHECK(ptr_sizes == other.ptr_sizes); + const size_t n = other.ptr.size(); + for (size_t i = 0; i < n; ++i) { + std::memcpy(ptr[i], other.ptr[i], other.ptr_sizes[i]); + } + return *this; +} + // Swap function required by `std::sort` and `std::stable_sort` implementations. template void swap(const Ref& lhs, const Ref& rhs) { @@ -196,6 +250,17 @@ void swap(const Ref& lhs, const Ref& rhs) { } } +void swap(const DRef& lhs, const DRef& rhs) { + DCHECK(lhs.ptr_sizes == rhs.ptr_sizes); + const size_t n = lhs.ptr.size(); + for (size_t i = 0; i < n; ++i) { + std::array tmp; + std::memcpy(tmp.data(), lhs.ptr[i], lhs.ptr_sizes[i]); + std::memcpy(lhs.ptr[i], rhs.ptr[i], rhs.ptr_sizes[i]); + std::memcpy(rhs.ptr[i], tmp.data(), lhs.ptr_sizes[i]); + } +} + // An array of pointers to the input data. template struct Ptr { @@ -250,19 +315,72 @@ struct Ptr { std::array ptr_sizes; // pointers sizes in bytes }; +struct DPtr { + using difference_type = std::ptrdiff_t; + + DPtr() = default; + + DPtr(std::vector ptr, std::vector ptr_sizes) + : ptr(ptr), ptr_sizes(ptr_sizes), n(ptr.size()) {} + + DRef operator*() const { return DRef{ptr, ptr_sizes}; } + + DPtr& operator+=(difference_type diff) { + for (size_t i = 0; i < n; ++i) ptr[i] += diff * ptr_sizes[i]; + return *this; + } + + DPtr& operator-=(difference_type diff) { + for (size_t i = 0; i < n; ++i) ptr[i] -= diff * ptr_sizes[i]; + return *this; + } + + DPtr operator+(difference_type diff) const { + std::vector upd(n); + for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] + diff * ptr_sizes[i]; + return DPtr{upd, ptr_sizes}; + } + + DPtr operator-(difference_type diff) const { + std::vector upd(n); + for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] - diff * ptr_sizes[i]; + return DPtr{upd, ptr_sizes}; + } + + // In all comparison operators defined below we use only the ptr at index 0, + // because we know that all pointers change together and this is an + // implementation detail of sort iterator. + + difference_type operator-(const DPtr& rhs) const { + DCHECK(ptr_sizes == rhs.ptr_sizes); + return (ptr[0] - rhs.ptr[0]) / ptr_sizes[0]; + } + + bool operator==(const DPtr& rhs) const { return ptr[0] == rhs.ptr[0]; } + bool operator!=(const DPtr& rhs) const { return ptr[0] != rhs.ptr[0]; } + bool operator>(const DPtr& rhs) const { return ptr[0] > rhs.ptr[0]; } + bool operator<(const DPtr& rhs) const { return ptr[0] < rhs.ptr[0]; } + bool operator>=(const DPtr& rhs) const { return ptr[0] >= rhs.ptr[0]; } + bool operator<=(const DPtr& rhs) const { return ptr[0] <= rhs.ptr[0]; } + + std::vector ptr; // pointers into the input buffers + std::vector ptr_sizes; // pointers sizes in bytes + size_t n; +}; + // We rely on `std::sort` and `std::stable_sort` to sort the raw data. We sort // multiple input buffers together using the same comparator function, so we // need to provide a custom iterator that can access the data of all input // buffers at the same time and swap elements in them. -template +template class SortIterator { public: using iterator_category = std::random_access_iterator_tag; using difference_type = std::ptrdiff_t; - using value_type = Value; - using reference = Ref; - using pointer = Ptr; + using value_type = Value; + using reference = Ref; + using pointer = Ptr; SortIterator() = default; SortIterator(pointer ptr, difference_type stride) @@ -388,8 +506,40 @@ static void SortInplace(const SortDims& sort_dims, int64_t offset, return (*less_than)(data.data()); }; - SortIterator begin(Ptr(ptr, ptr_sizes), - /*stride=*/sort_dims.inner_dim_size); + SortIterator, Ref, Ptr> begin( + Ptr(ptr, ptr_sizes), + /*stride=*/sort_dims.inner_dim_size); + if (is_stable) { + std::stable_sort(begin, begin + sort_dims.sort_dim_size, compare); + } else { + std::sort(begin, begin + sort_dims.sort_dim_size, compare); + } +} + +static void DSortInplace(const SortDims& sort_dims, int64_t offset, + absl::Span data, + absl::Span shapes, bool is_stable, + SortThunk::LessThan* less_than, size_t n) { + std::vector ptr(n); + std::vector ptr_sizes(n); + + for (size_t i = 0; i < n; ++i) { + std::byte* base = reinterpret_cast(data[i].opaque()); + ptr_sizes[i] = primitive_util::ByteWidth(shapes[i].element_type()); + ptr[i] = base + offset * ptr_sizes[i]; + } + + auto compare = [&](const auto& a, const auto& b) { + std::vector data(2 * n); + for (size_t i = 0, j = 0; i < n; i += 1, j += 2) { + data[j] = a.compared_value(i); + data[j + 1] = b.compared_value(i); + } + return (*less_than)(data.data()); + }; + + SortIterator begin(DPtr(ptr, ptr_sizes), + /*stride=*/sort_dims.inner_dim_size); if (is_stable) { std::stable_sort(begin, begin + sort_dims.sort_dim_size, compare); } else { @@ -416,9 +566,15 @@ static absl::Status SortInplace(absl::Span data, is_stable, less_than); }; - // TODO(ezhulenev): We can replace statically known number of sorted inputs - // with a dynamic value, however statically known number of inputs allows - // compiler to generate better code. Benchmark if it really matters. + auto dsort = [&](size_t num_inputs) { + DSortInplace(sort_dims, offset, data, shapes, is_stable, less_than, + num_inputs); + }; + + // use "sort" for statically known number of sorted inputs (expected to be + // faster) and "dsort" for dynamically known number of sorted inputs. + // for 100 elements stable sort is 1.5 times faster than stable dsort. + // for 100 elements unstable sort is 2.47 times faster than unstable dsort. switch (data.size()) { case 1: sort(std::integral_constant{}); @@ -495,11 +651,9 @@ static absl::Status SortInplace(absl::Span data, case 25: sort(std::integral_constant{}); break; - case 29: - sort(std::integral_constant{}); - break; default: - return Internal("Unsupported number of sorted inputs: %d", data.size()); + dsort(data.size()); + break; } } diff --git a/xla/backends/cpu/runtime/sort_thunk_test.cc b/xla/backends/cpu/runtime/sort_thunk_test.cc index 1f450f77548d70..6c8dfae0b65a8e 100644 --- a/xla/backends/cpu/runtime/sort_thunk_test.cc +++ b/xla/backends/cpu/runtime/sort_thunk_test.cc @@ -15,8 +15,10 @@ limitations under the License. #include "xla/backends/cpu/runtime/sort_thunk.h" +#include #include #include +#include #include #include @@ -34,6 +36,7 @@ limitations under the License. #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" namespace xla::cpu { namespace { @@ -100,6 +103,83 @@ TEST_P(SortThunkTest, Sort1D) { EXPECT_EQ(indices, expected_indices); } +TEST_P(SortThunkTest, DynamicSort1D) { + bool is_stable = GetParam(); + + // 33 empty slices + 2 slices with data = 35 slices + // This amount of slices will call the dynamic sort implementation. + constexpr int num_of_empty_slices = 33; + constexpr int total_num_of_slices = num_of_empty_slices + 2; + + // size of each of 33 data buffers + constexpr int data_size = 31; + + // values range will be [5.0, 35.0] + constexpr float starting_value = 5.0f; + + std::array data{ + 17.0f, 16.0f, 5.0f, 10.0f, 30.0f, 8.0f, 9.0f, 21.0f, + 14.0f, 32.0f, 29.0f, 28.0f, 19.0f, 12.0f, 25.0f, 22.0f, + 18.0f, 35.0f, 34.0f, 23.0f, 7.0f, 13.0f, 26.0f, 33.0f, + 15.0f, 24.0f, 20.0f, 31.0f, 6.0f, 27.0f, 11.0f}; + std::array indices; + std::iota(indices.begin(), indices.end(), 0); + + // This is a container for the rest of the buffers. + std::array empty; + + const size_t data_size_in_bytes = data.size() * sizeof(float); + const size_t ind_size_in_bytes = indices.size() * sizeof(int32_t); + const size_t empty_size_in_bytes = empty.size() * sizeof(uint32_t); + + const BufferAllocation alloc0(0, data_size_in_bytes, 0); + const BufferAllocation alloc1(1, ind_size_in_bytes, 0); + const BufferAllocation rest(2, empty_size_in_bytes, 0); + + const BufferAllocation::Slice slice0(&alloc0, 0, data_size_in_bytes); + const BufferAllocation::Slice slice1(&alloc1, 0, ind_size_in_bytes); + + const Shape data_shape = ShapeUtil::MakeShape(F32, {data_size}); + const Shape indices_shape = ShapeUtil::MakeShape(S32, {data_size}); + const Shape rest_shape = ShapeUtil::MakeShape(U32, {data_size}); + + std::vector buffers; + buffers.emplace_back(se::DeviceMemoryBase(data.data(), data_size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(indices.data(), ind_size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(empty.data(), empty_size_in_bytes)); + + BufferAllocations allocations(buffers); + + std::array inputs{ + {{slice0, data_shape}, {slice1, indices_shape}}}; + for (int i = 0; i < num_of_empty_slices; ++i) { + constexpr size_t empty_slice_in_bytes = data_size * sizeof(uint32_t); + inputs[i + 2].slice = BufferAllocation::Slice( + &rest, i * empty_slice_in_bytes, empty_slice_in_bytes); + inputs[i + 2].shape = rest_shape; + } + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, SortThunk::Create({"sort"}, inputs, + /*dimension=*/0, is_stable, LessThan)); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + + std::array expected_data; + std::iota(expected_data.begin(), expected_data.end(), starting_value); + const std::array expected_indices{ + 2, 28, 20, 5, 6, 3, 30, 13, 21, 8, 24, 1, 0, 16, 12, 26, + 7, 15, 19, 25, 14, 22, 29, 11, 10, 4, 27, 9, 23, 18, 17}; + + EXPECT_EQ(data, expected_data); + EXPECT_EQ(indices, expected_indices); +} + TEST_P(SortThunkTest, Sort2D) { bool is_stable = GetParam(); @@ -237,6 +317,99 @@ TEST_P(SortThunkTest, Sort2DWithLayout) { EXPECT_EQ(indices, expected_indices); } +void BM_DynamicSort1D(::testing::benchmark::State& state, bool is_stable) { + const int total_num_of_slices = state.range(0); + const int num_of_empty_slices = total_num_of_slices - 2; + + // size of each of data buffers + constexpr int data_size = 31; + + const std::array data{ + 17.0f, 16.0f, 5.0f, 10.0f, 30.0f, 8.0f, 9.0f, 21.0f, + 14.0f, 32.0f, 29.0f, 28.0f, 19.0f, 12.0f, 25.0f, 22.0f, + 18.0f, 35.0f, 34.0f, 23.0f, 7.0f, 13.0f, 26.0f, 33.0f, + 15.0f, 24.0f, 20.0f, 31.0f, 6.0f, 27.0f, 11.0f}; + std::array indices; + std::iota(indices.begin(), indices.end(), 0); + + // This is the container for the rest of the buffers. + std::vector empty(data_size * num_of_empty_slices); + + const size_t data_size_in_bytes = data.size() * sizeof(float); + const size_t ind_size_in_bytes = indices.size() * sizeof(int32_t); + const size_t empty_size_in_bytes = empty.size() * sizeof(uint32_t); + + const BufferAllocation alloc0(0, data_size_in_bytes, 0); + const BufferAllocation alloc1(1, ind_size_in_bytes, 0); + const BufferAllocation rest(2, empty_size_in_bytes, 0); + + const BufferAllocation::Slice slice0(&alloc0, 0, data_size_in_bytes); + const BufferAllocation::Slice slice1(&alloc1, 0, ind_size_in_bytes); + + const Shape data_shape = ShapeUtil::MakeShape(F32, {data_size}); + const Shape indices_shape = ShapeUtil::MakeShape(S32, {data_size}); + const Shape rest_shape = ShapeUtil::MakeShape(U32, {data_size}); + + for (auto s : state) { + // Pause timing to avoid counting the time spent in the setup. + state.PauseTiming(); + auto data_clone(data); + auto indices_clone(indices); + + std::vector buffers; + buffers.emplace_back( + se::DeviceMemoryBase(data_clone.data(), data_size_in_bytes)); + buffers.emplace_back( + se::DeviceMemoryBase(indices_clone.data(), ind_size_in_bytes)); + buffers.emplace_back( + se::DeviceMemoryBase(empty.data(), empty_size_in_bytes)); + + BufferAllocations allocations(buffers); + + std::vector inputs(total_num_of_slices); + inputs[0] = {slice0, data_shape}; + inputs[1] = {slice1, indices_shape}; + for (int i = 0; i < num_of_empty_slices; ++i) { + constexpr size_t empty_slice_in_bytes = data_size * sizeof(uint32_t); + inputs[i + 2].slice = BufferAllocation::Slice( + &rest, i * empty_slice_in_bytes, empty_slice_in_bytes); + inputs[i + 2].shape = rest_shape; + } + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + + state.ResumeTiming(); + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, SortThunk::Create({"sort"}, inputs, + /*dimension=*/0, is_stable, LessThan)); + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + } +} + +void BM_StableDynamicSort1D(::testing::benchmark::State& state) { + BM_DynamicSort1D(state, /*is_stable=*/true); +} + +void BM_UnstableDynamicSort1D(::testing::benchmark::State& state) { + BM_DynamicSort1D(state, /*is_stable=*/false); +} + +BENCHMARK(BM_StableDynamicSort1D) + ->MeasureProcessCPUTime() + ->Arg(35) + ->Arg(50) + ->Arg(100); + +BENCHMARK(BM_UnstableDynamicSort1D) + ->MeasureProcessCPUTime() + ->Arg(35) + ->Arg(50) + ->Arg(100); + INSTANTIATE_TEST_SUITE_P(SortThunk, SortThunkTest, testing::Bool(), testing::PrintToStringParamName());