diff --git a/.github/workflows/c-cpp.yml b/.github/workflows/c-cpp.yml index 96d24cb7..3cf380b9 100644 --- a/.github/workflows/c-cpp.yml +++ b/.github/workflows/c-cpp.yml @@ -125,3 +125,34 @@ jobs: - name: Run test suite on SPR run: sde -spr -- ./builddir/testexe + + SPR-gcc13-min-networksort: + + runs-on: intel-ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Install dependencies + run: | + sudo apt update + sudo apt -y install g++-13 libgtest-dev meson curl git cmake + + - name: Install Intel SDE + run: | + curl -o /tmp/sde.tar.xz https://downloadmirror.intel.com/784319/sde-external-9.24.0-2023-07-13-lin.tar.xz + mkdir /tmp/sde && tar -xvf /tmp/sde.tar.xz -C /tmp/sde/ + sudo mv /tmp/sde/* /opt/sde && sudo ln -s /opt/sde/sde64 /usr/bin/sde + + - name: Build + env: + CXX: g++-13 + CXXFLAGS: -DXSS_MINIMAL_NETWORK_SORT + run: | + make clean + meson setup --warnlevel 2 --werror --buildtype release builddir + cd builddir + ninja + + - name: Run test suite on SPR + run: sde -spr -- ./builddir/testexe diff --git a/Makefile b/Makefile index f25c8dad..b54dc288 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ # When unset, discover g++. Prioritise the latest version on the path. ifeq (, $(and $(strip $(CXX)), $(filter-out default undefined, $(origin CXX)))) - override CXX := $(shell which g++-12 g++-11 g++-10 g++-9 g++-8 g++ 2>/dev/null | head -n 1) + override CXX := $(shell which g++-13 g++-12 g++-11 g++-10 g++-9 g++-8 g++ 2>/dev/null | head -n 1) ifeq (, $(strip $(CXX))) $(error Could not locate the g++ compiler. Please manually specify its path using the CXX variable) endif diff --git a/README.md b/README.md index 47ef2460..5ad30ce7 100644 --- a/README.md +++ b/README.md @@ -170,3 +170,4 @@ Skylake https://arxiv.org/pdf/1704.08579.pdf * [4] http://mitp-content-server.mit.edu:18180/books/content/sectbyfn?collid=books_pres_0&fn=Chapter%2027.pdf&id=8030 +* [5] https://bertdobbelaere.github.io/sorting_networks.html \ No newline at end of file diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h index e51ac14a..532da825 100644 --- a/src/avx512-16bit-common.h +++ b/src/avx512-16bit-common.h @@ -8,7 +8,6 @@ #define AVX512_16BIT_COMMON #include "avx512-common-qsort.h" -#include "xss-network-qsort.hpp" /* * Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic @@ -93,30 +92,221 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_16bit(reg_t zmm) return zmm; } -// Assumes zmm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_16bit(reg_t zmm) -{ - // 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc .. - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(6), zmm), 0xFFFF0000); - // 2) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc .. - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00); - // 3) half_cleaner[8] - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); - // 3) half_cleaner[4] - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCCCCCC); - // 3) half_cleaner[2] - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAAAAAA); - return zmm; -} +struct avx512_16bit_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) + { + __m512i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { + __m512i mask = _mm512_set_epi16(30, + 31, + 28, + 29, + 26, + 27, + 24, + 25, + 22, + 23, + 20, + 21, + 18, + 19, + 16, + 17, + 14, + 15, + 12, + 13, + 10, + 11, + 8, + 9, + 6, + 7, + 4, + 5, + 2, + 3, + 0, + 1); + v = _mm512_permutexvar_epi16(mask, v); + } + else if constexpr (scale == 4) { + v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b10110001); + } + else if constexpr (scale == 8) { + v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b01001110); + } + else if constexpr (scale == 16) { + v = _mm512_shuffle_i64x2(v, v, 0b10110001); + } + else if constexpr (scale == 32) { + v = _mm512_shuffle_i64x2(v, v, 0b01001110); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + reverse_n(typename vtype::reg_t reg) + { + __m512i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { return swap_n(reg); } + else if constexpr (scale == 4) { + __m512i mask = _mm512_set_epi16(28, + 29, + 30, + 31, + 24, + 25, + 26, + 27, + 20, + 21, + 22, + 23, + 16, + 17, + 18, + 19, + 12, + 13, + 14, + 15, + 8, + 9, + 10, + 11, + 4, + 5, + 6, + 7, + 0, + 1, + 2, + 3); + v = _mm512_permutexvar_epi16(mask, v); + } + else if constexpr (scale == 8) { + __m512i mask = _mm512_set_epi16(24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7); + v = _mm512_permutexvar_epi16(mask, v); + } + else if constexpr (scale == 16) { + __m512i mask = _mm512_set_epi16(16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15); + v = _mm512_permutexvar_epi16(mask, v); + } + else if constexpr (scale == 32) { + return vtype::reverse(reg); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) + { + __m512i v1 = vtype::cast_to(reg); + __m512i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { + v1 = _mm512_mask_blend_epi16( + 0b01010101010101010101010101010101, v1, v2); + } + else if constexpr (scale == 4) { + v1 = _mm512_mask_blend_epi16( + 0b00110011001100110011001100110011, v1, v2); + } + else if constexpr (scale == 8) { + v1 = _mm512_mask_blend_epi16( + 0b00001111000011110000111100001111, v1, v2); + } + else if constexpr (scale == 16) { + v1 = _mm512_mask_blend_epi16( + 0b00000000111111110000000011111111, v1, v2); + } + else if constexpr (scale == 32) { + v1 = _mm512_mask_blend_epi16( + 0b00000000000000001111111111111111, v1, v2); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; #endif // AVX512_16BIT_COMMON diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index edd118b3..fdfba924 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -20,8 +20,14 @@ struct zmm_vector { using halfreg_t = __m256i; using opmask_t = __mmask32; static const uint8_t numlanes = 32; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else static constexpr int network_sort_threshold = 512; - static constexpr int partition_unroll_factor = 0; +#endif + static constexpr int partition_unroll_factor = 8; + + using swizzle_ops = avx512_16bit_swizzle_ops; static reg_t get_network(int index) { @@ -159,14 +165,18 @@ struct zmm_vector { const auto rev_index = get_network(4); return permutexvar(rev_index, zmm); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_zmm_16bit>(x); - } static reg_t sort_vec(reg_t x) { return sort_zmm_16bit>(x); } + static reg_t cast_from(__m512i v) + { + return v; + } + static __m512i cast_to(reg_t v) + { + return v; + } }; template <> @@ -176,8 +186,14 @@ struct zmm_vector { using halfreg_t = __m256i; using opmask_t = __mmask32; static const uint8_t numlanes = 32; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else static constexpr int network_sort_threshold = 512; - static constexpr int partition_unroll_factor = 0; +#endif + static constexpr int partition_unroll_factor = 8; + + using swizzle_ops = avx512_16bit_swizzle_ops; static reg_t get_network(int index) { @@ -273,14 +289,18 @@ struct zmm_vector { const auto rev_index = get_network(4); return permutexvar(rev_index, zmm); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_zmm_16bit>(x); - } static reg_t sort_vec(reg_t x) { return sort_zmm_16bit>(x); } + static reg_t cast_from(__m512i v) + { + return v; + } + static __m512i cast_to(reg_t v) + { + return v; + } }; template <> struct zmm_vector { @@ -289,8 +309,14 @@ struct zmm_vector { using halfreg_t = __m256i; using opmask_t = __mmask32; static const uint8_t numlanes = 32; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else static constexpr int network_sort_threshold = 512; - static constexpr int partition_unroll_factor = 0; +#endif + static constexpr int partition_unroll_factor = 8; + + using swizzle_ops = avx512_16bit_swizzle_ops; static reg_t get_network(int index) { @@ -384,14 +410,18 @@ struct zmm_vector { const auto rev_index = get_network(4); return permutexvar(rev_index, zmm); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_zmm_16bit>(x); - } static reg_t sort_vec(reg_t x) { return sort_zmm_16bit>(x); } + static reg_t cast_from(__m512i v) + { + return v; + } + static __m512i cast_to(reg_t v) + { + return v; + } }; template <> @@ -445,8 +475,7 @@ arrsize_t replace_nan_with_inf>(uint16_t *arr, template <> bool is_a_nan(uint16_t elem) { - return ((elem & 0x7c00u) == 0x7c00u) && - ((elem & 0x03ffu) != 0); + return ((elem & 0x7c00u) == 0x7c00u) && ((elem & 0x03ffu) != 0); } X86_SIMD_SORT_INLINE diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index fd427c28..dc56e370 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -9,7 +9,6 @@ #define AVX512_QSORT_32BIT #include "avx512-common-qsort.h" -#include "xss-network-qsort.hpp" /* * Constants used in sorting 16 elements in a ZMM registers. Based on Bitonic @@ -27,8 +26,7 @@ template X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm); -template -X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_32bit(reg_t zmm); +struct avx512_32bit_swizzle_ops; template <> struct zmm_vector { @@ -37,8 +35,14 @@ struct zmm_vector { using halfreg_t = __m256i; using opmask_t = __mmask16; static const uint8_t numlanes = 16; - static constexpr int network_sort_threshold = 256; - static constexpr int partition_unroll_factor = 2; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else + static constexpr int network_sort_threshold = 512; +#endif + static constexpr int partition_unroll_factor = 8; + + using swizzle_ops = avx512_32bit_swizzle_ops; static type_t type_max() { @@ -138,14 +142,18 @@ struct zmm_vector { const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5); return permutexvar(rev_index, zmm); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_zmm_32bit>(x); - } static reg_t sort_vec(reg_t x) { return sort_zmm_32bit>(x); } + static reg_t cast_from(__m512i v) + { + return v; + } + static __m512i cast_to(reg_t v) + { + return v; + } }; template <> struct zmm_vector { @@ -154,8 +162,14 @@ struct zmm_vector { using halfreg_t = __m256i; using opmask_t = __mmask16; static const uint8_t numlanes = 16; - static constexpr int network_sort_threshold = 256; - static constexpr int partition_unroll_factor = 2; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else + static constexpr int network_sort_threshold = 512; +#endif + static constexpr int partition_unroll_factor = 8; + + using swizzle_ops = avx512_32bit_swizzle_ops; static type_t type_max() { @@ -255,14 +269,18 @@ struct zmm_vector { const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5); return permutexvar(rev_index, zmm); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_zmm_32bit>(x); - } static reg_t sort_vec(reg_t x) { return sort_zmm_32bit>(x); } + static reg_t cast_from(__m512i v) + { + return v; + } + static __m512i cast_to(reg_t v) + { + return v; + } }; template <> struct zmm_vector { @@ -271,8 +289,14 @@ struct zmm_vector { using halfreg_t = __m256; using opmask_t = __mmask16; static const uint8_t numlanes = 16; - static constexpr int network_sort_threshold = 256; - static constexpr int partition_unroll_factor = 2; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else + static constexpr int network_sort_threshold = 512; +#endif + static constexpr int partition_unroll_factor = 8; + + using swizzle_ops = avx512_32bit_swizzle_ops; static type_t type_max() { @@ -386,14 +410,18 @@ struct zmm_vector { const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5); return permutexvar(rev_index, zmm); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_zmm_32bit>(x); - } static reg_t sort_vec(reg_t x) { return sort_zmm_32bit>(x); } + static reg_t cast_from(__m512i v) + { + return _mm512_castsi512_ps(v); + } + static __m512i cast_to(reg_t v) + { + return _mm512_castps_si512(v); + } }; /* @@ -446,31 +474,83 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm) return zmm; } -// Assumes zmm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_32bit(reg_t zmm) -{ - // 1) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc .. - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_7), zmm), - 0xFF00); - // 2) half_cleaner[8]: compare 1-5, 2-6, 3-7 etc .. - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_6), zmm), - 0xF0F0); - // 3) half_cleaner[4] - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCC); - // 3) half_cleaner[1] - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAA); - return zmm; -} +struct avx512_32bit_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) + { + __m512i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { + v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b10110001); + } + else if constexpr (scale == 4) { + v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b01001110); + } + else if constexpr (scale == 8) { + v = _mm512_shuffle_i64x2(v, v, 0b10110001); + } + else if constexpr (scale == 16) { + v = _mm512_shuffle_i64x2(v, v, 0b01001110); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + reverse_n(typename vtype::reg_t reg) + { + __m512i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { return swap_n(reg); } + else if constexpr (scale == 4) { + __m512i mask = _mm512_set_epi32( + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3); + v = _mm512_permutexvar_epi32(mask, v); + } + else if constexpr (scale == 8) { + __m512i mask = _mm512_set_epi32( + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7); + v = _mm512_permutexvar_epi32(mask, v); + } + else if constexpr (scale == 16) { + return vtype::reverse(reg); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) + { + __m512i v1 = vtype::cast_to(reg); + __m512i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { + v1 = _mm512_mask_blend_epi32(0b0101010101010101, v1, v2); + } + else if constexpr (scale == 4) { + v1 = _mm512_mask_blend_epi32(0b0011001100110011, v1, v2); + } + else if constexpr (scale == 8) { + v1 = _mm512_mask_blend_epi32(0b0000111100001111, v1, v2); + } + else if constexpr (scale == 16) { + v1 = _mm512_mask_blend_epi32(0b0000000011111111, v1, v2); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; #endif //AVX512_QSORT_32BIT diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 4571a469..b5072ccc 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -413,15 +413,14 @@ template X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, int64_t *arg, arrsize_t k, arrsize_t arrsize) { - avx512_argselect(arr, reinterpret_cast(arg), k, arrsize); + avx512_argselect(arr, reinterpret_cast(arg), k, arrsize); } template X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, int64_t *arg, arrsize_t arrsize) { - avx512_argsort(arr, reinterpret_cast(arg), arrsize); + avx512_argsort(arr, reinterpret_cast(arg), arrsize); } - #endif // AVX512_ARGSORT_64BIT diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index f9018231..387d8b57 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -22,8 +22,7 @@ template X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm); -template -X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_64bit(reg_t zmm); +struct avx512_64bit_swizzle_ops; template <> struct ymm_vector { @@ -485,9 +484,15 @@ struct zmm_vector { using halfreg_t = __m512i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else static constexpr int network_sort_threshold = 256; +#endif static constexpr int partition_unroll_factor = 8; + using swizzle_ops = avx512_64bit_swizzle_ops; + static type_t type_max() { return X86_SIMD_SORT_MAX_INT64; @@ -618,14 +623,18 @@ struct zmm_vector { const regi_t rev_index = seti(NETWORK_64BIT_2); return permutexvar(rev_index, zmm); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_zmm_64bit>(x); - } static reg_t sort_vec(reg_t x) { return sort_zmm_64bit>(x); } + static reg_t cast_from(__m512i v) + { + return v; + } + static __m512i cast_to(reg_t v) + { + return v; + } }; template <> struct zmm_vector { @@ -635,9 +644,15 @@ struct zmm_vector { using halfreg_t = __m512i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else static constexpr int network_sort_threshold = 256; +#endif static constexpr int partition_unroll_factor = 8; + using swizzle_ops = avx512_64bit_swizzle_ops; + static type_t type_max() { return X86_SIMD_SORT_MAX_UINT64; @@ -760,14 +775,18 @@ struct zmm_vector { const regi_t rev_index = seti(NETWORK_64BIT_2); return permutexvar(rev_index, zmm); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_zmm_64bit>(x); - } static reg_t sort_vec(reg_t x) { return sort_zmm_64bit>(x); } + static reg_t cast_from(__m512i v) + { + return v; + } + static __m512i cast_to(reg_t v) + { + return v; + } }; template <> struct zmm_vector { @@ -777,9 +796,15 @@ struct zmm_vector { using halfreg_t = __m512d; using opmask_t = __mmask8; static const uint8_t numlanes = 8; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else static constexpr int network_sort_threshold = 256; +#endif static constexpr int partition_unroll_factor = 8; + using swizzle_ops = avx512_64bit_swizzle_ops; + static type_t type_max() { return X86_SIMD_SORT_INFINITY; @@ -908,14 +933,18 @@ struct zmm_vector { const regi_t rev_index = seti(NETWORK_64BIT_2); return permutexvar(rev_index, zmm); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_zmm_64bit>(x); - } static reg_t sort_vec(reg_t x) { return sort_zmm_64bit>(x); } + static reg_t cast_from(__m512i v) + { + return _mm512_castsi512_pd(v); + } + static __m512i cast_to(reg_t v) + { + return _mm512_castpd_si512(v); + } }; /* @@ -940,24 +969,71 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm) return zmm; } -// Assumes zmm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_64bit(reg_t zmm) -{ - // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), zmm), - 0xF0); - // 2) half_cleaner[4] - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), zmm), - 0xCC); - // 3) half_cleaner[1] - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - return zmm; -} +struct avx512_64bit_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) + { + __m512i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { + v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b01001110); + } + else if constexpr (scale == 4) { + v = _mm512_shuffle_i64x2(v, v, 0b10110001); + } + else if constexpr (scale == 8) { + v = _mm512_shuffle_i64x2(v, v, 0b01001110); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + reverse_n(typename vtype::reg_t reg) + { + __m512i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { return swap_n(reg); } + else if constexpr (scale == 4) { + constexpr uint64_t mask = 0b00011011; + v = _mm512_permutex_epi64(v, mask); + } + else if constexpr (scale == 8) { + return vtype::reverse(reg); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) + { + __m512i v1 = vtype::cast_to(reg); + __m512i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { + v1 = _mm512_mask_blend_epi64(0b01010101, v1, v2); + } + else if constexpr (scale == 4) { + v1 = _mm512_mask_blend_epi64(0b00110011, v1, v2); + } + else if constexpr (scale == 8) { + v1 = _mm512_mask_blend_epi64(0b00001111, v1, v2); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; #endif diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index 2dae1622..1d15ef55 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -8,6 +8,5 @@ #define AVX512_QSORT_64BIT #include "avx512-64bit-common.h" -#include "xss-network-qsort.hpp" #endif // AVX512_QSORT_64BIT diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 6ea13ce2..b969a069 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -85,7 +85,7 @@ #define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) #elif defined(__GNUC__) #define X86_SIMD_SORT_INLINE static inline -#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) +#define X86_SIMD_SORT_FINLINE static inline __attribute__((always_inline)) #define LIKELY(x) __builtin_expect((x), 1) #define UNLIKELY(x) __builtin_expect((x), 0) #else @@ -103,6 +103,9 @@ typedef size_t arrsize_t; +#include "xss-pivot-selection.hpp" +#include "xss-network-qsort.hpp" + template struct zmm_vector; @@ -203,9 +206,7 @@ X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T *arr, arrsize_t size) } } /* Haven't checked for nan when ii == jj */ - if (is_a_nan(arr[ii])) { - count++; - } + if (is_a_nan(arr[ii])) { count++; } return size - count - 1; } @@ -240,23 +241,23 @@ X86_SIMD_SORT_INLINE reg_t cmp_merge(reg_t in1, reg_t in2, opmask_t mask) * number of elements that are greater than or equal to the pivot. */ template -X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arr, - arrsize_t left, - arrsize_t right, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) +X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, + type_t *r_store, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t &smallest_vec, + reg_t &biggest_vec) { - /* which elements are larger than or equal to the pivot */ typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec); - int32_t amount_ge_pivot = _mm_popcnt_u32((int32_t)ge_mask); - vtype::mask_compressstoreu( - arr + left, vtype::knot_opmask(ge_mask), curr_vec); + arrsize_t amount_ge_pivot = _mm_popcnt_u64(ge_mask); + + vtype::mask_compressstoreu(l_store, vtype::knot_opmask(ge_mask), curr_vec); vtype::mask_compressstoreu( - arr + right - amount_ge_pivot, ge_mask, curr_vec); - *smallest_vec = vtype::min(curr_vec, *smallest_vec); - *biggest_vec = vtype::max(curr_vec, *biggest_vec); + r_store + vtype::numlanes - amount_ge_pivot, ge_mask, curr_vec); + + smallest_vec = vtype::min(curr_vec, smallest_vec); + biggest_vec = vtype::max(curr_vec, biggest_vec); + return amount_ge_pivot; } /* @@ -293,23 +294,27 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, if (right - left == vtype::numlanes) { reg_t vec = vtype::loadu(arr + left); - int32_t amount_ge_pivot = partition_vec(arr, - left, - left + vtype::numlanes, - vec, - pivot_vec, - &min_vec, - &max_vec); + arrsize_t unpartitioned = right - left - vtype::numlanes; + arrsize_t l_store = left; + + arrsize_t amount_ge_pivot + = partition_vec(arr + l_store, + arr + l_store + unpartitioned, + vec, + pivot_vec, + min_vec, + max_vec); + l_store += (vtype::numlanes - amount_ge_pivot); *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); - return left + (vtype::numlanes - amount_ge_pivot); + return l_store; } // first and last vtype::numlanes values are partitioned at the end reg_t vec_left = vtype::loadu(arr + left); reg_t vec_right = vtype::loadu(arr + (right - vtype::numlanes)); // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; + arrsize_t unpartitioned = right - left - vtype::numlanes; arrsize_t l_store = left; // indices for loading the elements left += vtype::numlanes; @@ -321,7 +326,8 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, * then next elements are loaded from the right side, * otherwise from the left side */ - if ((r_store + vtype::numlanes) - right < left - l_store) { + if ((l_store + unpartitioned + vtype::numlanes) - right + < left - l_store) { right -= vtype::numlanes; curr_vec = vtype::loadu(arr + right); } @@ -330,36 +336,37 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, left += vtype::numlanes; } // partition the current vector and save it on both sides of the array - int32_t amount_ge_pivot - = partition_vec(arr, - l_store, - r_store + vtype::numlanes, + arrsize_t amount_ge_pivot + = partition_vec(arr + l_store, + arr + l_store + unpartitioned, curr_vec, pivot_vec, - &min_vec, - &max_vec); - ; - r_store -= amount_ge_pivot; + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); + unpartitioned -= vtype::numlanes; } /* partition and save vec_left and vec_right */ - int32_t amount_ge_pivot = partition_vec(arr, - l_store, - r_store + vtype::numlanes, - vec_left, - pivot_vec, - &min_vec, - &max_vec); + arrsize_t amount_ge_pivot + = partition_vec(arr + l_store, + arr + l_store + unpartitioned, + vec_left, + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); - amount_ge_pivot = partition_vec(arr, - l_store, - l_store + vtype::numlanes, + unpartitioned -= vtype::numlanes; + + amount_ge_pivot = partition_vec(arr + l_store, + arr + l_store + unpartitioned, vec_right, pivot_vec, - &min_vec, - &max_vec); + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); + unpartitioned -= vtype::numlanes; + *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); return l_store; @@ -380,13 +387,14 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, arr, left, right, pivot, smallest, biggest); } - if (right - left <= 2 * num_unroll * vtype::numlanes) { + /* Use regular partition_avx512 for smaller arrays */ + if (right - left < 3 * num_unroll * vtype::numlanes) { return partition_avx512( arr, left, right, pivot, smallest, biggest); } - /* make array length divisible by 8*vtype::numlanes , shortening the array */ - for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; - --i) { + + /* make array length divisible by vtype::numlanes, shortening the array */ + for (int32_t i = ((right - left) % (vtype::numlanes)); i > 0; --i) { *smallest = std::min(*smallest, arr[left], comparison_func); *biggest = std::max(*biggest, arr[left], comparison_func); if (!comparison_func(arr[left], pivot)) { @@ -397,16 +405,28 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, } } - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ + arrsize_t unpartitioned = right - left - vtype::numlanes; + arrsize_t l_store = left; using reg_t = typename vtype::reg_t; reg_t pivot_vec = vtype::set1(pivot); reg_t min_vec = vtype::set1(*smallest); reg_t max_vec = vtype::set1(*biggest); - // We will now have atleast 16 registers worth of data to process: - // left and right vtype::numlanes values are partitioned at the end + /* Calculate and load more registers to make the rest of the array a + * multiple of num_unroll. These registers will be partitioned at the very + * end. */ + int vecsToPartition = ((right - left) / vtype::numlanes) % num_unroll; + reg_t vec_align[num_unroll]; + for (int i = 0; i < vecsToPartition; i++) { + vec_align[i] = vtype::loadu(arr + left + i * vtype::numlanes); + } + left += vecsToPartition * vtype::numlanes; + + /* We will now have atleast 3*num_unroll registers worth of data to + * process. Load left and right vtype::numlanes*num_unroll values into + * registers to make space for in-place parition. The vec_left and + * vec_right registers are partitioned at the end */ reg_t vec_left[num_unroll], vec_right[num_unroll]; X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { @@ -414,10 +434,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, vec_right[ii] = vtype::loadu( arr + (right - vtype::numlanes * (num_unroll - ii))); } - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements + /* indices for loading the elements */ left += num_unroll * vtype::numlanes; right -= num_unroll * vtype::numlanes; while (right - left != 0) { @@ -427,63 +444,83 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, * then next elements are loaded from the right side, * otherwise from the left side */ - if ((r_store + vtype::numlanes) - right < left - l_store) { + if ((l_store + unpartitioned + vtype::numlanes) - right + < left - l_store) { right -= num_unroll * vtype::numlanes; X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { curr_vec[ii] = vtype::loadu(arr + right + ii * vtype::numlanes); + _mm_prefetch(arr + right + ii * vtype::numlanes + - num_unroll * vtype::numlanes, + _MM_HINT_T0); } } else { X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { curr_vec[ii] = vtype::loadu(arr + left + ii * vtype::numlanes); + _mm_prefetch(arr + left + ii * vtype::numlanes + + num_unroll * vtype::numlanes, + _MM_HINT_T0); } left += num_unroll * vtype::numlanes; } - // partition the current vector and save it on both sides of the array + /* partition the current vector and save it on both sides of the array + * */ X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_ge_pivot - = partition_vec(arr, - l_store, - r_store + vtype::numlanes, + arrsize_t amount_ge_pivot + = partition_vec(arr + l_store, + arr + l_store + unpartitioned, curr_vec[ii], pivot_vec, - &min_vec, - &max_vec); + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); - r_store -= amount_ge_pivot; + unpartitioned -= vtype::numlanes; } } - /* partition and save vec_left[8] and vec_right[8] */ + /* partition and save vec_left[num_unroll] and vec_right[num_unroll] */ X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_ge_pivot - = partition_vec(arr, - l_store, - r_store + vtype::numlanes, + arrsize_t amount_ge_pivot + = partition_vec(arr + l_store, + arr + l_store + unpartitioned, vec_left[ii], pivot_vec, - &min_vec, - &max_vec); + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); - r_store -= amount_ge_pivot; + unpartitioned -= vtype::numlanes; } X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_ge_pivot - = partition_vec(arr, - l_store, - r_store + vtype::numlanes, + arrsize_t amount_ge_pivot + = partition_vec(arr + l_store, + arr + l_store + unpartitioned, vec_right[ii], pivot_vec, - &min_vec, - &max_vec); + min_vec, + max_vec); + l_store += (vtype::numlanes - amount_ge_pivot); + unpartitioned -= vtype::numlanes; + } + + /* partition and save vec_align[vecsToPartition] */ + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < vecsToPartition; ++ii) { + arrsize_t amount_ge_pivot + = partition_vec(arr + l_store, + arr + l_store + unpartitioned, + vec_align[ii], + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); - r_store -= amount_ge_pivot; + unpartitioned -= vtype::numlanes; } + *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); return l_store; @@ -697,140 +734,11 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t1 *keys, return l_store; } -template -X86_SIMD_SORT_INLINE type_t get_pivot_scalar(type_t *arr, - const arrsize_t left, - const arrsize_t right) -{ - constexpr arrsize_t numSamples = vtype::numlanes; - type_t samples[numSamples]; - - arrsize_t delta = (right - left) / numSamples; - - for (int i = 0; i < numSamples; i++) { - samples[i] = arr[left + i * delta]; - } - - auto vec = vtype::loadu(samples); - vec = vtype::sort_vec(vec); - return ((type_t *)&vec)[numSamples / 2]; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr, - const arrsize_t left, - const arrsize_t right) -{ - // median of 32 - arrsize_t size = (right - left) / 32; - type_t vec_arr[32] = {arr[left], - arr[left + size], - arr[left + 2 * size], - arr[left + 3 * size], - arr[left + 4 * size], - arr[left + 5 * size], - arr[left + 6 * size], - arr[left + 7 * size], - arr[left + 8 * size], - arr[left + 9 * size], - arr[left + 10 * size], - arr[left + 11 * size], - arr[left + 12 * size], - arr[left + 13 * size], - arr[left + 14 * size], - arr[left + 15 * size], - arr[left + 16 * size], - arr[left + 17 * size], - arr[left + 18 * size], - arr[left + 19 * size], - arr[left + 20 * size], - arr[left + 21 * size], - arr[left + 22 * size], - arr[left + 23 * size], - arr[left + 24 * size], - arr[left + 25 * size], - arr[left + 26 * size], - arr[left + 27 * size], - arr[left + 28 * size], - arr[left + 29 * size], - arr[left + 30 * size], - arr[left + 31 * size]}; - typename vtype::reg_t rand_vec = vtype::loadu(vec_arr); - typename vtype::reg_t sort = vtype::sort_vec(rand_vec); - return ((type_t *)&sort)[16]; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, - const arrsize_t left, - const arrsize_t right) -{ - // median of 16 - arrsize_t size = (right - left) / 16; - using reg_t = typename vtype::reg_t; - type_t vec_arr[16] = {arr[left + size], - arr[left + 2 * size], - arr[left + 3 * size], - arr[left + 4 * size], - arr[left + 5 * size], - arr[left + 6 * size], - arr[left + 7 * size], - arr[left + 8 * size], - arr[left + 9 * size], - arr[left + 10 * size], - arr[left + 11 * size], - arr[left + 12 * size], - arr[left + 13 * size], - arr[left + 14 * size], - arr[left + 15 * size], - arr[left + 16 * size]}; - reg_t rand_vec = vtype::loadu(vec_arr); - reg_t sort = vtype::sort_vec(rand_vec); - // pivot will never be a nan, since there are no nan's! - return ((type_t *)&sort)[8]; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, - const arrsize_t left, - const arrsize_t right) -{ - // median of 8 - arrsize_t size = (right - left) / 8; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[left + size], - arr[left + 2 * size], - arr[left + 3 * size], - arr[left + 4 * size], - arr[left + 5 * size], - arr[left + 6 * size], - arr[left + 7 * size], - arr[left + 8 * size]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = vtype::sort_vec(rand_vec); - return ((type_t *)&sort)[4]; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr, - const arrsize_t left, - const arrsize_t right) -{ - if constexpr (vtype::numlanes == 8) - return get_pivot_64bit(arr, left, right); - else if constexpr (vtype::numlanes == 16) - return get_pivot_32bit(arr, left, right); - else if constexpr (vtype::numlanes == 32) - return get_pivot_16bit(arr, left, right); - else - return get_pivot_scalar(arr, left, right); -} - template void sort_n(typename vtype::type_t *arr, int N); template -X86_SIMD_SORT_INLINE void +static void qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) { /* @@ -849,7 +757,7 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) return; } - type_t pivot = get_pivot(arr, left, right); + type_t pivot = get_pivot_blocks(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 7d0f0a06..94e508f0 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -8,7 +8,6 @@ #define AVX512FP16_QSORT_16BIT #include "avx512-16bit-common.h" -#include "xss-network-qsort.hpp" typedef union { _Float16 f_; @@ -25,6 +24,8 @@ struct zmm_vector<_Float16> { static constexpr int network_sort_threshold = 128; static constexpr int partition_unroll_factor = 0; + using swizzle_ops = avx512_16bit_swizzle_ops; + static __m512i get_network(int index) { return _mm512_loadu_si512(&network[index - 1][0]); @@ -132,14 +133,18 @@ struct zmm_vector<_Float16> { const auto rev_index = get_network(4); return permutexvar(rev_index, zmm); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_zmm_16bit>(x); - } static reg_t sort_vec(reg_t x) { return sort_zmm_16bit>(x); } + static reg_t cast_from(__m512i v) + { + return _mm512_castsi512_ph(v); + } + static __m512i cast_to(reg_t v) + { + return _mm512_castph_si512(v); + } }; template <> diff --git a/src/xss-network-qsort.hpp b/src/xss-network-qsort.hpp index 67afc2d4..a768a580 100644 --- a/src/xss-network-qsort.hpp +++ b/src/xss-network-qsort.hpp @@ -1,75 +1,148 @@ #ifndef XSS_NETWORK_QSORT #define XSS_NETWORK_QSORT -#include "avx512-common-qsort.h" +#include "xss-optimal-networks.hpp" template -X86_SIMD_SORT_INLINE void bitonic_clean_n_vec(reg_t *regs) +X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) { - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int num = numVecs / 2; num >= 2; num /= 2) { - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int j = 0; j < numVecs; j += num) { - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < num / 2; i++) { - COEX(regs[i + j], regs[i + j + num / 2]); - } - } + if constexpr (numVecs == 1) { + UNUSED(regs); + return; + } + else if constexpr (numVecs == 2) { + COEX(regs[0], regs[1]); + } + else if constexpr (numVecs == 4) { + optimal_sort_4(regs); + } + else if constexpr (numVecs == 8) { + optimal_sort_8(regs); + } + else if constexpr (numVecs == 16) { + optimal_sort_16(regs); + } + else if constexpr (numVecs == 32) { + optimal_sort_32(regs); + } + else { + static_assert(numVecs == -1, "should not reach here"); } } -template -X86_SIMD_SORT_INLINE void bitonic_merge_n_vec(reg_t *regs) +/* + * Swizzle ops explained: + * swap_n: swap neighbouring blocks of size within block of size + * reg i = [7,6,5,4,3,2,1,0] + * swap_n<2>: = [[6,7],[4,5],[2,3],[0,1]] + * swap_n<4>: = [[5,4,7,6],[1,0,3,2]] + * swap_n<8>: = [[3,2,1,0,7,6,5,4]] + * reverse_n: reverse elements within block of size + * reg i = [7,6,5,4,3,2,1,0] + * rev_n<2>: = [[6,7],[4,5],[2,3],[0,1]] + * rev_n<4>: = [[4,5,6,7],[0,1,2,3]] + * rev_n<8>: = [[0,1,2,3,4,5,6,7]] + * merge_n: merge blocks of elements from two regs + * reg b,a = [a,a,a,a,a,a,a,a], [b,b,b,b,b,b,b,b] + * merge_n<2> = [a,b,a,b,a,b,a,b] + * merge_n<4> = [a,a,b,b,a,a,b,b] + * merge_n<8> = [a,a,a,a,b,b,b,b] + */ + +template +X86_SIMD_SORT_FINLINE void internal_merge_n_vec(typename vtype::reg_t *reg) { - // Do the reverse part - if constexpr (numVecs == 2) { - regs[1] = vtype::reverse(regs[1]); - COEX(regs[0], regs[1]); + using reg_t = typename vtype::reg_t; + using swizzle = typename vtype::swizzle_ops; + if constexpr (scale <= 1) { + UNUSED(reg); + return; } - else if constexpr (numVecs > 2) { - // Reverse upper half - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs / 2; i++) { - reg_t rev = vtype::reverse(regs[numVecs - i - 1]); - reg_t maxV = vtype::max(regs[i], rev); - reg_t minV = vtype::min(regs[i], rev); - regs[numVecs - i - 1] = vtype::reverse(maxV); - regs[i] = minV; + else { + if constexpr (first) { + // Use reverse then merge + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + reg_t &v = reg[i]; + reg_t rev = swizzle::template reverse_n(v); + COEX(rev, v); + v = swizzle::template merge_n(v, rev); + } + } + else { + // Use swap then merge + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + reg_t &v = reg[i]; + reg_t swap = swizzle::template swap_n(v); + COEX(swap, v); + v = swizzle::template merge_n(v, swap); + } } + internal_merge_n_vec(reg); } +} - // Call cleaner - bitonic_clean_n_vec(regs); +template +X86_SIMD_SORT_FINLINE void merge_substep_n_vec(reg_t *regs) +{ + using swizzle = typename vtype::swizzle_ops; + if constexpr (numVecs <= 1) { + UNUSED(regs); + return; + } - // Now do bitonic_merge + // Reverse upper half of vectors + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2; i < numVecs; i++) { + regs[i] = swizzle::template reverse_n(regs[i]); + } + // Do compare exchanges X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs; i++) { - regs[i] = vtype::bitonic_merge(regs[i]); + for (int i = 0; i < numVecs / 2; i++) { + COEX(regs[i], regs[numVecs - 1 - i]); } + + merge_substep_n_vec(regs); + merge_substep_n_vec(regs + numVecs / 2); +} + +template +X86_SIMD_SORT_FINLINE void merge_step_n_vec(reg_t *regs) +{ + // Do cross vector merges + merge_substep_n_vec(regs); + + // Do internal vector merges + internal_merge_n_vec(regs); } template -X86_SIMD_SORT_INLINE void bitonic_fullmerge_n_vec(reg_t *regs) +X86_SIMD_SORT_FINLINE void merge_n_vec(reg_t *regs) { - if constexpr (numPer > numVecs) { + if constexpr (numPer > vtype::numlanes) { UNUSED(regs); return; } else { - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs / numPer; i++) { - bitonic_merge_n_vec(regs + i * numPer); - } - bitonic_fullmerge_n_vec(regs); + merge_step_n_vec(regs); + merge_n_vec(regs); } } template X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) { + static_assert(numVecs > 0, "numVecs should be > 0"); if constexpr (numVecs > 1) { if (N * 2 <= numVecs * vtype::numlanes) { sort_n_vec(arr, N); @@ -101,14 +174,13 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) vtype::zmm_max(), ioMasks[j], arr + i * vtype::numlanes); } - // Sort each loaded vector - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs; i++) { - vecs[i] = vtype::sort_vec(vecs[i]); - } + /* Run the initial sorting network to sort the columns of the [numVecs x + * num_lanes] matrix + */ + bitonic_sort_n_vec(vecs); - // Run the full merger - bitonic_fullmerge_n_vec(&vecs[0]); + // Merge the vectors using bitonic merging networks + merge_n_vec(vecs); // Unmasked part of the store X86_SIMD_SORT_UNROLL_LOOP(64) @@ -133,5 +205,4 @@ X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N) sort_n_vec(arr, N); } - #endif diff --git a/src/xss-optimal-networks.hpp b/src/xss-optimal-networks.hpp new file mode 100644 index 00000000..3dfa5281 --- /dev/null +++ b/src/xss-optimal-networks.hpp @@ -0,0 +1,320 @@ +// All of these sources files are generated from the optimal networks described in +// https://bertdobbelaere.github.io/sorting_networks.html + +template +X86_SIMD_SORT_FINLINE void optimal_sort_4(reg_t *vecs) +{ + COEX(vecs[0], vecs[2]); + COEX(vecs[1], vecs[3]); + + COEX(vecs[0], vecs[1]); + COEX(vecs[2], vecs[3]); + + COEX(vecs[1], vecs[2]); +} + +template +X86_SIMD_SORT_FINLINE void optimal_sort_8(reg_t *vecs) +{ + COEX(vecs[0], vecs[2]); + COEX(vecs[1], vecs[3]); + COEX(vecs[4], vecs[6]); + COEX(vecs[5], vecs[7]); + + COEX(vecs[0], vecs[4]); + COEX(vecs[1], vecs[5]); + COEX(vecs[2], vecs[6]); + COEX(vecs[3], vecs[7]); + + COEX(vecs[0], vecs[1]); + COEX(vecs[2], vecs[3]); + COEX(vecs[4], vecs[5]); + COEX(vecs[6], vecs[7]); + + COEX(vecs[2], vecs[4]); + COEX(vecs[3], vecs[5]); + + COEX(vecs[1], vecs[4]); + COEX(vecs[3], vecs[6]); + + COEX(vecs[1], vecs[2]); + COEX(vecs[3], vecs[4]); + COEX(vecs[5], vecs[6]); +} + +template +X86_SIMD_SORT_FINLINE void optimal_sort_16(reg_t *vecs) +{ + COEX(vecs[0], vecs[13]); + COEX(vecs[1], vecs[12]); + COEX(vecs[2], vecs[15]); + COEX(vecs[3], vecs[14]); + COEX(vecs[4], vecs[8]); + COEX(vecs[5], vecs[6]); + COEX(vecs[7], vecs[11]); + COEX(vecs[9], vecs[10]); + + COEX(vecs[0], vecs[5]); + COEX(vecs[1], vecs[7]); + COEX(vecs[2], vecs[9]); + COEX(vecs[3], vecs[4]); + COEX(vecs[6], vecs[13]); + COEX(vecs[8], vecs[14]); + COEX(vecs[10], vecs[15]); + COEX(vecs[11], vecs[12]); + + COEX(vecs[0], vecs[1]); + COEX(vecs[2], vecs[3]); + COEX(vecs[4], vecs[5]); + COEX(vecs[6], vecs[8]); + COEX(vecs[7], vecs[9]); + COEX(vecs[10], vecs[11]); + COEX(vecs[12], vecs[13]); + COEX(vecs[14], vecs[15]); + + COEX(vecs[0], vecs[2]); + COEX(vecs[1], vecs[3]); + COEX(vecs[4], vecs[10]); + COEX(vecs[5], vecs[11]); + COEX(vecs[6], vecs[7]); + COEX(vecs[8], vecs[9]); + COEX(vecs[12], vecs[14]); + COEX(vecs[13], vecs[15]); + + COEX(vecs[1], vecs[2]); + COEX(vecs[3], vecs[12]); + COEX(vecs[4], vecs[6]); + COEX(vecs[5], vecs[7]); + COEX(vecs[8], vecs[10]); + COEX(vecs[9], vecs[11]); + COEX(vecs[13], vecs[14]); + + COEX(vecs[1], vecs[4]); + COEX(vecs[2], vecs[6]); + COEX(vecs[5], vecs[8]); + COEX(vecs[7], vecs[10]); + COEX(vecs[9], vecs[13]); + COEX(vecs[11], vecs[14]); + + COEX(vecs[2], vecs[4]); + COEX(vecs[3], vecs[6]); + COEX(vecs[9], vecs[12]); + COEX(vecs[11], vecs[13]); + + COEX(vecs[3], vecs[5]); + COEX(vecs[6], vecs[8]); + COEX(vecs[7], vecs[9]); + COEX(vecs[10], vecs[12]); + + COEX(vecs[3], vecs[4]); + COEX(vecs[5], vecs[6]); + COEX(vecs[7], vecs[8]); + COEX(vecs[9], vecs[10]); + COEX(vecs[11], vecs[12]); + + COEX(vecs[6], vecs[7]); + COEX(vecs[8], vecs[9]); +} + +template +X86_SIMD_SORT_FINLINE void optimal_sort_32(reg_t *vecs) +{ + COEX(vecs[0], vecs[1]); + COEX(vecs[2], vecs[3]); + COEX(vecs[4], vecs[5]); + COEX(vecs[6], vecs[7]); + COEX(vecs[8], vecs[9]); + COEX(vecs[10], vecs[11]); + COEX(vecs[12], vecs[13]); + COEX(vecs[14], vecs[15]); + COEX(vecs[16], vecs[17]); + COEX(vecs[18], vecs[19]); + COEX(vecs[20], vecs[21]); + COEX(vecs[22], vecs[23]); + COEX(vecs[24], vecs[25]); + COEX(vecs[26], vecs[27]); + COEX(vecs[28], vecs[29]); + COEX(vecs[30], vecs[31]); + + COEX(vecs[0], vecs[2]); + COEX(vecs[1], vecs[3]); + COEX(vecs[4], vecs[6]); + COEX(vecs[5], vecs[7]); + COEX(vecs[8], vecs[10]); + COEX(vecs[9], vecs[11]); + COEX(vecs[12], vecs[14]); + COEX(vecs[13], vecs[15]); + COEX(vecs[16], vecs[18]); + COEX(vecs[17], vecs[19]); + COEX(vecs[20], vecs[22]); + COEX(vecs[21], vecs[23]); + COEX(vecs[24], vecs[26]); + COEX(vecs[25], vecs[27]); + COEX(vecs[28], vecs[30]); + COEX(vecs[29], vecs[31]); + + COEX(vecs[0], vecs[4]); + COEX(vecs[1], vecs[5]); + COEX(vecs[2], vecs[6]); + COEX(vecs[3], vecs[7]); + COEX(vecs[8], vecs[12]); + COEX(vecs[9], vecs[13]); + COEX(vecs[10], vecs[14]); + COEX(vecs[11], vecs[15]); + COEX(vecs[16], vecs[20]); + COEX(vecs[17], vecs[21]); + COEX(vecs[18], vecs[22]); + COEX(vecs[19], vecs[23]); + COEX(vecs[24], vecs[28]); + COEX(vecs[25], vecs[29]); + COEX(vecs[26], vecs[30]); + COEX(vecs[27], vecs[31]); + + COEX(vecs[0], vecs[8]); + COEX(vecs[1], vecs[9]); + COEX(vecs[2], vecs[10]); + COEX(vecs[3], vecs[11]); + COEX(vecs[4], vecs[12]); + COEX(vecs[5], vecs[13]); + COEX(vecs[6], vecs[14]); + COEX(vecs[7], vecs[15]); + COEX(vecs[16], vecs[24]); + COEX(vecs[17], vecs[25]); + COEX(vecs[18], vecs[26]); + COEX(vecs[19], vecs[27]); + COEX(vecs[20], vecs[28]); + COEX(vecs[21], vecs[29]); + COEX(vecs[22], vecs[30]); + COEX(vecs[23], vecs[31]); + + COEX(vecs[0], vecs[16]); + COEX(vecs[1], vecs[8]); + COEX(vecs[2], vecs[4]); + COEX(vecs[3], vecs[12]); + COEX(vecs[5], vecs[10]); + COEX(vecs[6], vecs[9]); + COEX(vecs[7], vecs[14]); + COEX(vecs[11], vecs[13]); + COEX(vecs[15], vecs[31]); + COEX(vecs[17], vecs[24]); + COEX(vecs[18], vecs[20]); + COEX(vecs[19], vecs[28]); + COEX(vecs[21], vecs[26]); + COEX(vecs[22], vecs[25]); + COEX(vecs[23], vecs[30]); + COEX(vecs[27], vecs[29]); + + COEX(vecs[1], vecs[2]); + COEX(vecs[3], vecs[5]); + COEX(vecs[4], vecs[8]); + COEX(vecs[6], vecs[22]); + COEX(vecs[7], vecs[11]); + COEX(vecs[9], vecs[25]); + COEX(vecs[10], vecs[12]); + COEX(vecs[13], vecs[14]); + COEX(vecs[17], vecs[18]); + COEX(vecs[19], vecs[21]); + COEX(vecs[20], vecs[24]); + COEX(vecs[23], vecs[27]); + COEX(vecs[26], vecs[28]); + COEX(vecs[29], vecs[30]); + + COEX(vecs[1], vecs[17]); + COEX(vecs[2], vecs[18]); + COEX(vecs[3], vecs[19]); + COEX(vecs[4], vecs[20]); + COEX(vecs[5], vecs[10]); + COEX(vecs[7], vecs[23]); + COEX(vecs[8], vecs[24]); + COEX(vecs[11], vecs[27]); + COEX(vecs[12], vecs[28]); + COEX(vecs[13], vecs[29]); + COEX(vecs[14], vecs[30]); + COEX(vecs[21], vecs[26]); + + COEX(vecs[3], vecs[17]); + COEX(vecs[4], vecs[16]); + COEX(vecs[5], vecs[21]); + COEX(vecs[6], vecs[18]); + COEX(vecs[7], vecs[9]); + COEX(vecs[8], vecs[20]); + COEX(vecs[10], vecs[26]); + COEX(vecs[11], vecs[23]); + COEX(vecs[13], vecs[25]); + COEX(vecs[14], vecs[28]); + COEX(vecs[15], vecs[27]); + COEX(vecs[22], vecs[24]); + + COEX(vecs[1], vecs[4]); + COEX(vecs[3], vecs[8]); + COEX(vecs[5], vecs[16]); + COEX(vecs[7], vecs[17]); + COEX(vecs[9], vecs[21]); + COEX(vecs[10], vecs[22]); + COEX(vecs[11], vecs[19]); + COEX(vecs[12], vecs[20]); + COEX(vecs[14], vecs[24]); + COEX(vecs[15], vecs[26]); + COEX(vecs[23], vecs[28]); + COEX(vecs[27], vecs[30]); + + COEX(vecs[2], vecs[5]); + COEX(vecs[7], vecs[8]); + COEX(vecs[9], vecs[18]); + COEX(vecs[11], vecs[17]); + COEX(vecs[12], vecs[16]); + COEX(vecs[13], vecs[22]); + COEX(vecs[14], vecs[20]); + COEX(vecs[15], vecs[19]); + COEX(vecs[23], vecs[24]); + COEX(vecs[26], vecs[29]); + + COEX(vecs[2], vecs[4]); + COEX(vecs[6], vecs[12]); + COEX(vecs[9], vecs[16]); + COEX(vecs[10], vecs[11]); + COEX(vecs[13], vecs[17]); + COEX(vecs[14], vecs[18]); + COEX(vecs[15], vecs[22]); + COEX(vecs[19], vecs[25]); + COEX(vecs[20], vecs[21]); + COEX(vecs[27], vecs[29]); + + COEX(vecs[5], vecs[6]); + COEX(vecs[8], vecs[12]); + COEX(vecs[9], vecs[10]); + COEX(vecs[11], vecs[13]); + COEX(vecs[14], vecs[16]); + COEX(vecs[15], vecs[17]); + COEX(vecs[18], vecs[20]); + COEX(vecs[19], vecs[23]); + COEX(vecs[21], vecs[22]); + COEX(vecs[25], vecs[26]); + + COEX(vecs[3], vecs[5]); + COEX(vecs[6], vecs[7]); + COEX(vecs[8], vecs[9]); + COEX(vecs[10], vecs[12]); + COEX(vecs[11], vecs[14]); + COEX(vecs[13], vecs[16]); + COEX(vecs[15], vecs[18]); + COEX(vecs[17], vecs[20]); + COEX(vecs[19], vecs[21]); + COEX(vecs[22], vecs[23]); + COEX(vecs[24], vecs[25]); + COEX(vecs[26], vecs[28]); + + COEX(vecs[3], vecs[4]); + COEX(vecs[5], vecs[6]); + COEX(vecs[7], vecs[8]); + COEX(vecs[9], vecs[10]); + COEX(vecs[11], vecs[12]); + COEX(vecs[13], vecs[14]); + COEX(vecs[15], vecs[16]); + COEX(vecs[17], vecs[18]); + COEX(vecs[19], vecs[20]); + COEX(vecs[21], vecs[22]); + COEX(vecs[23], vecs[24]); + COEX(vecs[25], vecs[26]); + COEX(vecs[27], vecs[28]); +} diff --git a/src/xss-pivot-selection.hpp b/src/xss-pivot-selection.hpp new file mode 100644 index 00000000..15fe36a2 --- /dev/null +++ b/src/xss-pivot-selection.hpp @@ -0,0 +1,156 @@ +template +X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); + +template +X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr, + const arrsize_t left, + const arrsize_t right) +{ + // median of 32 + arrsize_t size = (right - left) / 32; + type_t vec_arr[32] = {arr[left], + arr[left + size], + arr[left + 2 * size], + arr[left + 3 * size], + arr[left + 4 * size], + arr[left + 5 * size], + arr[left + 6 * size], + arr[left + 7 * size], + arr[left + 8 * size], + arr[left + 9 * size], + arr[left + 10 * size], + arr[left + 11 * size], + arr[left + 12 * size], + arr[left + 13 * size], + arr[left + 14 * size], + arr[left + 15 * size], + arr[left + 16 * size], + arr[left + 17 * size], + arr[left + 18 * size], + arr[left + 19 * size], + arr[left + 20 * size], + arr[left + 21 * size], + arr[left + 22 * size], + arr[left + 23 * size], + arr[left + 24 * size], + arr[left + 25 * size], + arr[left + 26 * size], + arr[left + 27 * size], + arr[left + 28 * size], + arr[left + 29 * size], + arr[left + 30 * size], + arr[left + 31 * size]}; + typename vtype::reg_t rand_vec = vtype::loadu(vec_arr); + typename vtype::reg_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[16]; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, + const arrsize_t left, + const arrsize_t right) +{ + // median of 16 + arrsize_t size = (right - left) / 16; + using reg_t = typename vtype::reg_t; + type_t vec_arr[16] = {arr[left + size], + arr[left + 2 * size], + arr[left + 3 * size], + arr[left + 4 * size], + arr[left + 5 * size], + arr[left + 6 * size], + arr[left + 7 * size], + arr[left + 8 * size], + arr[left + 9 * size], + arr[left + 10 * size], + arr[left + 11 * size], + arr[left + 12 * size], + arr[left + 13 * size], + arr[left + 14 * size], + arr[left + 15 * size], + arr[left + 16 * size]}; + reg_t rand_vec = vtype::loadu(vec_arr); + reg_t sort = vtype::sort_vec(rand_vec); + // pivot will never be a nan, since there are no nan's! + return ((type_t *)&sort)[8]; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, + const arrsize_t left, + const arrsize_t right) +{ + // median of 8 + arrsize_t size = (right - left) / 8; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[left + size], + arr[left + 2 * size], + arr[left + 3 * size], + arr[left + 4 * size], + arr[left + 5 * size], + arr[left + 6 * size], + arr[left + 7 * size], + arr[left + 8 * size]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[4]; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr, + const arrsize_t left, + const arrsize_t right) +{ + if constexpr (vtype::numlanes == 8) + return get_pivot_64bit(arr, left, right); + else if constexpr (vtype::numlanes == 16) + return get_pivot_32bit(arr, left, right); + else if constexpr (vtype::numlanes == 32) + return get_pivot_16bit(arr, left, right); + else + return arr[right]; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr, + arrsize_t left, + arrsize_t right) +{ + + if (right - left <= 1024) { return get_pivot(arr, left, right); } + + using reg_t = typename vtype::reg_t; + constexpr int numVecs = 5; + + arrsize_t width = (right - vtype::numlanes) - left; + arrsize_t delta = width / numVecs; + + reg_t vecs[numVecs]; + // Load data + for (int i = 0; i < numVecs; i++) { + vecs[i] = vtype::loadu(arr + left + delta * i); + } + + // Implement sorting network (from https://bertdobbelaere.github.io/sorting_networks.html) + COEX(vecs[0], vecs[3]); + COEX(vecs[1], vecs[4]); + + COEX(vecs[0], vecs[2]); + COEX(vecs[1], vecs[3]); + + COEX(vecs[0], vecs[1]); + COEX(vecs[2], vecs[4]); + + COEX(vecs[1], vecs[2]); + COEX(vecs[3], vecs[4]); + + COEX(vecs[2], vecs[3]); + + // Calculate median of the middle vector + reg_t &vec = vecs[numVecs / 2]; + vec = vtype::sort_vec(vec); + + type_t data[vtype::numlanes]; + vtype::storeu(data, vec); + return data[vtype::numlanes / 2]; +} diff --git a/tests/test-qsort-common.h b/tests/test-qsort-common.h index 6b8241b3..9638387f 100644 --- a/tests/test-qsort-common.h +++ b/tests/test-qsort-common.h @@ -1,6 +1,8 @@ #ifndef AVX512_TEST_COMMON #define AVX512_TEST_COMMON +#define XSS_DO_NOT_SET_SEED + #include "custom-compare.h" #include "rand_array.h" #include "x86simdsort.h" @@ -21,7 +23,7 @@ template void IS_SORTED(std::vector sorted, std::vector arr, std::string type) { - if (memcmp(arr.data(), sorted.data(), arr.size() * sizeof(T) != 0)) { + if (memcmp(arr.data(), sorted.data(), arr.size() * sizeof(T)) != 0) { REPORT_FAIL("Array not sorted", arr.size(), type, -1); } } diff --git a/utils/rand_array.h b/utils/rand_array.h index 562c67bf..22607a88 100644 --- a/utils/rand_array.h +++ b/utils/rand_array.h @@ -22,7 +22,9 @@ static std::vector get_uniform_rand_array( std::random_device rd; if constexpr(std::is_floating_point_v) { std::mt19937 gen(rd()); +#ifndef XSS_DO_NOT_SET_SEED gen.seed(42); +#endif std::uniform_real_distribution dis(min, max); for (int64_t ii = 0; ii < arrsize; ++ii) { arr.emplace_back(dis(gen)); @@ -39,7 +41,9 @@ static std::vector get_uniform_rand_array( #endif else if constexpr(std::is_integral_v) { std::default_random_engine e1(rd()); +#ifndef XSS_DO_NOT_SET_SEED e1.seed(42); +#endif std::uniform_int_distribution uniform_dist(min, max); for (int64_t ii = 0; ii < arrsize; ++ii) { arr.emplace_back(uniform_dist(e1));