From 243fbe416c044907620f9fd7436cb9892f4d796c Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 18 Jul 2023 12:52:24 -0700 Subject: [PATCH 1/6] Add avx512_argselect --- src/avx512-64bit-argsort.hpp | 102 ++++++++++++ src/avx512-common-argsort.h | 5 + tests/test-argselect.hpp | 265 ++++++++++++++++++++++++++++++ tests/test-argsort-common.h | 27 +++ tests/test-argsort.cpp | 308 +---------------------------------- tests/test-argsort.hpp | 273 +++++++++++++++++++++++++++++++ 6 files changed, 678 insertions(+), 302 deletions(-) create mode 100644 tests/test-argselect.hpp create mode 100644 tests/test-argsort-common.h create mode 100644 tests/test-argsort.hpp diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 80c6ce4a..bf7d8ec6 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -283,6 +283,39 @@ inline void argsort_64bit_(type_t *arr, argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); } +template +static void argselect_64bit_(type_t *arr, + int64_t *arg, + int64_t pos, + int64_t left, + int64_t right, + int64_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 64) { + argsort_64_64bit(arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if ((pivot != smallest) && (pos < pivot_index)) + argselect_64bit_(arr, arg, pos, left, pivot_index - 1, max_iters - 1); + else if ((pivot != biggest) && (pos >= pivot_index)) + argselect_64bit_(arr, arg, pos, pivot_index, right, max_iters - 1); +} + template bool has_nan(type_t* arr, int64_t arrsize) { @@ -310,6 +343,8 @@ bool has_nan(type_t* arr, int64_t arrsize) return found_nan; } + +/* argsort methods for 32-bit and 64-bit dtypes */ template void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize) { @@ -375,4 +410,71 @@ std::vector avx512_argsort(T* arr, int64_t arrsize) return indices; } +/* argselect methods for 32-bit and 64-bit dtypes */ +template +void avx512_argselect(T* arr, int64_t *arg, int64_t k, int64_t arrsize) +{ + if (arrsize > 1) { + argselect_64bit_>( + arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize) +{ + if (arrsize > 1) { + if (has_nan>(arr, arrsize)) { + /* FIXME: no need to do a full argsort */ + std_argsort_withnan(arr, arg, 0, arrsize); + } + else { + argselect_64bit_>( + arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } + } +} + +template <> +void avx512_argselect(int32_t* arr, int64_t *arg, int64_t k, int64_t arrsize) +{ + if (arrsize > 1) { + argselect_64bit_>( + arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_argselect(uint32_t* arr, int64_t *arg, int64_t k, int64_t arrsize) +{ + if (arrsize > 1) { + argselect_64bit_>( + arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize) +{ + if (arrsize > 1) { + if (has_nan>(arr, arrsize)) { + /* FIXME: no need to do a full argsort */ + std_argsort_withnan(arr, arg, 0, arrsize); + } + else { + argselect_64bit_>( + arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } + } +} + +template +std::vector avx512_argselect(T* arr, int64_t k, int64_t arrsize) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argselect(arr, indices.data(), k, arrsize); + return indices; +} + #endif // AVX512_ARGSORT_64BIT diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index e0dcaccc..0ae50c49 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -21,6 +21,11 @@ void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize); template std::vector avx512_argsort(T *arr, int64_t arrsize); +template +void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize); + +template +std::vector avx512_argselect(T *arr, int64_t k, int64_t arrsize); /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. diff --git a/tests/test-argselect.hpp b/tests/test-argselect.hpp new file mode 100644 index 00000000..33dee37b --- /dev/null +++ b/tests/test-argselect.hpp @@ -0,0 +1,265 @@ +/******************************************* + * * Copyright (C) 2023 Intel Corporation + * * SPDX-License-Identifier: BSD-3-Clause + * *******************************************/ + +template +class avx512argselect : public ::testing::Test { +}; +TYPED_TEST_SUITE_P(avx512argselect); + +TYPED_TEST_P(avx512argselect, test_random) +{ + if (cpu_has_avx512bw()) { + const int arrsize = 1024; + auto arr = get_uniform_rand_array(arrsize); + std::vector kth; + for (int64_t ii = 0; ii < arrsize; ++ii) { + kth.push_back(ii); + } + std::vector sorted_inx = std_argsort(arr); + for (auto &k : kth) { + std::vector inx + = avx512_argselect(arr.data(), k, arr.size()); + EXPECT_EQ(arr[sorted_inx[k]], arr[inx[k]]) << "Failed at index k = " << k; + EXPECT_UNIQUE(inx) + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +//TYPED_TEST_P(avx512argselect, test_constant) +//{ +// if (cpu_has_avx512bw()) { +// std::vector arrsizes; +// for (int64_t ii = 0; ii <= 1024; ++ii) { +// arrsizes.push_back(ii); +// } +// std::vector arr; +// for (auto &size : arrsizes) { +// /* constant array */ +// auto elem = get_uniform_rand_array(1)[0]; +// for (int64_t jj = 0; jj < size; ++jj) { +// arr.push_back(elem); +// } +// std::vector inx1 = std_argsort(arr); +// std::vector inx2 +// = avx512_argsort(arr.data(), arr.size()); +// std::vector sort1, sort2; +// for (size_t jj = 0; jj < size; ++jj) { +// sort1.push_back(arr[inx1[jj]]); +// sort2.push_back(arr[inx2[jj]]); +// } +// EXPECT_EQ(sort1, sort2) << "Array size =" << size; +// EXPECT_UNIQUE(inx2) +// arr.clear(); +// } +// } +// else { +// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; +// } +//} +// +//TYPED_TEST_P(avx512argselect, test_small_range) +//{ +// if (cpu_has_avx512bw()) { +// std::vector arrsizes; +// for (int64_t ii = 0; ii <= 1024; ++ii) { +// arrsizes.push_back(ii); +// } +// std::vector arr; +// for (auto &size : arrsizes) { +// /* array with a smaller range of values */ +// arr = get_uniform_rand_array(size, 20, 1); +// std::vector inx1 = std_argsort(arr); +// std::vector inx2 +// = avx512_argsort(arr.data(), arr.size()); +// std::vector sort1, sort2; +// for (size_t jj = 0; jj < size; ++jj) { +// sort1.push_back(arr[inx1[jj]]); +// sort2.push_back(arr[inx2[jj]]); +// } +// EXPECT_EQ(sort1, sort2) << "Array size = " << size; +// EXPECT_UNIQUE(inx2) +// arr.clear(); +// } +// } +// else { +// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; +// } +//} +// +//TYPED_TEST_P(avx512argselect, test_sorted) +//{ +// if (cpu_has_avx512bw()) { +// std::vector arrsizes; +// for (int64_t ii = 0; ii <= 1024; ++ii) { +// arrsizes.push_back(ii); +// } +// std::vector arr; +// for (auto &size : arrsizes) { +// arr = get_uniform_rand_array(size); +// std::sort(arr.begin(), arr.end()); +// std::vector inx1 = std_argsort(arr); +// std::vector inx2 +// = avx512_argsort(arr.data(), arr.size()); +// std::vector sort1, sort2; +// for (size_t jj = 0; jj < size; ++jj) { +// sort1.push_back(arr[inx1[jj]]); +// sort2.push_back(arr[inx2[jj]]); +// } +// EXPECT_EQ(sort1, sort2) << "Array size =" << size; +// EXPECT_UNIQUE(inx2) +// arr.clear(); +// } +// } +// else { +// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; +// } +//} +// +//TYPED_TEST_P(avx512argselect, test_reverse) +//{ +// if (cpu_has_avx512bw()) { +// std::vector arrsizes; +// for (int64_t ii = 0; ii <= 1024; ++ii) { +// arrsizes.push_back(ii); +// } +// std::vector arr; +// for (auto &size : arrsizes) { +// arr = get_uniform_rand_array(size); +// std::sort(arr.begin(), arr.end()); +// std::reverse(arr.begin(), arr.end()); +// std::vector inx1 = std_argsort(arr); +// std::vector inx2 +// = avx512_argsort(arr.data(), arr.size()); +// std::vector sort1, sort2; +// for (size_t jj = 0; jj < size; ++jj) { +// sort1.push_back(arr[inx1[jj]]); +// sort2.push_back(arr[inx2[jj]]); +// } +// EXPECT_EQ(sort1, sort2) << "Array size =" << size; +// EXPECT_UNIQUE(inx2) +// arr.clear(); +// } +// } +// else { +// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; +// } +//} +// +//TYPED_TEST_P(avx512argselect, test_array_with_nan) +//{ +// if (!cpu_has_avx512bw()) { +// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; +// } +// if (!std::is_floating_point::value) { +// GTEST_SKIP() << "Skipping this test, it is meant for float/double"; +// } +// std::vector arrsizes; +// for (int64_t ii = 2; ii <= 1024; ++ii) { +// arrsizes.push_back(ii); +// } +// std::vector arr; +// for (auto &size : arrsizes) { +// arr = get_uniform_rand_array(size); +// arr[0] = std::numeric_limits::quiet_NaN(); +// arr[1] = std::numeric_limits::quiet_NaN(); +// std::vector inx +// = avx512_argsort(arr.data(), arr.size()); +// std::vector sort1; +// for (size_t jj = 0; jj < size; ++jj) { +// sort1.push_back(arr[inx[jj]]); +// } +// if ((!std::isnan(sort1[size - 1])) || (!std::isnan(sort1[size - 2]))) { +// FAIL() << "NAN's aren't sorted to the end"; +// } +// if (!std::is_sorted(sort1.begin(), sort1.end() - 2)) { +// FAIL() << "Array isn't sorted"; +// } +// EXPECT_UNIQUE(inx) +// arr.clear(); +// } +//} +// +//TYPED_TEST_P(avx512argselect, test_max_value_at_end_of_array) +//{ +// if (!cpu_has_avx512bw()) { +// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; +// } +// std::vector arrsizes; +// for (int64_t ii = 1; ii <= 256; ++ii) { +// arrsizes.push_back(ii); +// } +// std::vector arr; +// for (auto &size : arrsizes) { +// arr = get_uniform_rand_array(size); +// if (std::numeric_limits::has_infinity) { +// arr[size - 1] = std::numeric_limits::infinity(); +// } +// else { +// arr[size - 1] = std::numeric_limits::max(); +// } +// std::vector inx = avx512_argsort(arr.data(), arr.size()); +// std::vector sorted; +// for (size_t jj = 0; jj < size; ++jj) { +// sorted.push_back(arr[inx[jj]]); +// } +// if (!std::is_sorted(sorted.begin(), sorted.end())) { +// EXPECT_TRUE(false) << "Array of size " << size << "is not sorted"; +// } +// EXPECT_UNIQUE(inx) +// arr.clear(); +// } +//} +// +//TYPED_TEST_P(avx512argselect, test_all_inf_array) +//{ +// if (!cpu_has_avx512bw()) { +// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; +// } +// std::vector arrsizes; +// for (int64_t ii = 1; ii <= 256; ++ii) { +// arrsizes.push_back(ii); +// } +// std::vector arr; +// for (auto &size : arrsizes) { +// arr = get_uniform_rand_array(size); +// if (std::numeric_limits::has_infinity) { +// for (int64_t jj = 1; jj <= size; ++jj) { +// if (rand() % 0x1) { +// arr.push_back(std::numeric_limits::infinity()); +// } +// } +// } +// else { +// for (int64_t jj = 1; jj <= size; ++jj) { +// if (rand() % 0x1) { +// arr.push_back(std::numeric_limits::max()); +// } +// } +// } +// std::vector inx = avx512_argsort(arr.data(), arr.size()); +// std::vector sorted; +// for (size_t jj = 0; jj < size; ++jj) { +// sorted.push_back(arr[inx[jj]]); +// } +// if (!std::is_sorted(sorted.begin(), sorted.end())) { +// EXPECT_TRUE(false) << "Array of size " << size << "is not sorted"; +// } +// EXPECT_UNIQUE(inx) +// arr.clear(); +// } +//} + +REGISTER_TYPED_TEST_SUITE_P(avx512argselect, + test_random); + //test_reverse, + //test_constant, + //test_sorted, + //test_small_range, + //test_all_inf_array, + //test_array_with_nan, + //test_max_value_at_end_of_array); diff --git a/tests/test-argsort-common.h b/tests/test-argsort-common.h new file mode 100644 index 00000000..a21110a5 --- /dev/null +++ b/tests/test-argsort-common.h @@ -0,0 +1,27 @@ +#include +#include +#include +#include "cpuinfo.h" +#include "rand_array.h" +#include "avx512-64bit-argsort.hpp" + +template +std::vector std_argsort(const std::vector &array) +{ + std::vector indices(array.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), + indices.end(), + [&array](int left, int right) -> bool { + // sort indices according to corresponding array sizeent + return array[left] < array[right]; + }); + + return indices; +} + +#define EXPECT_UNIQUE(sorted_arg) \ + std::sort(sorted_arg.begin(), sorted_arg.end()); \ + std::vector expected_arg(sorted_arg.size()); \ + std::iota(expected_arg.begin(), expected_arg.end(), 0); \ + EXPECT_EQ(sorted_arg, expected_arg) << "Indices aren't unique. Array size = " << sorted_arg.size(); diff --git a/tests/test-argsort.cpp b/tests/test-argsort.cpp index 8048d751..b4d2b9e2 100644 --- a/tests/test-argsort.cpp +++ b/tests/test-argsort.cpp @@ -1,305 +1,9 @@ -/******************************************* - * * Copyright (C) 2023 Intel Corporation - * * SPDX-License-Identifier: BSD-3-Clause - * *******************************************/ +#include "test-argsort-common.h" +#include "test-argselect.hpp" +#include "test-argsort.hpp" -#include "avx512-64bit-argsort.hpp" -#include "cpuinfo.h" -#include "rand_array.h" -#include -#include -#include - -template -class avx512argsort : public ::testing::Test { -}; -TYPED_TEST_SUITE_P(avx512argsort); - -template -std::vector std_argsort(const std::vector &array) -{ - std::vector indices(array.size()); - std::iota(indices.begin(), indices.end(), 0); - std::sort(indices.begin(), - indices.end(), - [&array](int left, int right) -> bool { - // sort indices according to corresponding array sizeent - return array[left] < array[right]; - }); - - return indices; -} - -#define EXPECT_UNIQUE(sorted_arg) \ - std::sort(sorted_arg.begin(), sorted_arg.end()); \ - std::vector expected_arg(sorted_arg.size()); \ - std::iota(expected_arg.begin(), expected_arg.end(), 0); \ - EXPECT_EQ(sorted_arg, expected_arg) << "Indices aren't unique. Array size = " << sorted_arg.size(); - -TYPED_TEST_P(avx512argsort, test_random) -{ - if (cpu_has_avx512bw()) { - std::vector arrsizes; - for (int64_t ii = 0; ii <= 1024; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - /* Random array */ - arr = get_uniform_rand_array(size); - std::vector inx1 = std_argsort(arr); - std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; - for (size_t jj = 0; jj < size; ++jj) { - sort1.push_back(arr[inx1[jj]]); - sort2.push_back(arr[inx2[jj]]); - } - EXPECT_EQ(sort1, sort2) << "Array size =" << size; - EXPECT_UNIQUE(inx2) - arr.clear(); - } - } - else { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } -} - -TYPED_TEST_P(avx512argsort, test_constant) -{ - if (cpu_has_avx512bw()) { - std::vector arrsizes; - for (int64_t ii = 0; ii <= 1024; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - /* constant array */ - auto elem = get_uniform_rand_array(1)[0]; - for (int64_t jj = 0; jj < size; ++jj) { - arr.push_back(elem); - } - std::vector inx1 = std_argsort(arr); - std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; - for (size_t jj = 0; jj < size; ++jj) { - sort1.push_back(arr[inx1[jj]]); - sort2.push_back(arr[inx2[jj]]); - } - EXPECT_EQ(sort1, sort2) << "Array size =" << size; - EXPECT_UNIQUE(inx2) - arr.clear(); - } - } - else { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } -} - -TYPED_TEST_P(avx512argsort, test_small_range) -{ - if (cpu_has_avx512bw()) { - std::vector arrsizes; - for (int64_t ii = 0; ii <= 1024; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - /* array with a smaller range of values */ - arr = get_uniform_rand_array(size, 20, 1); - std::vector inx1 = std_argsort(arr); - std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; - for (size_t jj = 0; jj < size; ++jj) { - sort1.push_back(arr[inx1[jj]]); - sort2.push_back(arr[inx2[jj]]); - } - EXPECT_EQ(sort1, sort2) << "Array size = " << size; - EXPECT_UNIQUE(inx2) - arr.clear(); - } - } - else { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } -} - -TYPED_TEST_P(avx512argsort, test_sorted) -{ - if (cpu_has_avx512bw()) { - std::vector arrsizes; - for (int64_t ii = 0; ii <= 1024; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - arr = get_uniform_rand_array(size); - std::sort(arr.begin(), arr.end()); - std::vector inx1 = std_argsort(arr); - std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; - for (size_t jj = 0; jj < size; ++jj) { - sort1.push_back(arr[inx1[jj]]); - sort2.push_back(arr[inx2[jj]]); - } - EXPECT_EQ(sort1, sort2) << "Array size =" << size; - EXPECT_UNIQUE(inx2) - arr.clear(); - } - } - else { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } -} - -TYPED_TEST_P(avx512argsort, test_reverse) -{ - if (cpu_has_avx512bw()) { - std::vector arrsizes; - for (int64_t ii = 0; ii <= 1024; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - arr = get_uniform_rand_array(size); - std::sort(arr.begin(), arr.end()); - std::reverse(arr.begin(), arr.end()); - std::vector inx1 = std_argsort(arr); - std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; - for (size_t jj = 0; jj < size; ++jj) { - sort1.push_back(arr[inx1[jj]]); - sort2.push_back(arr[inx2[jj]]); - } - EXPECT_EQ(sort1, sort2) << "Array size =" << size; - EXPECT_UNIQUE(inx2) - arr.clear(); - } - } - else { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } -} - -TYPED_TEST_P(avx512argsort, test_array_with_nan) -{ - if (!cpu_has_avx512bw()) { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } - if (!std::is_floating_point::value) { - GTEST_SKIP() << "Skipping this test, it is meant for float/double"; - } - std::vector arrsizes; - for (int64_t ii = 2; ii <= 1024; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - arr = get_uniform_rand_array(size); - arr[0] = std::numeric_limits::quiet_NaN(); - arr[1] = std::numeric_limits::quiet_NaN(); - std::vector inx - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1; - for (size_t jj = 0; jj < size; ++jj) { - sort1.push_back(arr[inx[jj]]); - } - if ((!std::isnan(sort1[size - 1])) || (!std::isnan(sort1[size - 2]))) { - FAIL() << "NAN's aren't sorted to the end"; - } - if (!std::is_sorted(sort1.begin(), sort1.end() - 2)) { - FAIL() << "Array isn't sorted"; - } - EXPECT_UNIQUE(inx) - arr.clear(); - } -} - -TYPED_TEST_P(avx512argsort, test_max_value_at_end_of_array) -{ - if (!cpu_has_avx512bw()) { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } - std::vector arrsizes; - for (int64_t ii = 1; ii <= 256; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - arr = get_uniform_rand_array(size); - if (std::numeric_limits::has_infinity) { - arr[size - 1] = std::numeric_limits::infinity(); - } - else { - arr[size - 1] = std::numeric_limits::max(); - } - std::vector inx = avx512_argsort(arr.data(), arr.size()); - std::vector sorted; - for (size_t jj = 0; jj < size; ++jj) { - sorted.push_back(arr[inx[jj]]); - } - if (!std::is_sorted(sorted.begin(), sorted.end())) { - EXPECT_TRUE(false) << "Array of size " << size << "is not sorted"; - } - EXPECT_UNIQUE(inx) - arr.clear(); - } -} - -TYPED_TEST_P(avx512argsort, test_all_inf_array) -{ - if (!cpu_has_avx512bw()) { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } - std::vector arrsizes; - for (int64_t ii = 1; ii <= 256; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - arr = get_uniform_rand_array(size); - if (std::numeric_limits::has_infinity) { - for (int64_t jj = 1; jj <= size; ++jj) { - if (rand() % 0x1) { - arr.push_back(std::numeric_limits::infinity()); - } - } - } - else { - for (int64_t jj = 1; jj <= size; ++jj) { - if (rand() % 0x1) { - arr.push_back(std::numeric_limits::max()); - } - } - } - std::vector inx = avx512_argsort(arr.data(), arr.size()); - std::vector sorted; - for (size_t jj = 0; jj < size; ++jj) { - sorted.push_back(arr[inx[jj]]); - } - if (!std::is_sorted(sorted.begin(), sorted.end())) { - EXPECT_TRUE(false) << "Array of size " << size << "is not sorted"; - } - EXPECT_UNIQUE(inx) - arr.clear(); - } -} - -REGISTER_TYPED_TEST_SUITE_P(avx512argsort, - test_random, - test_reverse, - test_constant, - test_sorted, - test_small_range, - test_all_inf_array, - test_array_with_nan, - test_max_value_at_end_of_array); - -using ArgSortTestTypes +using ArgTestTypes = testing::Types; -INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512argsort, ArgSortTestTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512argsort, ArgTestTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512argselect, ArgTestTypes); diff --git a/tests/test-argsort.hpp b/tests/test-argsort.hpp new file mode 100644 index 00000000..5ef9e6ea --- /dev/null +++ b/tests/test-argsort.hpp @@ -0,0 +1,273 @@ +/******************************************* + * * Copyright (C) 2023 Intel Corporation + * * SPDX-License-Identifier: BSD-3-Clause + * *******************************************/ + +template +class avx512argsort : public ::testing::Test { +}; +TYPED_TEST_SUITE_P(avx512argsort); + +TYPED_TEST_P(avx512argsort, test_random) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + /* Random array */ + arr = get_uniform_rand_array(size); + std::vector inx1 = std_argsort(arr); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + EXPECT_EQ(sort1, sort2) << "Array size =" << size; + EXPECT_UNIQUE(inx2) + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TYPED_TEST_P(avx512argsort, test_constant) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + /* constant array */ + auto elem = get_uniform_rand_array(1)[0]; + for (int64_t jj = 0; jj < size; ++jj) { + arr.push_back(elem); + } + std::vector inx1 = std_argsort(arr); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + EXPECT_EQ(sort1, sort2) << "Array size =" << size; + EXPECT_UNIQUE(inx2) + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TYPED_TEST_P(avx512argsort, test_small_range) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + /* array with a smaller range of values */ + arr = get_uniform_rand_array(size, 20, 1); + std::vector inx1 = std_argsort(arr); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + EXPECT_EQ(sort1, sort2) << "Array size = " << size; + EXPECT_UNIQUE(inx2) + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TYPED_TEST_P(avx512argsort, test_sorted) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + arr = get_uniform_rand_array(size); + std::sort(arr.begin(), arr.end()); + std::vector inx1 = std_argsort(arr); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + EXPECT_EQ(sort1, sort2) << "Array size =" << size; + EXPECT_UNIQUE(inx2) + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TYPED_TEST_P(avx512argsort, test_reverse) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + arr = get_uniform_rand_array(size); + std::sort(arr.begin(), arr.end()); + std::reverse(arr.begin(), arr.end()); + std::vector inx1 = std_argsort(arr); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + EXPECT_EQ(sort1, sort2) << "Array size =" << size; + EXPECT_UNIQUE(inx2) + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TYPED_TEST_P(avx512argsort, test_array_with_nan) +{ + if (!cpu_has_avx512bw()) { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } + if (!std::is_floating_point::value) { + GTEST_SKIP() << "Skipping this test, it is meant for float/double"; + } + std::vector arrsizes; + for (int64_t ii = 2; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + arr = get_uniform_rand_array(size); + arr[0] = std::numeric_limits::quiet_NaN(); + arr[1] = std::numeric_limits::quiet_NaN(); + std::vector inx + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx[jj]]); + } + if ((!std::isnan(sort1[size - 1])) || (!std::isnan(sort1[size - 2]))) { + FAIL() << "NAN's aren't sorted to the end"; + } + if (!std::is_sorted(sort1.begin(), sort1.end() - 2)) { + FAIL() << "Array isn't sorted"; + } + EXPECT_UNIQUE(inx) + arr.clear(); + } +} + +TYPED_TEST_P(avx512argsort, test_max_value_at_end_of_array) +{ + if (!cpu_has_avx512bw()) { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } + std::vector arrsizes; + for (int64_t ii = 1; ii <= 256; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + arr = get_uniform_rand_array(size); + if (std::numeric_limits::has_infinity) { + arr[size - 1] = std::numeric_limits::infinity(); + } + else { + arr[size - 1] = std::numeric_limits::max(); + } + std::vector inx = avx512_argsort(arr.data(), arr.size()); + std::vector sorted; + for (size_t jj = 0; jj < size; ++jj) { + sorted.push_back(arr[inx[jj]]); + } + if (!std::is_sorted(sorted.begin(), sorted.end())) { + EXPECT_TRUE(false) << "Array of size " << size << "is not sorted"; + } + EXPECT_UNIQUE(inx) + arr.clear(); + } +} + +TYPED_TEST_P(avx512argsort, test_all_inf_array) +{ + if (!cpu_has_avx512bw()) { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } + std::vector arrsizes; + for (int64_t ii = 1; ii <= 256; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + arr = get_uniform_rand_array(size); + if (std::numeric_limits::has_infinity) { + for (int64_t jj = 1; jj <= size; ++jj) { + if (rand() % 0x1) { + arr.push_back(std::numeric_limits::infinity()); + } + } + } + else { + for (int64_t jj = 1; jj <= size; ++jj) { + if (rand() % 0x1) { + arr.push_back(std::numeric_limits::max()); + } + } + } + std::vector inx = avx512_argsort(arr.data(), arr.size()); + std::vector sorted; + for (size_t jj = 0; jj < size; ++jj) { + sorted.push_back(arr[inx[jj]]); + } + if (!std::is_sorted(sorted.begin(), sorted.end())) { + EXPECT_TRUE(false) << "Array of size " << size << "is not sorted"; + } + EXPECT_UNIQUE(inx) + arr.clear(); + } +} + +REGISTER_TYPED_TEST_SUITE_P(avx512argsort, + test_random, + test_reverse, + test_constant, + test_sorted, + test_small_range, + test_all_inf_array, + test_array_with_nan, + test_max_value_at_end_of_array); + From 608538dd639b977a0d1fd54bc0b0cb517558bc7d Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 18 Jul 2023 14:53:43 -0700 Subject: [PATCH 2/6] More test condition for argselect --- tests/test-argselect.hpp | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/test-argselect.hpp b/tests/test-argselect.hpp index 33dee37b..2b907c61 100644 --- a/tests/test-argselect.hpp +++ b/tests/test-argselect.hpp @@ -8,6 +8,27 @@ class avx512argselect : public ::testing::Test { }; TYPED_TEST_SUITE_P(avx512argselect); + +template +T std_min_element(std::vector arr, std::vector arg, int64_t left, int64_t right) +{ + std::vector::iterator res = + std::min_element(arg.begin() + left, + arg.begin() + right, + [arr](int64_t a, int64_t b) -> bool {return arr[a] < arr[b];}); + return arr[*res]; +} + +template +T std_max_element(std::vector arr, std::vector arg, int64_t left, int64_t right) +{ + std::vector::iterator res = + std::max_element(arg.begin() + left, + arg.begin() + right, + [arr](int64_t a, int64_t b) -> bool {return arr[a] > arr[b];}); + return arr[*res]; +} + TYPED_TEST_P(avx512argselect, test_random) { if (cpu_has_avx512bw()) { @@ -21,7 +42,12 @@ TYPED_TEST_P(avx512argselect, test_random) for (auto &k : kth) { std::vector inx = avx512_argselect(arr.data(), k, arr.size()); - EXPECT_EQ(arr[sorted_inx[k]], arr[inx[k]]) << "Failed at index k = " << k; + auto true_kth = arr[sorted_inx[k]]; + EXPECT_EQ(true_kth, arr[inx[k]]) << "Failed at index k = " << k; + if (k >= 1) + EXPECT_GE(true_kth, std_max_element(arr, inx, 0, k-1)); + if (k != arrsize-1) + EXPECT_LE(true_kth, std_min_element(arr, inx, k+1, arrsize-1)); EXPECT_UNIQUE(inx) } } From 8bd9c42a30f633e8713da50dbf70de5ee5109146 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 20 Jul 2023 11:03:38 -0700 Subject: [PATCH 3/6] Use std_argselect_withnan --- src/avx512-64bit-argsort.hpp | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index bf7d8ec6..8162c453 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -11,6 +11,20 @@ #include "avx512-common-argsort.h" #include "avx512-64bit-keyvalue-networks.hpp" +template +void std_argselect_withnan(T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right) +{ + std::nth_element(arg + left, + arg + k, + arg + right, + [arr](int64_t a, int64_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {return arr[a] < arr[b];} + else if (std::isnan(arr[a])) {return false;} + else {return true;} + }); +} + + /* argsort using std::sort */ template void std_argsort_withnan(T *arr, int64_t *arg, int64_t left, int64_t right) @@ -425,8 +439,7 @@ void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize) { if (arrsize > 1) { if (has_nan>(arr, arrsize)) { - /* FIXME: no need to do a full argsort */ - std_argsort_withnan(arr, arg, 0, arrsize); + std_argselect_withnan(arr, arg, 0, arrsize); } else { argselect_64bit_>( @@ -458,8 +471,7 @@ void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize) { if (arrsize > 1) { if (has_nan>(arr, arrsize)) { - /* FIXME: no need to do a full argsort */ - std_argsort_withnan(arr, arg, 0, arrsize); + std_argselect_withnan(arr, arg, 0, arrsize); } else { argselect_64bit_>( From 437fd4abdd8ad491d373f117d580ad5261d2aecf Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 20 Jul 2023 14:32:38 -0700 Subject: [PATCH 4/6] Add tests for arrays with NAN --- src/avx512-64bit-argsort.hpp | 4 +- tests/test-argselect.hpp | 259 +++-------------------------------- tests/test-argsort-common.h | 13 +- 3 files changed, 29 insertions(+), 247 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 8162c453..64e7cfdf 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -439,7 +439,7 @@ void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize) { if (arrsize > 1) { if (has_nan>(arr, arrsize)) { - std_argselect_withnan(arr, arg, 0, arrsize); + std_argselect_withnan(arr, arg, k, 0, arrsize); } else { argselect_64bit_>( @@ -471,7 +471,7 @@ void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize) { if (arrsize > 1) { if (has_nan>(arr, arrsize)) { - std_argselect_withnan(arr, arg, 0, arrsize); + std_argselect_withnan(arr, arg, k, 0, arrsize); } else { argselect_64bit_>( diff --git a/tests/test-argselect.hpp b/tests/test-argselect.hpp index 2b907c61..23375b54 100644 --- a/tests/test-argselect.hpp +++ b/tests/test-argselect.hpp @@ -15,7 +15,11 @@ T std_min_element(std::vector arr, std::vector arg, int64_t left, in std::vector::iterator res = std::min_element(arg.begin() + left, arg.begin() + right, - [arr](int64_t a, int64_t b) -> bool {return arr[a] < arr[b];}); + [arr](int64_t a, int64_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {return arr[a] < arr[b];} + else if (std::isnan(arr[a])) {return false;} + else {return true;} + }); return arr[*res]; } @@ -25,7 +29,11 @@ T std_max_element(std::vector arr, std::vector arg, int64_t left, in std::vector::iterator res = std::max_element(arg.begin() + left, arg.begin() + right, - [arr](int64_t a, int64_t b) -> bool {return arr[a] > arr[b];}); + [arr](int64_t a, int64_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {return arr[a] > arr[b];} + else if (std::isnan(arr[a])) {return true;} + else {return false;} + }); return arr[*res]; } @@ -34,20 +42,25 @@ TYPED_TEST_P(avx512argselect, test_random) if (cpu_has_avx512bw()) { const int arrsize = 1024; auto arr = get_uniform_rand_array(arrsize); + std::vector sorted_inx; + if (std::is_floating_point::value) { + arr[0] = std::numeric_limits::quiet_NaN(); + arr[1] = std::numeric_limits::quiet_NaN(); + } + sorted_inx = std_argsort(arr); std::vector kth; - for (int64_t ii = 0; ii < arrsize; ++ii) { + for (int64_t ii = 0; ii < arrsize-3; ++ii) { kth.push_back(ii); } - std::vector sorted_inx = std_argsort(arr); for (auto &k : kth) { std::vector inx = avx512_argselect(arr.data(), k, arr.size()); auto true_kth = arr[sorted_inx[k]]; EXPECT_EQ(true_kth, arr[inx[k]]) << "Failed at index k = " << k; if (k >= 1) - EXPECT_GE(true_kth, std_max_element(arr, inx, 0, k-1)); + EXPECT_GE(true_kth, std_max_element(arr, inx, 0, k-1)) << "failed at k = " << k; if (k != arrsize-1) - EXPECT_LE(true_kth, std_min_element(arr, inx, k+1, arrsize-1)); + EXPECT_LE(true_kth, std_min_element(arr, inx, k+1, arrsize-1)) << "failed at k = " << k; EXPECT_UNIQUE(inx) } } @@ -56,236 +69,4 @@ TYPED_TEST_P(avx512argselect, test_random) } } -//TYPED_TEST_P(avx512argselect, test_constant) -//{ -// if (cpu_has_avx512bw()) { -// std::vector arrsizes; -// for (int64_t ii = 0; ii <= 1024; ++ii) { -// arrsizes.push_back(ii); -// } -// std::vector arr; -// for (auto &size : arrsizes) { -// /* constant array */ -// auto elem = get_uniform_rand_array(1)[0]; -// for (int64_t jj = 0; jj < size; ++jj) { -// arr.push_back(elem); -// } -// std::vector inx1 = std_argsort(arr); -// std::vector inx2 -// = avx512_argsort(arr.data(), arr.size()); -// std::vector sort1, sort2; -// for (size_t jj = 0; jj < size; ++jj) { -// sort1.push_back(arr[inx1[jj]]); -// sort2.push_back(arr[inx2[jj]]); -// } -// EXPECT_EQ(sort1, sort2) << "Array size =" << size; -// EXPECT_UNIQUE(inx2) -// arr.clear(); -// } -// } -// else { -// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; -// } -//} -// -//TYPED_TEST_P(avx512argselect, test_small_range) -//{ -// if (cpu_has_avx512bw()) { -// std::vector arrsizes; -// for (int64_t ii = 0; ii <= 1024; ++ii) { -// arrsizes.push_back(ii); -// } -// std::vector arr; -// for (auto &size : arrsizes) { -// /* array with a smaller range of values */ -// arr = get_uniform_rand_array(size, 20, 1); -// std::vector inx1 = std_argsort(arr); -// std::vector inx2 -// = avx512_argsort(arr.data(), arr.size()); -// std::vector sort1, sort2; -// for (size_t jj = 0; jj < size; ++jj) { -// sort1.push_back(arr[inx1[jj]]); -// sort2.push_back(arr[inx2[jj]]); -// } -// EXPECT_EQ(sort1, sort2) << "Array size = " << size; -// EXPECT_UNIQUE(inx2) -// arr.clear(); -// } -// } -// else { -// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; -// } -//} -// -//TYPED_TEST_P(avx512argselect, test_sorted) -//{ -// if (cpu_has_avx512bw()) { -// std::vector arrsizes; -// for (int64_t ii = 0; ii <= 1024; ++ii) { -// arrsizes.push_back(ii); -// } -// std::vector arr; -// for (auto &size : arrsizes) { -// arr = get_uniform_rand_array(size); -// std::sort(arr.begin(), arr.end()); -// std::vector inx1 = std_argsort(arr); -// std::vector inx2 -// = avx512_argsort(arr.data(), arr.size()); -// std::vector sort1, sort2; -// for (size_t jj = 0; jj < size; ++jj) { -// sort1.push_back(arr[inx1[jj]]); -// sort2.push_back(arr[inx2[jj]]); -// } -// EXPECT_EQ(sort1, sort2) << "Array size =" << size; -// EXPECT_UNIQUE(inx2) -// arr.clear(); -// } -// } -// else { -// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; -// } -//} -// -//TYPED_TEST_P(avx512argselect, test_reverse) -//{ -// if (cpu_has_avx512bw()) { -// std::vector arrsizes; -// for (int64_t ii = 0; ii <= 1024; ++ii) { -// arrsizes.push_back(ii); -// } -// std::vector arr; -// for (auto &size : arrsizes) { -// arr = get_uniform_rand_array(size); -// std::sort(arr.begin(), arr.end()); -// std::reverse(arr.begin(), arr.end()); -// std::vector inx1 = std_argsort(arr); -// std::vector inx2 -// = avx512_argsort(arr.data(), arr.size()); -// std::vector sort1, sort2; -// for (size_t jj = 0; jj < size; ++jj) { -// sort1.push_back(arr[inx1[jj]]); -// sort2.push_back(arr[inx2[jj]]); -// } -// EXPECT_EQ(sort1, sort2) << "Array size =" << size; -// EXPECT_UNIQUE(inx2) -// arr.clear(); -// } -// } -// else { -// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; -// } -//} -// -//TYPED_TEST_P(avx512argselect, test_array_with_nan) -//{ -// if (!cpu_has_avx512bw()) { -// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; -// } -// if (!std::is_floating_point::value) { -// GTEST_SKIP() << "Skipping this test, it is meant for float/double"; -// } -// std::vector arrsizes; -// for (int64_t ii = 2; ii <= 1024; ++ii) { -// arrsizes.push_back(ii); -// } -// std::vector arr; -// for (auto &size : arrsizes) { -// arr = get_uniform_rand_array(size); -// arr[0] = std::numeric_limits::quiet_NaN(); -// arr[1] = std::numeric_limits::quiet_NaN(); -// std::vector inx -// = avx512_argsort(arr.data(), arr.size()); -// std::vector sort1; -// for (size_t jj = 0; jj < size; ++jj) { -// sort1.push_back(arr[inx[jj]]); -// } -// if ((!std::isnan(sort1[size - 1])) || (!std::isnan(sort1[size - 2]))) { -// FAIL() << "NAN's aren't sorted to the end"; -// } -// if (!std::is_sorted(sort1.begin(), sort1.end() - 2)) { -// FAIL() << "Array isn't sorted"; -// } -// EXPECT_UNIQUE(inx) -// arr.clear(); -// } -//} -// -//TYPED_TEST_P(avx512argselect, test_max_value_at_end_of_array) -//{ -// if (!cpu_has_avx512bw()) { -// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; -// } -// std::vector arrsizes; -// for (int64_t ii = 1; ii <= 256; ++ii) { -// arrsizes.push_back(ii); -// } -// std::vector arr; -// for (auto &size : arrsizes) { -// arr = get_uniform_rand_array(size); -// if (std::numeric_limits::has_infinity) { -// arr[size - 1] = std::numeric_limits::infinity(); -// } -// else { -// arr[size - 1] = std::numeric_limits::max(); -// } -// std::vector inx = avx512_argsort(arr.data(), arr.size()); -// std::vector sorted; -// for (size_t jj = 0; jj < size; ++jj) { -// sorted.push_back(arr[inx[jj]]); -// } -// if (!std::is_sorted(sorted.begin(), sorted.end())) { -// EXPECT_TRUE(false) << "Array of size " << size << "is not sorted"; -// } -// EXPECT_UNIQUE(inx) -// arr.clear(); -// } -//} -// -//TYPED_TEST_P(avx512argselect, test_all_inf_array) -//{ -// if (!cpu_has_avx512bw()) { -// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; -// } -// std::vector arrsizes; -// for (int64_t ii = 1; ii <= 256; ++ii) { -// arrsizes.push_back(ii); -// } -// std::vector arr; -// for (auto &size : arrsizes) { -// arr = get_uniform_rand_array(size); -// if (std::numeric_limits::has_infinity) { -// for (int64_t jj = 1; jj <= size; ++jj) { -// if (rand() % 0x1) { -// arr.push_back(std::numeric_limits::infinity()); -// } -// } -// } -// else { -// for (int64_t jj = 1; jj <= size; ++jj) { -// if (rand() % 0x1) { -// arr.push_back(std::numeric_limits::max()); -// } -// } -// } -// std::vector inx = avx512_argsort(arr.data(), arr.size()); -// std::vector sorted; -// for (size_t jj = 0; jj < size; ++jj) { -// sorted.push_back(arr[inx[jj]]); -// } -// if (!std::is_sorted(sorted.begin(), sorted.end())) { -// EXPECT_TRUE(false) << "Array of size " << size << "is not sorted"; -// } -// EXPECT_UNIQUE(inx) -// arr.clear(); -// } -//} - -REGISTER_TYPED_TEST_SUITE_P(avx512argselect, - test_random); - //test_reverse, - //test_constant, - //test_sorted, - //test_small_range, - //test_all_inf_array, - //test_array_with_nan, - //test_max_value_at_end_of_array); +REGISTER_TYPED_TEST_SUITE_P(avx512argselect, test_random); diff --git a/tests/test-argsort-common.h b/tests/test-argsort-common.h index a21110a5..d757e968 100644 --- a/tests/test-argsort-common.h +++ b/tests/test-argsort-common.h @@ -6,16 +6,17 @@ #include "avx512-64bit-argsort.hpp" template -std::vector std_argsort(const std::vector &array) +std::vector std_argsort(const std::vector &arr) { - std::vector indices(array.size()); + std::vector indices(arr.size()); std::iota(indices.begin(), indices.end(), 0); std::sort(indices.begin(), indices.end(), - [&array](int left, int right) -> bool { - // sort indices according to corresponding array sizeent - return array[left] < array[right]; - }); + [&arr](int64_t left, int64_t right) -> bool { + if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) {return arr[left] < arr[right];} + else if (std::isnan(arr[left])) {return false;} + else {return true;} + }); return indices; } From e1b7fa7ad1f879a376a7755556ac3913e5b504ec Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 20 Jul 2023 14:35:07 -0700 Subject: [PATCH 5/6] clang format --- src/avx512-16bit-qsort.hpp | 4 +- src/avx512-32bit-qsort.hpp | 14 ++++-- src/avx512-64bit-argsort.hpp | 66 ++++++++++++++---------- src/avx512-64bit-keyvalue-networks.hpp | 64 ++++++++++++------------ src/avx512-64bit-qsort.hpp | 19 +++++-- src/avx512fp16-16bit-qsort.hpp | 4 +- tests/test-argselect.hpp | 66 +++++++++++++++--------- tests/test-argsort-common.h | 69 +++++++++++++++++++++++--- tests/test-argsort.cpp | 4 +- tests/test-argsort.hpp | 1 - tests/test-qsort-fp.hpp | 10 ++-- tests/test-qsort.cpp | 2 +- 12 files changed, 212 insertions(+), 111 deletions(-) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 1efcf1e9..b5202f46 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -433,11 +433,11 @@ void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) { int64_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { qselect_16bit_, uint16_t>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index bfd4a151..a0dd7f7e 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -715,7 +715,10 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) } template <> -void avx512_qselect(int32_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void avx512_qselect(int32_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan) { if (arrsize > 1) { qselect_32bit_, int32_t>( @@ -724,7 +727,10 @@ void avx512_qselect(int32_t *arr, int64_t k, int64_t arrsize, bool hasn } template <> -void avx512_qselect(uint32_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void avx512_qselect(uint32_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan) { if (arrsize > 1) { qselect_32bit_, uint32_t>( @@ -737,11 +743,11 @@ void avx512_qselect(float *arr, int64_t k, int64_t arrsize, bool hasnan) { int64_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { qselect_32bit_, float>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 64e7cfdf..3626ab63 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -8,23 +8,29 @@ #define AVX512_ARGSORT_64BIT #include "avx512-64bit-common.h" -#include "avx512-common-argsort.h" #include "avx512-64bit-keyvalue-networks.hpp" +#include "avx512-common-argsort.h" template -void std_argselect_withnan(T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right) +void std_argselect_withnan( + T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right) { std::nth_element(arg + left, arg + k, arg + right, [arr](int64_t a, int64_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {return arr[a] < arr[b];} - else if (std::isnan(arr[a])) {return false;} - else {return true;} + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] < arr[b]; + } + else if (std::isnan(arr[a])) { + return false; + } + else { + return true; + } }); } - /* argsort using std::sort */ template void std_argsort_withnan(T *arr, int64_t *arg, int64_t left, int64_t right) @@ -32,9 +38,15 @@ void std_argsort_withnan(T *arr, int64_t *arg, int64_t left, int64_t right) std::sort(arg + left, arg + right, [arr](int64_t left, int64_t right) -> bool { - if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) {return arr[left] < arr[right];} - else if (std::isnan(arr[left])) {return false;} - else {return true;} + if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { + return arr[left] < arr[right]; + } + else if (std::isnan(arr[left])) { + return false; + } + else { + return true; + } }); } @@ -325,13 +337,15 @@ static void argselect_64bit_(type_t *arr, int64_t pivot_index = partition_avx512_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if ((pivot != smallest) && (pos < pivot_index)) - argselect_64bit_(arr, arg, pos, left, pivot_index - 1, max_iters - 1); + argselect_64bit_( + arr, arg, pos, left, pivot_index - 1, max_iters - 1); else if ((pivot != biggest) && (pos >= pivot_index)) - argselect_64bit_(arr, arg, pos, pivot_index, right, max_iters - 1); + argselect_64bit_( + arr, arg, pos, pivot_index, right, max_iters - 1); } template -bool has_nan(type_t* arr, int64_t arrsize) +bool has_nan(type_t *arr, int64_t arrsize) { using opmask_t = typename vtype::opmask_t; using zmm_t = typename vtype::zmm_t; @@ -346,7 +360,7 @@ bool has_nan(type_t* arr, int64_t arrsize) else { in = vtype::loadu(arr); } - opmask_t nanmask = vtype::template fpclass<0x01|0x80>(in); + opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); arr += vtype::numlanes; arrsize -= vtype::numlanes; if (nanmask != 0x00) { @@ -357,10 +371,9 @@ bool has_nan(type_t* arr, int64_t arrsize) return found_nan; } - /* argsort methods for 32-bit and 64-bit dtypes */ template -void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize) +void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize) { if (arrsize > 1) { argsort_64bit_>( @@ -369,7 +382,7 @@ void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize) } template <> -void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize) +void avx512_argsort(double *arr, int64_t *arg, int64_t arrsize) { if (arrsize > 1) { if (has_nan>(arr, arrsize)) { @@ -382,9 +395,8 @@ void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize) } } - template <> -void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize) +void avx512_argsort(int32_t *arr, int64_t *arg, int64_t arrsize) { if (arrsize > 1) { argsort_64bit_>( @@ -393,7 +405,7 @@ void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize) } template <> -void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize) +void avx512_argsort(uint32_t *arr, int64_t *arg, int64_t arrsize) { if (arrsize > 1) { argsort_64bit_>( @@ -402,7 +414,7 @@ void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize) } template <> -void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize) +void avx512_argsort(float *arr, int64_t *arg, int64_t arrsize) { if (arrsize > 1) { if (has_nan>(arr, arrsize)) { @@ -416,7 +428,7 @@ void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize) } template -std::vector avx512_argsort(T* arr, int64_t arrsize) +std::vector avx512_argsort(T *arr, int64_t arrsize) { std::vector indices(arrsize); std::iota(indices.begin(), indices.end(), 0); @@ -426,7 +438,7 @@ std::vector avx512_argsort(T* arr, int64_t arrsize) /* argselect methods for 32-bit and 64-bit dtypes */ template -void avx512_argselect(T* arr, int64_t *arg, int64_t k, int64_t arrsize) +void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize) { if (arrsize > 1) { argselect_64bit_>( @@ -435,7 +447,7 @@ void avx512_argselect(T* arr, int64_t *arg, int64_t k, int64_t arrsize) } template <> -void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize) +void avx512_argselect(double *arr, int64_t *arg, int64_t k, int64_t arrsize) { if (arrsize > 1) { if (has_nan>(arr, arrsize)) { @@ -449,7 +461,7 @@ void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize) } template <> -void avx512_argselect(int32_t* arr, int64_t *arg, int64_t k, int64_t arrsize) +void avx512_argselect(int32_t *arr, int64_t *arg, int64_t k, int64_t arrsize) { if (arrsize > 1) { argselect_64bit_>( @@ -458,7 +470,7 @@ void avx512_argselect(int32_t* arr, int64_t *arg, int64_t k, int64_t arrsize) } template <> -void avx512_argselect(uint32_t* arr, int64_t *arg, int64_t k, int64_t arrsize) +void avx512_argselect(uint32_t *arr, int64_t *arg, int64_t k, int64_t arrsize) { if (arrsize > 1) { argselect_64bit_>( @@ -467,7 +479,7 @@ void avx512_argselect(uint32_t* arr, int64_t *arg, int64_t k, int64_t arrsize) } template <> -void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize) +void avx512_argselect(float *arr, int64_t *arg, int64_t k, int64_t arrsize) { if (arrsize > 1) { if (has_nan>(arr, arrsize)) { @@ -481,7 +493,7 @@ void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize) } template -std::vector avx512_argselect(T* arr, int64_t k, int64_t arrsize) +std::vector avx512_argselect(T *arr, int64_t k, int64_t arrsize) { std::vector indices(arrsize); std::iota(indices.begin(), indices.end(), 0); diff --git a/src/avx512-64bit-keyvalue-networks.hpp b/src/avx512-64bit-keyvalue-networks.hpp index af3a2a98..b930a42b 100644 --- a/src/avx512-64bit-keyvalue-networks.hpp +++ b/src/avx512-64bit-keyvalue-networks.hpp @@ -136,14 +136,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, typename vtype1::opmask_t movmask1 = vtype1::eq(key_zmm_t1, key_zmm[0]); typename vtype1::opmask_t movmask2 = vtype1::eq(key_zmm_t2, key_zmm[1]); - index_type index_zmm_t1 = vtype2::mask_mov( - index_zmm3r, movmask1, index_zmm[0]); - index_type index_zmm_m1 = vtype2::mask_mov( - index_zmm[0], movmask1, index_zmm3r); - index_type index_zmm_t2 = vtype2::mask_mov( - index_zmm2r, movmask2, index_zmm[1]); - index_type index_zmm_m2 = vtype2::mask_mov( - index_zmm[1], movmask2, index_zmm2r); + index_type index_zmm_t1 + = vtype2::mask_mov(index_zmm3r, movmask1, index_zmm[0]); + index_type index_zmm_m1 + = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm3r); + index_type index_zmm_t2 + = vtype2::mask_mov(index_zmm2r, movmask2, index_zmm[1]); + index_type index_zmm_m2 + = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm2r); // 2) Recursive half clearer: 16 zmm_t key_zmm_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2); @@ -159,14 +159,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, movmask1 = vtype1::eq(key_zmm0, key_zmm_t1); movmask2 = vtype1::eq(key_zmm2, key_zmm_t3); - index_type index_zmm0 = vtype2::mask_mov( - index_zmm_t2, movmask1, index_zmm_t1); - index_type index_zmm1 = vtype2::mask_mov( - index_zmm_t1, movmask1, index_zmm_t2); - index_type index_zmm2 = vtype2::mask_mov( - index_zmm_t4, movmask2, index_zmm_t3); - index_type index_zmm3 = vtype2::mask_mov( - index_zmm_t3, movmask2, index_zmm_t4); + index_type index_zmm0 + = vtype2::mask_mov(index_zmm_t2, movmask1, index_zmm_t1); + index_type index_zmm1 + = vtype2::mask_mov(index_zmm_t1, movmask1, index_zmm_t2); + index_type index_zmm2 + = vtype2::mask_mov(index_zmm_t4, movmask2, index_zmm_t3); + index_type index_zmm3 + = vtype2::mask_mov(index_zmm_t3, movmask2, index_zmm_t4); key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); @@ -212,22 +212,22 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, typename vtype1::opmask_t movmask3 = vtype1::eq(key_zmm_t3, key_zmm[2]); typename vtype1::opmask_t movmask4 = vtype1::eq(key_zmm_t4, key_zmm[3]); - index_type index_zmm_t1 = vtype2::mask_mov( - index_zmm7r, movmask1, index_zmm[0]); - index_type index_zmm_m1 = vtype2::mask_mov( - index_zmm[0], movmask1, index_zmm7r); - index_type index_zmm_t2 = vtype2::mask_mov( - index_zmm6r, movmask2, index_zmm[1]); - index_type index_zmm_m2 = vtype2::mask_mov( - index_zmm[1], movmask2, index_zmm6r); - index_type index_zmm_t3 = vtype2::mask_mov( - index_zmm5r, movmask3, index_zmm[2]); - index_type index_zmm_m3 = vtype2::mask_mov( - index_zmm[2], movmask3, index_zmm5r); - index_type index_zmm_t4 = vtype2::mask_mov( - index_zmm4r, movmask4, index_zmm[3]); - index_type index_zmm_m4 = vtype2::mask_mov( - index_zmm[3], movmask4, index_zmm4r); + index_type index_zmm_t1 + = vtype2::mask_mov(index_zmm7r, movmask1, index_zmm[0]); + index_type index_zmm_m1 + = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm7r); + index_type index_zmm_t2 + = vtype2::mask_mov(index_zmm6r, movmask2, index_zmm[1]); + index_type index_zmm_m2 + = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm6r); + index_type index_zmm_t3 + = vtype2::mask_mov(index_zmm5r, movmask3, index_zmm[2]); + index_type index_zmm_m3 + = vtype2::mask_mov(index_zmm[2], movmask3, index_zmm5r); + index_type index_zmm_t4 + = vtype2::mask_mov(index_zmm4r, movmask4, index_zmm[3]); + index_type index_zmm_m4 + = vtype2::mask_mov(index_zmm[3], movmask4, index_zmm4r); zmm_t key_zmm_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4); zmm_t key_zmm_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index aa5d7958..d59a1788 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -784,7 +784,10 @@ static void qselect_64bit_(type_t *arr, } template <> -void avx512_qselect(int64_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void avx512_qselect(int64_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan) { if (arrsize > 1) { qselect_64bit_, int64_t>( @@ -793,7 +796,10 @@ void avx512_qselect(int64_t *arr, int64_t k, int64_t arrsize, bool hasn } template <> -void avx512_qselect(uint64_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void avx512_qselect(uint64_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan) { if (arrsize > 1) { qselect_64bit_, uint64_t>( @@ -802,15 +808,18 @@ void avx512_qselect(uint64_t *arr, int64_t k, int64_t arrsize, bool ha } template <> -void avx512_qselect(double *arr, int64_t k, int64_t arrsize, bool hasnan) +void avx512_qselect(double *arr, + int64_t k, + int64_t arrsize, + bool hasnan) { int64_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { qselect_64bit_, double>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 8e87ac29..5bb4c6c0 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -157,11 +157,11 @@ void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan) { int64_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { qselect_16bit_, _Float16>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/tests/test-argselect.hpp b/tests/test-argselect.hpp index 23375b54..c660d83f 100644 --- a/tests/test-argselect.hpp +++ b/tests/test-argselect.hpp @@ -8,32 +8,49 @@ class avx512argselect : public ::testing::Test { }; TYPED_TEST_SUITE_P(avx512argselect); - template -T std_min_element(std::vector arr, std::vector arg, int64_t left, int64_t right) +T std_min_element(std::vector arr, + std::vector arg, + int64_t left, + int64_t right) { - std::vector::iterator res = - std::min_element(arg.begin() + left, - arg.begin() + right, - [arr](int64_t a, int64_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {return arr[a] < arr[b];} - else if (std::isnan(arr[a])) {return false;} - else {return true;} - }); + std::vector::iterator res = std::min_element( + arg.begin() + left, + arg.begin() + right, + [arr](int64_t a, int64_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] < arr[b]; + } + else if (std::isnan(arr[a])) { + return false; + } + else { + return true; + } + }); return arr[*res]; } template -T std_max_element(std::vector arr, std::vector arg, int64_t left, int64_t right) +T std_max_element(std::vector arr, + std::vector arg, + int64_t left, + int64_t right) { - std::vector::iterator res = - std::max_element(arg.begin() + left, - arg.begin() + right, - [arr](int64_t a, int64_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {return arr[a] > arr[b];} - else if (std::isnan(arr[a])) {return true;} - else {return false;} - }); + std::vector::iterator res = std::max_element( + arg.begin() + left, + arg.begin() + right, + [arr](int64_t a, int64_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] > arr[b]; + } + else if (std::isnan(arr[a])) { + return true; + } + else { + return false; + } + }); return arr[*res]; } @@ -49,7 +66,7 @@ TYPED_TEST_P(avx512argselect, test_random) } sorted_inx = std_argsort(arr); std::vector kth; - for (int64_t ii = 0; ii < arrsize-3; ++ii) { + for (int64_t ii = 0; ii < arrsize - 3; ++ii) { kth.push_back(ii); } for (auto &k : kth) { @@ -58,9 +75,12 @@ TYPED_TEST_P(avx512argselect, test_random) auto true_kth = arr[sorted_inx[k]]; EXPECT_EQ(true_kth, arr[inx[k]]) << "Failed at index k = " << k; if (k >= 1) - EXPECT_GE(true_kth, std_max_element(arr, inx, 0, k-1)) << "failed at k = " << k; - if (k != arrsize-1) - EXPECT_LE(true_kth, std_min_element(arr, inx, k+1, arrsize-1)) << "failed at k = " << k; + EXPECT_GE(true_kth, std_max_element(arr, inx, 0, k - 1)) + << "failed at k = " << k; + if (k != arrsize - 1) + EXPECT_LE(true_kth, + std_min_element(arr, inx, k + 1, arrsize - 1)) + << "failed at k = " << k; EXPECT_UNIQUE(inx) } } diff --git a/tests/test-argsort-common.h b/tests/test-argsort-common.h index d757e968..2e293620 100644 --- a/tests/test-argsort-common.h +++ b/tests/test-argsort-common.h @@ -1,9 +1,9 @@ +#include "avx512-64bit-argsort.hpp" +#include "cpuinfo.h" +#include "rand_array.h" #include #include #include -#include "cpuinfo.h" -#include "rand_array.h" -#include "avx512-64bit-argsort.hpp" template std::vector std_argsort(const std::vector &arr) @@ -13,16 +13,69 @@ std::vector std_argsort(const std::vector &arr) std::sort(indices.begin(), indices.end(), [&arr](int64_t left, int64_t right) -> bool { - if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) {return arr[left] < arr[right];} - else if (std::isnan(arr[left])) {return false;} - else {return true;} - }); + if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { + return arr[left] < arr[right]; + } + else if (std::isnan(arr[left])) { + return false; + } + else { + return true; + } + }); return indices; } +template +T std_min_element(std::vector arr, + std::vector arg, + int64_t left, + int64_t right) +{ + std::vector::iterator res = std::min_element( + arg.begin() + left, + arg.begin() + right, + [arr](int64_t a, int64_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] < arr[b]; + } + else if (std::isnan(arr[a])) { + return false; + } + else { + return true; + } + }); + return arr[*res]; +} + +template +T std_max_element(std::vector arr, + std::vector arg, + int64_t left, + int64_t right) +{ + std::vector::iterator res = std::max_element( + arg.begin() + left, + arg.begin() + right, + [arr](int64_t a, int64_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] > arr[b]; + } + else if (std::isnan(arr[a])) { + return true; + } + else { + return false; + } + }); + return arr[*res]; +} + #define EXPECT_UNIQUE(sorted_arg) \ std::sort(sorted_arg.begin(), sorted_arg.end()); \ std::vector expected_arg(sorted_arg.size()); \ std::iota(expected_arg.begin(), expected_arg.end(), 0); \ - EXPECT_EQ(sorted_arg, expected_arg) << "Indices aren't unique. Array size = " << sorted_arg.size(); + EXPECT_EQ(sorted_arg, expected_arg) \ + << "Indices aren't unique. Array size = " << sorted_arg.size(); diff --git a/tests/test-argsort.cpp b/tests/test-argsort.cpp index b4d2b9e2..321a84c2 100644 --- a/tests/test-argsort.cpp +++ b/tests/test-argsort.cpp @@ -1,6 +1,6 @@ -#include "test-argsort-common.h" -#include "test-argselect.hpp" #include "test-argsort.hpp" +#include "test-argselect.hpp" +#include "test-argsort-common.h" using ArgTestTypes = testing::Types; diff --git a/tests/test-argsort.hpp b/tests/test-argsort.hpp index 5ef9e6ea..f7a4a23f 100644 --- a/tests/test-argsort.hpp +++ b/tests/test-argsort.hpp @@ -270,4 +270,3 @@ REGISTER_TYPED_TEST_SUITE_P(avx512argsort, test_all_inf_array, test_array_with_nan, test_max_value_at_end_of_array); - diff --git a/tests/test-qsort-fp.hpp b/tests/test-qsort-fp.hpp index 9000fb38..438305b1 100644 --- a/tests/test-qsort-fp.hpp +++ b/tests/test-qsort-fp.hpp @@ -26,15 +26,17 @@ TYPED_TEST_P(avx512_sort_fp, test_random_nan) /* Random array */ arr = get_uniform_rand_array(size); for (auto ii = 1; ii <= num_nans; ++ii) { - arr[size-ii] = std::numeric_limits::quiet_NaN(); + arr[size - ii] = std::numeric_limits::quiet_NaN(); } sortedarr = arr; - std::sort(sortedarr.begin(), sortedarr.end()-3); + std::sort(sortedarr.begin(), sortedarr.end() - 3); std::random_shuffle(arr.begin(), arr.end()); avx512_qsort(arr.data(), arr.size()); for (auto ii = 1; ii <= num_nans; ++ii) { - if (!std::isnan(arr[size-ii])) { - ASSERT_TRUE(false) << "NAN's aren't sorted to the end. Arr size = " << size; + if (!std::isnan(arr[size - ii])) { + ASSERT_TRUE(false) + << "NAN's aren't sorted to the end. Arr size = " + << size; } } if (!std::is_sorted(arr.begin(), arr.end() - num_nans)) { diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index eb3d5f77..a35d8e8c 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -1,7 +1,7 @@ +#include "test-qsort.hpp" #include "test-partial-qsort.hpp" #include "test-qselect.hpp" #include "test-qsort-fp.hpp" -#include "test-qsort.hpp" using QSortTestTypes = testing::Types Date: Thu, 20 Jul 2023 20:55:36 -0700 Subject: [PATCH 6/6] Fix build issues --- tests/test-argselect.hpp | 47 +--------------------------------------- tests/test-argsort.cpp | 2 +- 2 files changed, 2 insertions(+), 47 deletions(-) diff --git a/tests/test-argselect.hpp b/tests/test-argselect.hpp index c660d83f..298000d4 100644 --- a/tests/test-argselect.hpp +++ b/tests/test-argselect.hpp @@ -6,53 +6,8 @@ template class avx512argselect : public ::testing::Test { }; -TYPED_TEST_SUITE_P(avx512argselect); - -template -T std_min_element(std::vector arr, - std::vector arg, - int64_t left, - int64_t right) -{ - std::vector::iterator res = std::min_element( - arg.begin() + left, - arg.begin() + right, - [arr](int64_t a, int64_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { - return arr[a] < arr[b]; - } - else if (std::isnan(arr[a])) { - return false; - } - else { - return true; - } - }); - return arr[*res]; -} -template -T std_max_element(std::vector arr, - std::vector arg, - int64_t left, - int64_t right) -{ - std::vector::iterator res = std::max_element( - arg.begin() + left, - arg.begin() + right, - [arr](int64_t a, int64_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { - return arr[a] > arr[b]; - } - else if (std::isnan(arr[a])) { - return true; - } - else { - return false; - } - }); - return arr[*res]; -} +TYPED_TEST_SUITE_P(avx512argselect); TYPED_TEST_P(avx512argselect, test_random) { diff --git a/tests/test-argsort.cpp b/tests/test-argsort.cpp index 321a84c2..41ce5ca4 100644 --- a/tests/test-argsort.cpp +++ b/tests/test-argsort.cpp @@ -1,6 +1,6 @@ +#include "test-argsort-common.h" #include "test-argsort.hpp" #include "test-argselect.hpp" -#include "test-argsort-common.h" using ArgTestTypes = testing::Types;