Skip to content

Commit

Permalink
[XLA:CPU] Add a generic sort kernel to SortThunk
Browse files Browse the repository at this point in the history
This PR duplicates templated code as non-templated with "D" prefix.
It is done for every function, which uses "n" directly, like loops.
Thus the only unified class is the SortIterator, which operates
on Value, Ref and Ptr abstractions, which, in turn, differ.

PiperOrigin-RevId: 679101106
  • Loading branch information
tvladyslav authored and Google-ML-Automation committed Sep 27, 2024
1 parent a36265a commit 25e52cc
Show file tree
Hide file tree
Showing 3 changed files with 341 additions and 16 deletions.
4 changes: 1 addition & 3 deletions xla/backends/cpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -950,14 +950,12 @@ xla_cc_test(
"//xla/service:buffer_assignment",
"//xla/service:maybe_owning_device_memory",
"//xla/stream_executor",
"//xla/stream_executor/host:host_kernel_c_api",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_benchmark",
"@tsl//tsl/platform:test_main",
],
)
Expand Down
180 changes: 167 additions & 13 deletions xla/backends/cpu/runtime/sort_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/dynamic_annotations.h"
Expand Down Expand Up @@ -131,6 +132,7 @@ static constexpr size_t kMaxElementSize = 16;
// Forward declare reference type defined below.
template <size_t n>
struct Ref;
struct DRef;

// Value type to store values loaded from the input buffers.
template <size_t n>
Expand All @@ -145,6 +147,18 @@ struct Value {
std::array<uint8_t, n> value_sizes;
};

struct DValue {
DValue(const DRef& ref); // NOLINT

const void* compared_value(size_t i) const { return value[i].data(); }

// Use properly aligned byte array to store primitive values.
using ValueStorage = std::array<std::byte, kMaxElementSize>;
std::vector<ValueStorage> value;
std::vector<uint8_t> value_sizes;
size_t n;
};

