Skip to content

Commit

Permalink
Add avx512_argselect
Browse files Browse the repository at this point in the history
  • Loading branch information
r-devulap committed Jul 18, 2023
1 parent 85f4e9c commit c1419da
Show file tree
Hide file tree
Showing 6 changed files with 678 additions and 302 deletions.
102 changes: 102 additions & 0 deletions src/avx512-64bit-argsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,39 @@ inline void argsort_64bit_(type_t *arr,
argsort_64bit_<vtype>(arr, arg, pivot_index, right, max_iters - 1);
}

template <typename vtype, typename type_t>
static void argselect_64bit_(type_t *arr,
int64_t *arg,
int64_t pos,
int64_t left,
int64_t right,
int64_t max_iters)
{
/*
* Resort to std::sort if quicksort isnt making any progress
*/
if (max_iters <= 0) {
std_argsort(arr, arg, left, right + 1);
return;
}
/*
* Base case: use bitonic networks to sort arrays <= 64
*/
if (right + 1 - left <= 64) {
argsort_64_64bit<vtype>(arr, arg + left, (int32_t)(right + 1 - left));
return;
}
type_t pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
type_t smallest = vtype::type_max();
type_t biggest = vtype::type_min();
int64_t pivot_index = partition_avx512_unrolled<vtype, 4>(
arr, arg, left, right + 1, pivot, &smallest, &biggest);
if ((pivot != smallest) && (pos < pivot_index))
argselect_64bit_<vtype>(arr, arg, pos, left, pivot_index - 1, max_iters - 1);
else if ((pivot != biggest) && (pos >= pivot_index))
argselect_64bit_<vtype>(arr, arg, pos, pivot_index, right, max_iters - 1);
}

template <typename vtype, typename type_t>
bool has_nan(type_t* arr, int64_t arrsize)
{
Expand Down Expand Up @@ -310,6 +343,8 @@ bool has_nan(type_t* arr, int64_t arrsize)
return found_nan;
}


/* argsort methods for 32-bit and 64-bit dtypes */
template <typename T>
void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
{
Expand Down Expand Up @@ -375,4 +410,71 @@ std::vector<int64_t> avx512_argsort(T* arr, int64_t arrsize)
return indices;
}

