From cde12085106279b6b69346286f7f2e9f06b324cc Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Wed, 1 Nov 2023 10:24:39 -0700 Subject: [PATCH] feat: provide a f32x16 abstraction to make unrolling 256-bit code easier (#1495) --- rust/lance-linalg/Cargo.toml | 4 + rust/lance-linalg/benches/norm_l2.rs | 85 ++++++ rust/lance-linalg/src/distance/l2.rs | 29 +-- rust/lance-linalg/src/distance/norm_l2.rs | 95 ++----- rust/lance-linalg/src/simd.rs | 16 +- rust/lance-linalg/src/simd/f32.rs | 301 +++++++++++++++++++++- 6 files changed, 437 insertions(+), 93 deletions(-) create mode 100644 rust/lance-linalg/benches/norm_l2.rs diff --git a/rust/lance-linalg/Cargo.toml b/rust/lance-linalg/Cargo.toml index 9ffec161d4..47ec71afb1 100644 --- a/rust/lance-linalg/Cargo.toml +++ b/rust/lance-linalg/Cargo.toml @@ -48,6 +48,10 @@ harness = false name = "cosine" harness = false +[[bench]] +name = "norm_l2" +harness = false + [[bench]] name = "kmeans" harness = false diff --git a/rust/lance-linalg/benches/norm_l2.rs b/rust/lance-linalg/benches/norm_l2.rs new file mode 100644 index 0000000000..c03eaa801c --- /dev/null +++ b/rust/lance-linalg/benches/norm_l2.rs @@ -0,0 +1,85 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow_arith::{aggregate::sum, numeric::mul}; +use arrow_array::{cast::AsArray, types::Float32Type, Float32Array}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +#[cfg(target_os = "linux")] +use pprof::criterion::{Output, PProfProfiler}; + +use lance_linalg::distance::norm_l2::Normalize; +use lance_testing::datagen::generate_random_array_with_seed; + +#[inline] +fn norm_l2_arrow(x: &Float32Array) -> f32 { + let m = mul(&x, &x).unwrap(); + sum(m.as_primitive::()).unwrap() +} + +#[inline] +fn norm_l2_auto_vectorization(x: &[f32]) -> f32 { + x.iter().map(|v| v * v).sum::() +} + +fn bench_distance(c: &mut Criterion) { + const DIMENSION: usize = 1024; + const TOTAL: usize = 1024 * 1024; // 1M vectors + + // 1M of 1024 D vectors. 4GB in memory. + let target = generate_random_array_with_seed(TOTAL * DIMENSION, [42; 32]); + + c.bench_function("norm_l2(arrow)", |b| { + b.iter(|| unsafe { + Float32Array::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| { + let arr = target.slice(idx * DIMENSION, DIMENSION); + Some(norm_l2_arrow(&arr)) + })) + }); + }); + + c.bench_function("norm_l2(auto-vectorization)", |b| { + b.iter(|| unsafe { + Float32Array::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| { + let arr = target.slice(idx * DIMENSION, DIMENSION); + Some(norm_l2_auto_vectorization(arr.values())) + })) + }); + }); + + c.bench_function("norm_l2(SIMD)", |b| { + b.iter(|| unsafe { + Float32Array::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| { + let arr = &target.values()[idx * DIMENSION..(idx + 1) * DIMENSION]; + Some(arr.norm_l2()) + })) + }); + }); +} + +#[cfg(target_os = "linux")] +criterion_group!( + name=benches; + config = Criterion::default().significance_level(0.1).sample_size(10) + .with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); + targets = bench_distance); + +// Non-linux version does not support pprof. +#[cfg(not(target_os = "linux"))] +criterion_group!( + name=benches; + config = Criterion::default().significance_level(0.1).sample_size(10); + targets = bench_distance); + +criterion_main!(benches); diff --git a/rust/lance-linalg/src/distance/l2.rs b/rust/lance-linalg/src/distance/l2.rs index dec85fd3a9..7b2efbf5d6 100644 --- a/rust/lance-linalg/src/distance/l2.rs +++ b/rust/lance-linalg/src/distance/l2.rs @@ -22,7 +22,10 @@ use arrow_array::{cast::AsArray, types::Float32Type, Array, FixedSizeListArray, use half::{bf16, f16}; use num_traits::real::Real; -use crate::simd::{f32::f32x8, SIMD}; +use crate::simd::{ + f32::{f32x16, f32x8}, + SIMD, +}; /// Calculate the L2 distance between two vectors. /// @@ -73,33 +76,29 @@ impl L2 for [f32] { if len % 16 == 0 { // Likely let dim = self.len(); - let mut sum1 = f32x8::splat(0.0); - let mut sum2 = f32x8::splat(0.0); + let mut sum = f32x16::zeros(); for i in (0..dim).step_by(16) { unsafe { - let mut x1 = f32x8::load_unaligned(self.as_ptr().add(i)); - let mut x2 = f32x8::load_unaligned(self.as_ptr().add(i + 8)); - let y1 = f32x8::load_unaligned(other.as_ptr().add(i)); - let y2 = f32x8::load_unaligned(other.as_ptr().add(i + 8)); - x1 -= y1; - x2 -= y2; - sum1.multiply_add(x1, x1); - sum2.multiply_add(x2, x2); + let mut x = f32x16::load_unaligned(self.as_ptr().add(i)); + + let y = f32x16::load_unaligned(other.as_ptr().add(i)); + x -= y; + sum.multiply_add(x, x); } } - (sum1 + sum2).reduce_sum() + sum.reduce_sum() } else if len % 8 == 0 { - let mut sum1 = f32x8::splat(0.0); + let mut sum = f32x8::zeros(); for i in (0..len).step_by(8) { unsafe { let mut x = f32x8::load_unaligned(self.as_ptr().add(i)); let y = f32x8::load_unaligned(other.as_ptr().add(i)); x -= y; - sum1.multiply_add(x, x); + sum.multiply_add(x, x); } } - sum1.reduce_sum() + sum.reduce_sum() } else { // Fallback to scalar l2_scalar(self, other) diff --git a/rust/lance-linalg/src/distance/norm_l2.rs b/rust/lance-linalg/src/distance/norm_l2.rs index 0c31b4e23b..77e4b5e2a1 100644 --- a/rust/lance-linalg/src/distance/norm_l2.rs +++ b/rust/lance-linalg/src/distance/norm_l2.rs @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::simd::{ + f32::{f32x16, f32x8}, + SIMD, +}; use half::{bf16, f16}; use num_traits::Float; @@ -48,20 +52,25 @@ impl Normalize for &[f32] { #[inline] fn norm_l2(&self) -> Self::Output { - #[cfg(target_arch = "aarch64")] - { - aarch64::neon::norm_l2(self) - } - - #[cfg(target_arch = "x86_64")] - { - if is_x86_feature_detected!("fma") { - return x86_64::avx::norm_l2_f32(self); + let dim = self.len(); + if dim % 16 == 0 { + let mut sum = f32x16::zeros(); + for i in (0..dim).step_by(16) { + let x = unsafe { f32x16::load_unaligned(self.as_ptr().add(i)) }; + sum += x * x; + } + sum.reduce_sum().sqrt() + } else if dim % 8 == 0 { + let mut sum = f32x8::zeros(); + for i in (0..dim).step_by(8) { + let x = unsafe { f32x8::load_unaligned(self.as_ptr().add(i)) }; + sum += x * x; } + sum.reduce_sum().sqrt() + } else { + // Fallback to scalar + return self.iter().map(|v| v * v).sum::().sqrt(); } - - #[cfg(not(target_arch = "aarch64"))] - self.iter().map(|v| v * v).sum::().sqrt() } } @@ -80,67 +89,7 @@ impl Normalize for &[f64] { /// Arrow Arrays, i.e., Float32Array #[inline] pub fn norm_l2(vector: &[f32]) -> f32 { - #[cfg(target_arch = "aarch64")] - { - aarch64::neon::norm_l2(vector) - } - - #[cfg(target_arch = "x86_64")] - { - if is_x86_feature_detected!("fma") { - return x86_64::avx::norm_l2_f32(vector); - } - } - - #[cfg(not(target_arch = "aarch64"))] - vector.iter().map(|v| v * v).sum::().sqrt() -} - -#[cfg(target_arch = "x86_64")] -mod x86_64 { - - pub mod avx { - use crate::distance::x86_64::avx::*; - use std::arch::x86_64::*; - - #[inline] - pub fn norm_l2_f32(vector: &[f32]) -> f32 { - let len = vector.len() / 8 * 8; - let mut sum = unsafe { - let mut sums = _mm256_setzero_ps(); - vector.chunks_exact(8).for_each(|chunk| { - let x = _mm256_loadu_ps(chunk.as_ptr()); - sums = _mm256_fmadd_ps(x, x, sums); - }); - add_f32_register(sums) - }; - sum += vector[len..].iter().map(|v| v * v).sum::(); - sum.sqrt() - } - } -} - -#[cfg(target_arch = "aarch64")] -mod aarch64 { - pub mod neon { - use std::arch::aarch64::*; - - #[inline] - pub fn norm_l2(vector: &[f32]) -> f32 { - let len = vector.len() / 4 * 4; - let mut sum = unsafe { - let buf = [0.0_f32; 4]; - let mut sum = vld1q_f32(buf.as_ptr()); - for i in (0..len).step_by(4) { - let x = vld1q_f32(vector.as_ptr().add(i)); - sum = vfmaq_f32(sum, x, x); - } - vaddvq_f32(sum) - }; - sum += vector[len..].iter().map(|v| v.powi(2)).sum::(); - sum.sqrt() - } - } + vector.norm_l2() } #[cfg(test)] diff --git a/rust/lance-linalg/src/simd.rs b/rust/lance-linalg/src/simd.rs index d52188af78..56db1f2e58 100644 --- a/rust/lance-linalg/src/simd.rs +++ b/rust/lance-linalg/src/simd.rs @@ -31,12 +31,17 @@ use num_traits::Float; /// Lance SIMD lib /// -pub trait SIMD: +pub trait SIMD: std::fmt::Debug + AddAssign + Add + Mul + Sub + SubAssign + Copy + Clone + Sized { + const LANES: usize = N; + /// Create a new instance with all lanes set to `val`. fn splat(val: T) -> Self; + /// Create a new instance with all lanes set to zero. + fn zeros() -> Self; + /// Load aligned data from aligned memory. /// /// # Safety @@ -61,6 +66,15 @@ pub trait SIMD: /// # Safety unsafe fn store_unaligned(&self, ptr: *mut T); + /// Return the values as an array. + fn as_array(&self) -> [T; N] { + let mut arr = [T::zero(); N]; + unsafe { + self.store_unaligned(arr.as_mut_ptr()); + } + arr + } + /// fused multiply-add /// /// c = a * b + c diff --git a/rust/lance-linalg/src/simd/f32.rs b/rust/lance-linalg/src/simd/f32.rs index 26ecd75f83..b4ea8b78a5 100644 --- a/rust/lance-linalg/src/simd/f32.rs +++ b/rust/lance-linalg/src/simd/f32.rs @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! `f32x8`, 8 of f32 values.s +//! `f32x8`, 8 of `f32` values.s use std::fmt::Formatter; #[cfg(target_arch = "aarch64")] use std::arch::aarch64::{ - float32x4x2_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32_x2, vmulq_f32, - vst1q_f32_x2, vsubq_f32, + float32x4x2_t, float32x4x4_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32_x2, + vld1q_f32_x4, vmulq_f32, vst1q_f32_x2, vst1q_f32_x4, vsubq_f32, }; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; @@ -49,7 +49,7 @@ impl std::fmt::Debug for f32x8 { } } -impl SIMD for f32x8 { +impl SIMD for f32x8 { fn splat(val: f32) -> Self { #[cfg(target_arch = "x86_64")] unsafe { @@ -61,6 +61,15 @@ impl SIMD for f32x8 { } } + fn zeros() -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm256_setzero_ps()) + } + #[cfg(target_arch = "aarch64")] + Self::splat(0.0) + } + #[inline] unsafe fn load(ptr: *const f32) -> Self { #[cfg(target_arch = "x86_64")] @@ -226,6 +235,246 @@ impl Mul for f32x8 { } } +/// 16 of 32-bit `f32` values. Use 512-bit SIMD if possible. +#[allow(non_camel_case_types)] +#[cfg(target_arch = "x86_64")] +#[derive(Clone, Copy)] +pub struct f32x16(__m256, __m256); + +/// 16 of 32-bit `f32` values. Use 512-bit SIMD if possible. +#[allow(non_camel_case_types)] +#[cfg(target_arch = "aarch64")] +#[derive(Clone, Copy)] +pub struct f32x16(float32x4x4_t); + +impl std::fmt::Debug for f32x16 { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut arr = [0.0_f32; 16]; + unsafe { + self.store_unaligned(arr.as_mut_ptr()); + } + write!(f, "f32x16({:?})", arr) + } +} +impl SIMD for f32x16 { + #[inline] + + fn splat(val: f32) -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm256_set1_ps(val), _mm256_set1_ps(val)) + } + #[cfg(target_arch = "aarch64")] + unsafe { + Self(float32x4x4_t( + vdupq_n_f32(val), + vdupq_n_f32(val), + vdupq_n_f32(val), + vdupq_n_f32(val), + )) + } + } + + #[inline] + fn zeros() -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm256_setzero_ps(), _mm256_setzero_ps()) + } + #[cfg(target_arch = "aarch64")] + Self::splat(0.0) + } + + #[inline] + + unsafe fn load(ptr: *const f32) -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm256_load_ps(ptr), _mm256_load_ps(ptr.add(8))) + } + #[cfg(target_arch = "aarch64")] + Self::load_unaligned(ptr) + } + + #[inline] + + unsafe fn load_unaligned(ptr: *const f32) -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm256_loadu_ps(ptr), _mm256_loadu_ps(ptr.add(8))) + } + #[cfg(target_arch = "aarch64")] + Self(vld1q_f32_x4(ptr)) + } + + #[inline] + unsafe fn store(&self, ptr: *mut f32) { + #[cfg(target_arch = "x86_64")] + unsafe { + _mm256_store_ps(ptr, self.0); + _mm256_store_ps(ptr.add(8), self.1); + } + #[cfg(target_arch = "aarch64")] + unsafe { + vst1q_f32_x4(ptr, self.0); + } + } + + #[inline] + + unsafe fn store_unaligned(&self, ptr: *mut f32) { + #[cfg(target_arch = "x86_64")] + unsafe { + _mm256_storeu_ps(ptr, self.0); + _mm256_storeu_ps(ptr.add(8), self.1); + } + #[cfg(target_arch = "aarch64")] + unsafe { + vst1q_f32_x4(ptr, self.0); + } + } + + #[inline] + fn multiply_add(&mut self, a: Self, b: Self) { + #[cfg(target_arch = "x86_64")] + unsafe { + self.0 = _mm256_fmadd_ps(a.0, b.0, self.0); + self.1 = _mm256_fmadd_ps(a.1, b.1, self.1); + } + #[cfg(target_arch = "aarch64")] + unsafe { + self.0 .0 = vfmaq_f32(self.0 .0, a.0 .0, b.0 .0); + self.0 .1 = vfmaq_f32(self.0 .1, a.0 .1, b.0 .1); + self.0 .2 = vfmaq_f32(self.0 .2, a.0 .2, b.0 .2); + self.0 .3 = vfmaq_f32(self.0 .3, a.0 .3, b.0 .3); + } + } + + fn reduce_sum(&self) -> f32 { + #[cfg(target_arch = "x86_64")] + unsafe { + let mut sum = _mm256_add_ps(self.0, self.1); + // Shift and add vector, until only 1 value left. + // sums = [x0-x7], shift = [x4-x7] + let mut shift = _mm256_permute2f128_ps(sum, sum, 1); + // [x0+x4, x1+x5, ..] + sum = _mm256_add_ps(sum, shift); + shift = _mm256_permute_ps(sum, 14); + sum = _mm256_add_ps(sum, shift); + sum = _mm256_hadd_ps(sum, sum); + let mut results: [f32; 8] = [0f32; 8]; + _mm256_storeu_ps(results.as_mut_ptr(), sum); + results[0] + } + #[cfg(target_arch = "aarch64")] + unsafe { + let mut sum1 = vaddq_f32(self.0 .0, self.0 .1); + let sum2 = vaddq_f32(self.0 .2, self.0 .3); + sum1 = vaddq_f32(sum1, sum2); + vaddvq_f32(sum1) + } + } +} + +impl Add for f32x16 { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm256_add_ps(self.0, rhs.0), _mm256_add_ps(self.1, rhs.1)) + } + #[cfg(target_arch = "aarch64")] + unsafe { + Self(float32x4x4_t( + vaddq_f32(self.0 .0, rhs.0 .0), + vaddq_f32(self.0 .1, rhs.0 .1), + vaddq_f32(self.0 .2, rhs.0 .2), + vaddq_f32(self.0 .3, rhs.0 .3), + )) + } + } +} + +impl AddAssign for f32x16 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + #[cfg(target_arch = "x86_64")] + unsafe { + self.0 = _mm256_add_ps(self.0, rhs.0); + self.1 = _mm256_add_ps(self.1, rhs.1); + } + #[cfg(target_arch = "aarch64")] + unsafe { + self.0 .0 = vaddq_f32(self.0 .0, rhs.0 .0); + self.0 .1 = vaddq_f32(self.0 .1, rhs.0 .1); + self.0 .2 = vaddq_f32(self.0 .2, rhs.0 .2); + self.0 .3 = vaddq_f32(self.0 .3, rhs.0 .3); + } + } +} + +impl Mul for f32x16 { + type Output = Self; + + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm256_mul_ps(self.0, rhs.0), _mm256_mul_ps(self.1, rhs.1)) + } + #[cfg(target_arch = "aarch64")] + unsafe { + Self(float32x4x4_t( + vmulq_f32(self.0 .0, rhs.0 .0), + vmulq_f32(self.0 .1, rhs.0 .1), + vmulq_f32(self.0 .2, rhs.0 .2), + vmulq_f32(self.0 .3, rhs.0 .3), + )) + } + } +} + +impl Sub for f32x16 { + type Output = Self; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm256_sub_ps(self.0, rhs.0), _mm256_sub_ps(self.1, rhs.1)) + } + #[cfg(target_arch = "aarch64")] + unsafe { + Self(float32x4x4_t( + vsubq_f32(self.0 .0, rhs.0 .0), + vsubq_f32(self.0 .1, rhs.0 .1), + vsubq_f32(self.0 .2, rhs.0 .2), + vsubq_f32(self.0 .3, rhs.0 .3), + )) + } + } +} + +impl SubAssign for f32x16 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + #[cfg(target_arch = "x86_64")] + unsafe { + self.0 = _mm256_sub_ps(self.0, rhs.0); + self.1 = _mm256_sub_ps(self.1, rhs.1); + } + #[cfg(target_arch = "aarch64")] + unsafe { + self.0 .0 = vsubq_f32(self.0 .0, rhs.0 .0); + self.0 .1 = vsubq_f32(self.0 .1, rhs.0 .1); + self.0 .2 = vsubq_f32(self.0 .2, rhs.0 .2); + self.0 .3 = vsubq_f32(self.0 .3, rhs.0 .3); + } + } +} + #[cfg(test)] mod tests { @@ -238,6 +487,20 @@ mod tests { let mut simd_a = unsafe { f32x8::load_unaligned(a.as_ptr()) }; let simd_b = unsafe { f32x8::load_unaligned(b.as_ptr()) }; + + let simd_add = simd_a + simd_b; + assert!((0..8) + .zip(simd_add.as_array().iter()) + .all(|(x, &y)| (x + x + 10) as f32 == y)); + + let simd_mul = simd_a * simd_b; + assert!((0..8) + .zip(simd_mul.as_array().iter()) + .all(|(x, &y)| (x * (x + 10)) as f32 == y)); + + let simd_sub = simd_b - simd_a; + assert!(simd_sub.as_array().iter().all(|&v| v == 10.0)); + simd_a -= simd_b; assert_eq!(simd_a.reduce_sum(), -80.0); @@ -249,4 +512,34 @@ mod tests { format!("{:?}", simd_power) ); } + + #[test] + fn test_basic_f32x16_ops() { + let a = (0..16).map(|f| f as f32).collect::>(); + let b = (10..26).map(|f| f as f32).collect::>(); + + let mut simd_a = unsafe { f32x16::load_unaligned(a.as_ptr()) }; + let simd_b = unsafe { f32x16::load_unaligned(b.as_ptr()) }; + + let simd_add = simd_a + simd_b; + assert!((0..16) + .zip(simd_add.as_array().iter()) + .all(|(x, &y)| (x + x + 10) as f32 == y)); + + let simd_mul = simd_a * simd_b; + assert!((0..16) + .zip(simd_mul.as_array().iter()) + .all(|(x, &y)| (x * (x + 10)) as f32 == y)); + + simd_a -= simd_b; + assert_eq!(simd_a.reduce_sum(), -160.0); + + let mut simd_power = f32x16::zeros(); + simd_power.multiply_add(simd_a, simd_a); + + assert_eq!( + format!("f32x16({:?})", [100.0; 16]), + format!("{:?}", simd_power) + ); + } }