Skip to content

Commit

Permalink
feat: provide a f32x16 abstraction to make unrolling 256-bit code eas…
Browse files Browse the repository at this point in the history
…ier (#1495)
  • Loading branch information
eddyxu authored Nov 1, 2023
1 parent e9a7d83 commit cde1208
Show file tree
Hide file tree
Showing 6 changed files with 437 additions and 93 deletions.
4 changes: 4 additions & 0 deletions rust/lance-linalg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ harness = false
name = "cosine"
harness = false

[[bench]]
name = "norm_l2"
harness = false

[[bench]]
name = "kmeans"
harness = false
Expand Down
85 changes: 85 additions & 0 deletions rust/lance-linalg/benches/norm_l2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use arrow_arith::{aggregate::sum, numeric::mul};
use arrow_array::{cast::AsArray, types::Float32Type, Float32Array};
use criterion::{black_box, criterion_group, criterion_main, Criterion};

#[cfg(target_os = "linux")]
use pprof::criterion::{Output, PProfProfiler};

use lance_linalg::distance::norm_l2::Normalize;
use lance_testing::datagen::generate_random_array_with_seed;

#[inline]
fn norm_l2_arrow(x: &Float32Array) -> f32 {
let m = mul(&x, &x).unwrap();
sum(m.as_primitive::<Float32Type>()).unwrap()
}

#[inline]
fn norm_l2_auto_vectorization(x: &[f32]) -> f32 {
x.iter().map(|v| v * v).sum::<f32>()
}

fn bench_distance(c: &mut Criterion) {
const DIMENSION: usize = 1024;
const TOTAL: usize = 1024 * 1024; // 1M vectors

// 1M of 1024 D vectors. 4GB in memory.
let target = generate_random_array_with_seed(TOTAL * DIMENSION, [42; 32]);

c.bench_function("norm_l2(arrow)", |b| {
b.iter(|| unsafe {
Float32Array::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| {
let arr = target.slice(idx * DIMENSION, DIMENSION);
Some(norm_l2_arrow(&arr))
}))
});
});

c.bench_function("norm_l2(auto-vectorization)", |b| {
b.iter(|| unsafe {
Float32Array::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| {
let arr = target.slice(idx * DIMENSION, DIMENSION);
Some(norm_l2_auto_vectorization(arr.values()))
}))
});
});

c.bench_function("norm_l2(SIMD)", |b| {
b.iter(|| unsafe {
Float32Array::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| {
let arr = &target.values()[idx * DIMENSION..(idx + 1) * DIMENSION];
Some(arr.norm_l2())
}))
});
});
}

#[cfg(target_os = "linux")]
criterion_group!(
name=benches;
config = Criterion::default().significance_level(0.1).sample_size(10)
.with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
targets = bench_distance);

// Non-linux version does not support pprof.
#[cfg(not(target_os = "linux"))]
criterion_group!(
name=benches;
config = Criterion::default().significance_level(0.1).sample_size(10);
targets = bench_distance);

