Skip to content

Commit

Permalink
chore: use residual in cosine inference (#1984)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: use residual to calculate cosine in PQ
  • Loading branch information
eddyxu authored Feb 22, 2024
1 parent bc999ea commit dc01633
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 40 deletions.
4 changes: 4 additions & 0 deletions rust/lance-index/src/vector/ivf/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ pub struct IvfBuildParams {
pub shuffle_partition_batches: usize,

pub shuffle_partition_concurrency: usize,

/// Use residual vectors to build sub-vector.
pub use_residual: bool,
}

impl Default for IvfBuildParams {
Expand All @@ -68,6 +71,7 @@ impl Default for IvfBuildParams {
precomputed_shuffle_buffers: None,
shuffle_partition_batches: 1024 * 10,
shuffle_partition_concurrency: 2,
use_residual: true,
}
}
}
Expand Down
23 changes: 1 addition & 22 deletions rust/lance-index/src/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ impl<T: ArrowFloatType + Cosine + Dot + L2 + 'static> ProductQuantizer for Produ
}

fn use_residual(&self) -> bool {
self.metric_type != MetricType::Cosine
matches!(self.metric_type, MetricType::L2 | MetricType::Cosine)
}
}

Expand Down Expand Up @@ -535,27 +535,6 @@ mod tests {
assert_eq!(tensor.shape, vec![256, 16]);
}

#[test]
fn test_cosine_pq_does_not_use_residual() {
let pq = ProductQuantizerImpl::<Float32Type> {
num_bits: 8,
num_sub_vectors: 4,
dimension: 16,
codebook: Arc::new(Float32Array::from_iter_values(repeat(0.0).take(128))),
metric_type: MetricType::Cosine,
};
assert!(!pq.use_residual());

let pq = ProductQuantizerImpl::<Float32Type> {
num_bits: 8,
num_sub_vectors: 4,
dimension: 16,
codebook: Arc::new(Float32Array::from_iter_values(repeat(0.0).take(128))),
metric_type: MetricType::L2,
};
assert!(pq.use_residual());
}

#[tokio::test]
async fn test_empty_dist_iter() {
let pq = ProductQuantizerImpl::<Float32Type> {
Expand Down
112 changes: 94 additions & 18 deletions rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -800,20 +800,6 @@ pub async fn build_ivf_pq_index(
}
Ivf::new(centroids.clone())
} else {
// Pre-transforms
if pq_params.use_opq {
#[cfg(not(feature = "opq"))]
return Err(Error::Index {
message: "Feature 'opq' is not installed.".to_string(),
location: location!(),
});
#[cfg(feature = "opq")]
{
let opq = train_opq(&training_data, pq_params).await?;
transforms.push(Box::new(opq));
}
}

// Transform training data if necessary.
for transform in transforms.iter() {
if let Some(training_data) = &mut training_data {
Expand Down Expand Up @@ -906,13 +892,12 @@ pub async fn build_ivf_pq_index(
// the time to compute them is not that bad.
let part_ids = ivf2.compute_partitions(&training_data).await?;

let training_data = if metric_type == MetricType::Cosine {
// Do not run residual distance for cosine distance.
training_data
} else {
let training_data = if ivf_params.use_residual {
span!(Level::INFO, "compute residual for PQ training")
.in_scope(|| ivf2.compute_residual(&training_data, Some(&part_ids)))
.await?
} else {
training_data
};
info!("Start train PQ: params={:#?}", pq_params);
pq_params.build(&training_data, metric_type).await?
Expand Down Expand Up @@ -1205,6 +1190,7 @@ mod tests {
use std::collections::HashMap;
use std::iter::repeat;

use arrow_array::types::UInt64Type;
use arrow_array::{cast::AsArray, RecordBatchIterator, RecordBatchReader, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use lance_core::utils::address::RowAddress;
Expand Down Expand Up @@ -1893,4 +1879,94 @@ mod tests {
]))
);
}

#[tokio::test]
async fn test_check_cosine_normalization() {
let test_dir = tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();
const DIM: usize = 32;

let schema = Arc::new(Schema::new(vec![Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
DIM as i32,
),
true,
)]));

let arr = generate_random_array(1000 * DIM)
.values()
.iter()
.map(|&v| v + 1000.0)
.collect::<Float32Array>();
let fsl = FixedSizeListArray::try_new_from_values(arr.clone(), DIM as i32).unwrap();
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(fsl)]).unwrap();
let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema.clone());
let mut dataset = Dataset::write(batches, test_uri, None).await.unwrap();

let params = VectorIndexParams::ivf_pq(2, 8, 4, false, MetricType::Cosine, 50);
dataset
.create_index(&[&"vector"], IndexType::Vector, None, &params, false)
.await
.unwrap();
let indices = dataset.load_indices().await.unwrap();
let idx = dataset
.open_generic_index("vector", indices[0].uuid.to_string().as_str())
.await
.unwrap();
let ivf_idx = idx.as_any().downcast_ref::<IVFIndex>().unwrap();
// All centroids are normalized.
//
// If not normalized, the centroids should be on the mean of original vector space
assert!(ivf_idx
.ivf
.centroids
.values()
.as_primitive::<Float32Type>()
.values()
.iter()
.all(|v| (0.0..=1.0).contains(v)));

let pq_idx = ivf_idx
.sub_index
.as_any()
.downcast_ref::<PQIndex>()
.unwrap();
assert!(pq_idx
.pq
.codebook_as_fsl()
.values()
.as_primitive::<Float32Type>()
.values()
.iter()
.all(|v| (0.0..=1.0).contains(v)));

let dataset = Dataset::open(test_uri).await.unwrap();

let mut correct_times = 0;
for query_id in 0..10 {
let query = &arr.slice(query_id * DIM, DIM);
let results = dataset
.scan()
.with_row_id()
.nearest("vector", query, 1)
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 1);
let row_id = results
.column_by_name("_rowid")
.unwrap()
.as_primitive::<UInt64Type>()
.value(0);
println!("Row id: {}", row_id);
if row_id == (query_id as u64) {
correct_times += 1;
}
}

assert!(correct_times >= 9, "correct: {}", correct_times);
}
}

0 comments on commit dc01633

Please sign in to comment.