// Reference to values stored in the input buffers.
template <size_t n>
struct Ref {
Expand All @@ -160,13 +174,36 @@ struct Ref {
std::array<uint8_t, n> ptr_sizes;
};

struct DRef {
DRef(std::vector<std::byte*> ptr, std::vector<uint8_t> ptr_sizes)
: ptr(ptr), ptr_sizes(ptr_sizes), n(ptr.size()) {}

DRef& operator=(const DValue& value);
DRef& operator=(const DRef& other);

const void* compared_value(size_t i) const { return ptr[i]; }

std::vector<std::byte*> ptr;
std::vector<uint8_t> ptr_sizes;
const size_t n;
};

template <size_t n>
Value<n>::Value(const Ref<n>& ref) : value_sizes(ref.ptr_sizes) {
for (size_t i = 0; i < n; ++i) {
std::memcpy(value[i].data(), ref.ptr[i], ref.ptr_sizes[i]);
}
}

DValue::DValue(const DRef& ref)
: value_sizes(ref.ptr_sizes), n(ref.ptr.size()) {
value.reserve(n);
for (size_t i = 0; i < n; ++i) {
value.emplace_back();
std::memcpy(value[i].data(), ref.ptr[i], ref.ptr_sizes[i]);
}
}

template <size_t n>
Ref<n>& Ref<n>::operator=(const Value<n>& value) {
DCHECK(ptr_sizes == value.value_sizes);
Expand All @@ -176,6 +213,14 @@ Ref<n>& Ref<n>::operator=(const Value<n>& value) {
return *this;
}

DRef& DRef::operator=(const DValue& value) {
DCHECK(ptr_sizes == value.value_sizes);
for (size_t i = 0; i < n; ++i) {
std::memcpy(ptr[i], value.value[i].data(), value.value_sizes[i]);
}
return *this;
}

template <size_t n>
Ref<n>& Ref<n>::operator=(const Ref<n>& other) {
DCHECK(ptr_sizes == other.ptr_sizes);
Expand All @@ -185,6 +230,15 @@ Ref<n>& Ref<n>::operator=(const Ref<n>& other) {
return *this;
}

DRef& DRef::operator=(const DRef& other) {
DCHECK(ptr_sizes == other.ptr_sizes);
const size_t n = other.ptr.size();
for (size_t i = 0; i < n; ++i) {
std::memcpy(ptr[i], other.ptr[i], other.ptr_sizes[i]);
}
return *this;
}

// Swap function required by `std::sort` and `std::stable_sort` implementations.
template <size_t n>
void swap(const Ref<n>& lhs, const Ref<n>& rhs) {
Expand All @@ -196,6 +250,17 @@ void swap(const Ref<n>& lhs, const Ref<n>& rhs) {
}
}

void swap(const DRef& lhs, const DRef& rhs) {
DCHECK(lhs.ptr_sizes == rhs.ptr_sizes);
const size_t n = lhs.ptr.size();
for (size_t i = 0; i < n; ++i) {
std::array<std::byte, kMaxElementSize> tmp;
std::memcpy(tmp.data(), lhs.ptr[i], lhs.ptr_sizes[i]);
std::memcpy(lhs.ptr[i], rhs.ptr[i], rhs.ptr_sizes[i]);
std::memcpy(rhs.ptr[i], tmp.data(), lhs.ptr_sizes[i]);
}
}

// An array of pointers to the input data.
template <size_t n>
struct Ptr {
Expand Down Expand Up @@ -250,19 +315,72 @@ struct Ptr {
std::array<uint8_t, n> ptr_sizes; // pointers sizes in bytes
};

struct DPtr {
using difference_type = std::ptrdiff_t;

DPtr() = default;

DPtr(std::vector<std::byte*> ptr, std::vector<uint8_t> ptr_sizes)
: ptr(ptr), ptr_sizes(ptr_sizes), n(ptr.size()) {}

DRef operator*() const { return DRef{ptr, ptr_sizes}; }

DPtr& operator+=(difference_type diff) {
for (size_t i = 0; i < n; ++i) ptr[i] += diff * ptr_sizes[i];
return *this;
}

DPtr& operator-=(difference_type diff) {
for (size_t i = 0; i < n; ++i) ptr[i] -= diff * ptr_sizes[i];
return *this;
}

DPtr operator+(difference_type diff) const {
std::vector<std::byte*> upd(n);
for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] + diff * ptr_sizes[i];
return DPtr{upd, ptr_sizes};
}

DPtr operator-(difference_type diff) const {
std::vector<std::byte*> upd(n);
for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] - diff * ptr_sizes[i];
return DPtr{upd, ptr_sizes};
}

// In all comparison operators defined below we use only the ptr at index 0,
// because we know that all pointers change together and this is an
// implementation detail of sort iterator.

difference_type operator-(const DPtr& rhs) const {
DCHECK(ptr_sizes == rhs.ptr_sizes);
return (ptr[0] - rhs.ptr[0]) / ptr_sizes[0];
}

bool operator==(const DPtr& rhs) const { return ptr[0] == rhs.ptr[0]; }
bool operator!=(const DPtr& rhs) const { return ptr[0] != rhs.ptr[0]; }
bool operator>(const DPtr& rhs) const { return ptr[0] > rhs.ptr[0]; }
bool operator<(const DPtr& rhs) const { return ptr[0] < rhs.ptr[0]; }
bool operator>=(const DPtr& rhs) const { return ptr[0] >= rhs.ptr[0]; }
bool operator<=(const DPtr& rhs) const { return ptr[0] <= rhs.ptr[0]; }

std::vector<std::byte*> ptr; // pointers into the input buffers
std::vector<uint8_t> ptr_sizes; // pointers sizes in bytes
size_t n;
};

// We rely on `std::sort` and `std::stable_sort` to sort the raw data. We sort
// multiple input buffers together using the same comparator function, so we
// need to provide a custom iterator that can access the data of all input
// buffers at the same time and swap elements in them.
template <size_t n>
template <class Value, class Ref, class Ptr>
class SortIterator {
public:
using iterator_category = std::random_access_iterator_tag;
using difference_type = std::ptrdiff_t;

using value_type = Value<n>;
using reference = Ref<n>;
using pointer = Ptr<n>;
using value_type = Value;
using reference = Ref;
using pointer = Ptr;

SortIterator() = default;
SortIterator(pointer ptr, difference_type stride)
Expand Down Expand Up @@ -388,8 +506,40 @@ static void SortInplace(const SortDims& sort_dims, int64_t offset,
return (*less_than)(data.data());
};

SortIterator<n> begin(Ptr<n>(ptr, ptr_sizes),
/*stride=*/sort_dims.inner_dim_size);
SortIterator<Value<n>, Ref<n>, Ptr<n>> begin(
Ptr<n>(ptr, ptr_sizes),
/*stride=*/sort_dims.inner_dim_size);
if (is_stable) {
std::stable_sort(begin, begin + sort_dims.sort_dim_size, compare);
} else {
std::sort(begin, begin + sort_dims.sort_dim_size, compare);
}
}

static void DSortInplace(const SortDims& sort_dims, int64_t offset,
absl::Span<se::DeviceMemoryBase> data,
absl::Span<const Shape> shapes, bool is_stable,
SortThunk::LessThan* less_than, size_t n) {
std::vector<std::byte*> ptr(n);
std::vector<uint8_t> ptr_sizes(n);

for (size_t i = 0; i < n; ++i) {
std::byte* base = reinterpret_cast<std::byte*>(data[i].opaque());
ptr_sizes[i] = primitive_util::ByteWidth(shapes[i].element_type());
ptr[i] = base + offset * ptr_sizes[i];
}

auto compare = [&](const auto& a, const auto& b) {
std::vector<const void*> data(2 * n);
for (size_t i = 0, j = 0; i < n; i += 1, j += 2) {
data[j] = a.compared_value(i);
data[j + 1] = b.compared_value(i);
}
return (*less_than)(data.data());
};

SortIterator<DValue, DRef, DPtr> begin(DPtr(ptr, ptr_sizes),
/*stride=*/sort_dims.inner_dim_size);
if (is_stable) {
std::stable_sort(begin, begin + sort_dims.sort_dim_size, compare);
} else {
Expand All @@ -416,9 +566,15 @@ static absl::Status SortInplace(absl::Span<se::DeviceMemoryBase> data,
is_stable, less_than);
};

// TODO(ezhulenev): We can replace statically known number of sorted inputs
// with a dynamic value, however statically known number of inputs allows
// compiler to generate better code. Benchmark if it really matters.
auto dsort = [&](size_t num_inputs) {
DSortInplace(sort_dims, offset, data, shapes, is_stable, less_than,
num_inputs);
};

// use "sort" for statically known number of sorted inputs (expected to be
// faster) and "dsort" for dynamically known number of sorted inputs.
// for 100 elements stable sort is 1.5 times faster than stable dsort.
// for 100 elements unstable sort is 2.47 times faster than unstable dsort.
switch (data.size()) {
case 1:
sort(std::integral_constant<size_t, 1>{});
Expand Down Expand Up @@ -495,11 +651,9 @@ static absl::Status SortInplace(absl::Span<se::DeviceMemoryBase> data,
case 25:
sort(std::integral_constant<size_t, 25>{});
break;
case 29:
sort(std::integral_constant<size_t, 29>{});
break;
default:
return Internal("Unsupported number of sorted inputs: %d", data.size());
dsort(data.size());
break;
}
}

Expand Down
Loading

0 comments on commit 25e52cc

Please sign in to comment.