Skip to content

Commit

Permalink
Merge pull request #13 from LaihoE/unroll_simd_loops
Browse files Browse the repository at this point in the history
  • Loading branch information
LaihoE authored Jul 28, 2024
2 parents 722b30e + 20ccf1e commit 74d18f5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#![feature(is_sorted)]
#![feature(sort_floats)]

pub const UNROLL_FACTOR: usize = 4;
pub const SIMD_LEN: usize = 32;

mod all_equal;
Expand Down
43 changes: 33 additions & 10 deletions src/position.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::SIMD_LEN;
use crate::UNROLL_FACTOR;
use std::simd::cmp::SimdPartialEq;
use std::simd::Mask;
use std::simd::{Simd, SimdElement};
Expand Down Expand Up @@ -26,16 +27,38 @@ where
}
// SIMD
let simd_needle = Simd::splat(needle);
for (chunk_idx, chunk) in simd_data.iter().enumerate() {
let mut unrolled_loops = 0;
// Unrolled loops
let mut chunks_iter = simd_data.chunks_exact(UNROLL_FACTOR);
for chunks in chunks_iter.by_ref() {
let mut mask = Mask::default();
for chunk in chunks {
mask |= chunk.simd_eq(simd_needle);
}
if mask.any() {
for (mask_idx, c) in chunks.iter().enumerate() {
let mask = c.simd_eq(simd_needle);
if mask.any() {
return Some(
prefix.len()
+ (unrolled_loops * (SIMD_LEN * UNROLL_FACTOR)) // Full outer loops
+ mask_idx * SIMD_LEN // nth inner loop
+ mask.to_bitmask().trailing_zeros() as usize, // nth element in matching mask
);
}
}
}
unrolled_loops += 1;
}
// Remaining simd loops that where not divisible by UNROLL_FACTOR
for (idx, chunk) in chunks_iter.remainder().iter().enumerate() {
let mask = chunk.simd_eq(simd_needle).to_bitmask();
if mask != 0 {
// Example:
// needle = 10
// prefix = [1,2,3]
// SIMD = [[4,5,6,7], [8,9,10,11]]
// 3 + (1 * 4) + (trailing_zeros(0b0010) == 2) = 9
return Some(
prefix.len() + (chunk_idx * SIMD_LEN) + (mask.trailing_zeros() as usize),
prefix.len()
+ (unrolled_loops * UNROLL_FACTOR * SIMD_LEN)
+ (idx * SIMD_LEN)
+ (mask.trailing_zeros() as usize),
);
}
}
Expand Down Expand Up @@ -67,14 +90,13 @@ mod tests {
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
Standard: Distribution<T>,
{
for len in 0..100 {
for _ in 0..5 {
for len in 0..500 {
for _ in 0..200 {
let mut v: Vec<T> = vec![T::default(); len];
let mut rng = rand::thread_rng();
for x in v.iter_mut() {
*x = rng.gen()
}

let needle = match rng.gen_bool(0.5) {
true => v.choose(&mut rng).cloned().unwrap_or(T::default()),
false => loop {
Expand All @@ -84,6 +106,7 @@ mod tests {
}
},
};

let ans = v.iter().position_simd(needle);
let correct = v.iter().position(|x| *x == needle);

Expand Down

0 comments on commit 74d18f5

Please sign in to comment.