/* argselect methods for 32-bit and 64-bit dtypes */
template <typename T>
void avx512_argselect(T* arr, int64_t *arg, int64_t k, int64_t arrsize)
{
if (arrsize > 1) {
argselect_64bit_<zmm_vector<T>>(
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
}

template <>
void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize)
{
if (arrsize > 1) {
if (has_nan<zmm_vector<double>>(arr, arrsize)) {
/* FIXME: no need to do a full argsort */
std_argsort_withnan(arr, arg, 0, arrsize);
}
else {
argselect_64bit_<zmm_vector<double>>(
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
}
}

template <>
void avx512_argselect(int32_t* arr, int64_t *arg, int64_t k, int64_t arrsize)
{
if (arrsize > 1) {
argselect_64bit_<ymm_vector<int32_t>>(
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
}

template <>
void avx512_argselect(uint32_t* arr, int64_t *arg, int64_t k, int64_t arrsize)
{
if (arrsize > 1) {
argselect_64bit_<ymm_vector<uint32_t>>(
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
}

template <>
void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize)
{
if (arrsize > 1) {
if (has_nan<ymm_vector<float>>(arr, arrsize)) {
/* FIXME: no need to do a full argsort */
std_argsort_withnan(arr, arg, 0, arrsize);
}
else {
argselect_64bit_<ymm_vector<float>>(
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
}
}

template <typename T>
std::vector<int64_t> avx512_argselect(T* arr, int64_t k, int64_t arrsize)
{
std::vector<int64_t> indices(arrsize);
std::iota(indices.begin(), indices.end(), 0);
avx512_argselect<T>(arr, indices.data(), k, arrsize);
return indices;
}

#endif // AVX512_ARGSORT_64BIT
5 changes: 5 additions & 0 deletions src/avx512-common-argsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize);
template <typename T>
std::vector<int64_t> avx512_argsort(T *arr, int64_t arrsize);

template <typename T>
void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize);

template <typename T>
std::vector<int64_t> avx512_argselect(T *arr, int64_t k, int64_t arrsize);
/*
* Parition one ZMM register based on the pivot and returns the index of the
* last element that is less than equal to the pivot.
Expand Down
265 changes: 265 additions & 0 deletions tests/test-argselect.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
/*******************************************
* * Copyright (C) 2023 Intel Corporation
* * SPDX-License-Identifier: BSD-3-Clause
* *******************************************/

template <typename T>
class avx512argselect : public ::testing::Test {
};
TYPED_TEST_SUITE_P(avx512argselect);

TYPED_TEST_P(avx512argselect, test_random)
{
if (cpu_has_avx512bw()) {
const int arrsize = 1024;
auto arr = get_uniform_rand_array<TypeParam>(arrsize);
std::vector<int64_t> kth;
for (int64_t ii = 0; ii < arrsize; ++ii) {
kth.push_back(ii);
}
std::vector<int64_t> sorted_inx = std_argsort(arr);
for (auto &k : kth) {
std::vector<int64_t> inx
= avx512_argselect<TypeParam>(arr.data(), k, arr.size());
EXPECT_EQ(arr[sorted_inx[k]], arr[inx[k]]) << "Failed at index k = " << k;
EXPECT_UNIQUE(inx)
}
}
else {
GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
}
}

//TYPED_TEST_P(avx512argselect, test_constant)
//{
// if (cpu_has_avx512bw()) {
// std::vector<int64_t> arrsizes;
// for (int64_t ii = 0; ii <= 1024; ++ii) {
// arrsizes.push_back(ii);
// }
// std::vector<TypeParam> arr;
// for (auto &size : arrsizes) {
// /* constant array */
// auto elem = get_uniform_rand_array<TypeParam>(1)[0];
// for (int64_t jj = 0; jj < size; ++jj) {
// arr.push_back(elem);
// }
// std::vector<int64_t> inx1 = std_argsort(arr);
// std::vector<int64_t> inx2
// = avx512_argsort<TypeParam>(arr.data(), arr.size());
// std::vector<TypeParam> sort1, sort2;
// for (size_t jj = 0; jj < size; ++jj) {
// sort1.push_back(arr[inx1[jj]]);
// sort2.push_back(arr[inx2[jj]]);
// }
// EXPECT_EQ(sort1, sort2) << "Array size =" << size;
// EXPECT_UNIQUE(inx2)
// arr.clear();
// }
// }
// else {
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
// }
//}
//
//TYPED_TEST_P(avx512argselect, test_small_range)
//{
// if (cpu_has_avx512bw()) {
// std::vector<int64_t> arrsizes;
// for (int64_t ii = 0; ii <= 1024; ++ii) {
// arrsizes.push_back(ii);
// }
// std::vector<TypeParam> arr;
// for (auto &size : arrsizes) {
// /* array with a smaller range of values */
// arr = get_uniform_rand_array<TypeParam>(size, 20, 1);
// std::vector<int64_t> inx1 = std_argsort(arr);
// std::vector<int64_t> inx2
// = avx512_argsort<TypeParam>(arr.data(), arr.size());
// std::vector<TypeParam> sort1, sort2;
// for (size_t jj = 0; jj < size; ++jj) {
// sort1.push_back(arr[inx1[jj]]);
// sort2.push_back(arr[inx2[jj]]);
// }
// EXPECT_EQ(sort1, sort2) << "Array size = " << size;
// EXPECT_UNIQUE(inx2)
// arr.clear();
// }
// }
// else {
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
// }
//}
//
//TYPED_TEST_P(avx512argselect, test_sorted)
//{
// if (cpu_has_avx512bw()) {
// std::vector<int64_t> arrsizes;
// for (int64_t ii = 0; ii <= 1024; ++ii) {
// arrsizes.push_back(ii);
// }
// std::vector<TypeParam> arr;
// for (auto &size : arrsizes) {
// arr = get_uniform_rand_array<TypeParam>(size);
// std::sort(arr.begin(), arr.end());
// std::vector<int64_t> inx1 = std_argsort(arr);
// std::vector<int64_t> inx2
// = avx512_argsort<TypeParam>(arr.data(), arr.size());
// std::vector<TypeParam> sort1, sort2;
// for (size_t jj = 0; jj < size; ++jj) {
// sort1.push_back(arr[inx1[jj]]);
// sort2.push_back(arr[inx2[jj]]);
// }
// EXPECT_EQ(sort1, sort2) << "Array size =" << size;
// EXPECT_UNIQUE(inx2)
// arr.clear();
// }
// }
// else {
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
// }
//}
//
//TYPED_TEST_P(avx512argselect, test_reverse)
//{
// if (cpu_has_avx512bw()) {
// std::vector<int64_t> arrsizes;
// for (int64_t ii = 0; ii <= 1024; ++ii) {
// arrsizes.push_back(ii);
// }
// std::vector<TypeParam> arr;
// for (auto &size : arrsizes) {
// arr = get_uniform_rand_array<TypeParam>(size);
// std::sort(arr.begin(), arr.end());
// std::reverse(arr.begin(), arr.end());
// std::vector<int64_t> inx1 = std_argsort(arr);
// std::vector<int64_t> inx2
// = avx512_argsort<TypeParam>(arr.data(), arr.size());
// std::vector<TypeParam> sort1, sort2;
// for (size_t jj = 0; jj < size; ++jj) {
// sort1.push_back(arr[inx1[jj]]);
// sort2.push_back(arr[inx2[jj]]);
// }
// EXPECT_EQ(sort1, sort2) << "Array size =" << size;
// EXPECT_UNIQUE(inx2)
// arr.clear();
// }
// }
// else {
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
// }
//}
//
//TYPED_TEST_P(avx512argselect, test_array_with_nan)
//{
// if (!cpu_has_avx512bw()) {
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
// }
// if (!std::is_floating_point<TypeParam>::value) {
// GTEST_SKIP() << "Skipping this test, it is meant for float/double";
// }
// std::vector<int64_t> arrsizes;
// for (int64_t ii = 2; ii <= 1024; ++ii) {
// arrsizes.push_back(ii);
// }
// std::vector<TypeParam> arr;
// for (auto &size : arrsizes) {
// arr = get_uniform_rand_array<TypeParam>(size);
// arr[0] = std::numeric_limits<TypeParam>::quiet_NaN();
// arr[1] = std::numeric_limits<TypeParam>::quiet_NaN();
// std::vector<int64_t> inx
// = avx512_argsort<TypeParam>(arr.data(), arr.size());
// std::vector<TypeParam> sort1;
// for (size_t jj = 0; jj < size; ++jj) {
// sort1.push_back(arr[inx[jj]]);
// }
// if ((!std::isnan(sort1[size - 1])) || (!std::isnan(sort1[size - 2]))) {
// FAIL() << "NAN's aren't sorted to the end";
// }
// if (!std::is_sorted(sort1.begin(), sort1.end() - 2)) {
// FAIL() << "Array isn't sorted";
// }
// EXPECT_UNIQUE(inx)
// arr.clear();
// }
//}
//
//TYPED_TEST_P(avx512argselect, test_max_value_at_end_of_array)
//{
// if (!cpu_has_avx512bw()) {
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
// }
// std::vector<int64_t> arrsizes;
// for (int64_t ii = 1; ii <= 256; ++ii) {
// arrsizes.push_back(ii);
// }
// std::vector<TypeParam> arr;
// for (auto &size : arrsizes) {
// arr = get_uniform_rand_array<TypeParam>(size);
// if (std::numeric_limits<TypeParam>::has_infinity) {
// arr[size - 1] = std::numeric_limits<TypeParam>::infinity();
// }
// else {
// arr[size - 1] = std::numeric_limits<TypeParam>::max();
// }
// std::vector<int64_t> inx = avx512_argsort(arr.data(), arr.size());
// std::vector<TypeParam> sorted;
// for (size_t jj = 0; jj < size; ++jj) {
// sorted.push_back(arr[inx[jj]]);
// }
// if (!std::is_sorted(sorted.begin(), sorted.end())) {
// EXPECT_TRUE(false) << "Array of size " << size << "is not sorted";
// }
// EXPECT_UNIQUE(inx)
// arr.clear();
// }
//}
//
//TYPED_TEST_P(avx512argselect, test_all_inf_array)
//{
// if (!cpu_has_avx512bw()) {
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
// }
// std::vector<int64_t> arrsizes;
// for (int64_t ii = 1; ii <= 256; ++ii) {
// arrsizes.push_back(ii);
// }
// std::vector<TypeParam> arr;
// for (auto &size : arrsizes) {
// arr = get_uniform_rand_array<TypeParam>(size);
// if (std::numeric_limits<TypeParam>::has_infinity) {
// for (int64_t jj = 1; jj <= size; ++jj) {
// if (rand() % 0x1) {
// arr.push_back(std::numeric_limits<TypeParam>::infinity());
// }
// }
// }
// else {
// for (int64_t jj = 1; jj <= size; ++jj) {
// if (rand() % 0x1) {
// arr.push_back(std::numeric_limits<TypeParam>::max());
// }
// }
// }
// std::vector<int64_t> inx = avx512_argsort(arr.data(), arr.size());
// std::vector<TypeParam> sorted;
// for (size_t jj = 0; jj < size; ++jj) {
// sorted.push_back(arr[inx[jj]]);
// }
// if (!std::is_sorted(sorted.begin(), sorted.end())) {
// EXPECT_TRUE(false) << "Array of size " << size << "is not sorted";
// }
// EXPECT_UNIQUE(inx)
// arr.clear();
// }
//}

REGISTER_TYPED_TEST_SUITE_P(avx512argselect,
test_random);
//test_reverse,
//test_constant,
//test_sorted,
//test_small_range,
//test_all_inf_array,
//test_array_with_nan,
//test_max_value_at_end_of_array);
Loading

0 comments on commit c1419da

Please sign in to comment.