criterion_main!(benches);
29 changes: 14 additions & 15 deletions rust/lance-linalg/src/distance/l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ use arrow_array::{cast::AsArray, types::Float32Type, Array, FixedSizeListArray,
use half::{bf16, f16};
use num_traits::real::Real;

use crate::simd::{f32::f32x8, SIMD};
use crate::simd::{
f32::{f32x16, f32x8},
SIMD,
};

/// Calculate the L2 distance between two vectors.
///
Expand Down Expand Up @@ -73,33 +76,29 @@ impl L2 for [f32] {
if len % 16 == 0 {
// Likely
let dim = self.len();
let mut sum1 = f32x8::splat(0.0);
let mut sum2 = f32x8::splat(0.0);
let mut sum = f32x16::zeros();

for i in (0..dim).step_by(16) {
unsafe {
let mut x1 = f32x8::load_unaligned(self.as_ptr().add(i));
let mut x2 = f32x8::load_unaligned(self.as_ptr().add(i + 8));
let y1 = f32x8::load_unaligned(other.as_ptr().add(i));
let y2 = f32x8::load_unaligned(other.as_ptr().add(i + 8));
x1 -= y1;
x2 -= y2;
sum1.multiply_add(x1, x1);
sum2.multiply_add(x2, x2);
let mut x = f32x16::load_unaligned(self.as_ptr().add(i));

let y = f32x16::load_unaligned(other.as_ptr().add(i));
x -= y;
sum.multiply_add(x, x);
}
}
(sum1 + sum2).reduce_sum()
sum.reduce_sum()
} else if len % 8 == 0 {
let mut sum1 = f32x8::splat(0.0);
let mut sum = f32x8::zeros();
for i in (0..len).step_by(8) {
unsafe {
let mut x = f32x8::load_unaligned(self.as_ptr().add(i));
let y = f32x8::load_unaligned(other.as_ptr().add(i));
x -= y;
sum1.multiply_add(x, x);
sum.multiply_add(x, x);
}
}
sum1.reduce_sum()
sum.reduce_sum()
} else {
// Fallback to scalar
l2_scalar(self, other)
Expand Down
95 changes: 22 additions & 73 deletions rust/lance-linalg/src/distance/norm_l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::simd::{
f32::{f32x16, f32x8},
SIMD,
};
use half::{bf16, f16};
use num_traits::Float;

Expand Down Expand Up @@ -48,20 +52,25 @@ impl Normalize<f32> for &[f32] {

#[inline]
fn norm_l2(&self) -> Self::Output {
#[cfg(target_arch = "aarch64")]
{
aarch64::neon::norm_l2(self)
}

#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("fma") {
return x86_64::avx::norm_l2_f32(self);
let dim = self.len();
if dim % 16 == 0 {
let mut sum = f32x16::zeros();
for i in (0..dim).step_by(16) {
let x = unsafe { f32x16::load_unaligned(self.as_ptr().add(i)) };
sum += x * x;
}
sum.reduce_sum().sqrt()
} else if dim % 8 == 0 {
let mut sum = f32x8::zeros();
for i in (0..dim).step_by(8) {
let x = unsafe { f32x8::load_unaligned(self.as_ptr().add(i)) };
sum += x * x;
}
sum.reduce_sum().sqrt()
} else {
// Fallback to scalar
return self.iter().map(|v| v * v).sum::<f32>().sqrt();
}

#[cfg(not(target_arch = "aarch64"))]
self.iter().map(|v| v * v).sum::<f32>().sqrt()
}
}

Expand All @@ -80,67 +89,7 @@ impl Normalize<f64> for &[f64] {
/// Arrow Arrays, i.e., Float32Array
#[inline]
pub fn norm_l2(vector: &[f32]) -> f32 {
#[cfg(target_arch = "aarch64")]
{
aarch64::neon::norm_l2(vector)
}

#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("fma") {
return x86_64::avx::norm_l2_f32(vector);
}
}

#[cfg(not(target_arch = "aarch64"))]
vector.iter().map(|v| v * v).sum::<f32>().sqrt()
}

#[cfg(target_arch = "x86_64")]
mod x86_64 {

pub mod avx {
use crate::distance::x86_64::avx::*;
use std::arch::x86_64::*;

#[inline]
pub fn norm_l2_f32(vector: &[f32]) -> f32 {
let len = vector.len() / 8 * 8;
let mut sum = unsafe {
let mut sums = _mm256_setzero_ps();
vector.chunks_exact(8).for_each(|chunk| {
let x = _mm256_loadu_ps(chunk.as_ptr());
sums = _mm256_fmadd_ps(x, x, sums);
});
add_f32_register(sums)
};
sum += vector[len..].iter().map(|v| v * v).sum::<f32>();
sum.sqrt()
}
}
}

#[cfg(target_arch = "aarch64")]
mod aarch64 {
pub mod neon {
use std::arch::aarch64::*;

#[inline]
pub fn norm_l2(vector: &[f32]) -> f32 {
let len = vector.len() / 4 * 4;
let mut sum = unsafe {
let buf = [0.0_f32; 4];
let mut sum = vld1q_f32(buf.as_ptr());
for i in (0..len).step_by(4) {
let x = vld1q_f32(vector.as_ptr().add(i));
sum = vfmaq_f32(sum, x, x);
}
vaddvq_f32(sum)
};
sum += vector[len..].iter().map(|v| v.powi(2)).sum::<f32>();
sum.sqrt()
}
}
vector.norm_l2()
}

#[cfg(test)]
Expand Down
16 changes: 15 additions & 1 deletion rust/lance-linalg/src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,17 @@ use num_traits::Float;

/// Lance SIMD lib
///
pub trait SIMD<T: Float>:
pub trait SIMD<T: Float, const N: usize>:
std::fmt::Debug + AddAssign + Add + Mul + Sub + SubAssign + Copy + Clone + Sized
{
const LANES: usize = N;

/// Create a new instance with all lanes set to `val`.
fn splat(val: T) -> Self;

/// Create a new instance with all lanes set to zero.
fn zeros() -> Self;

/// Load aligned data from aligned memory.
///
/// # Safety
Expand All @@ -61,6 +66,15 @@ pub trait SIMD<T: Float>:
/// # Safety
unsafe fn store_unaligned(&self, ptr: *mut T);

/// Return the values as an array.
fn as_array(&self) -> [T; N] {
let mut arr = [T::zero(); N];
unsafe {
self.store_unaligned(arr.as_mut_ptr());
}
arr
}

/// fused multiply-add
///
/// c = a * b + c
Expand Down
Loading

0 comments on commit cde1208

Please sign in to comment.