diff --git a/rust/lance-index/src/vector/ivf/builder.rs b/rust/lance-index/src/vector/ivf/builder.rs index 72ab2b0e3f..573fcba116 100644 --- a/rust/lance-index/src/vector/ivf/builder.rs +++ b/rust/lance-index/src/vector/ivf/builder.rs @@ -55,6 +55,9 @@ pub struct IvfBuildParams { pub shuffle_partition_batches: usize, pub shuffle_partition_concurrency: usize, + + /// Use residual vectors to build sub-vector. + pub use_residual: bool, } impl Default for IvfBuildParams { @@ -68,6 +71,7 @@ impl Default for IvfBuildParams { precomputed_shuffle_buffers: None, shuffle_partition_batches: 1024 * 10, shuffle_partition_concurrency: 2, + use_residual: true, } } } diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index 72820226c2..c0033e747f 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -476,7 +476,7 @@ impl ProductQuantizer for Produ } fn use_residual(&self) -> bool { - self.metric_type != MetricType::Cosine + matches!(self.metric_type, MetricType::L2 | MetricType::Cosine) } } @@ -535,27 +535,6 @@ mod tests { assert_eq!(tensor.shape, vec![256, 16]); } - #[test] - fn test_cosine_pq_does_not_use_residual() { - let pq = ProductQuantizerImpl:: { - num_bits: 8, - num_sub_vectors: 4, - dimension: 16, - codebook: Arc::new(Float32Array::from_iter_values(repeat(0.0).take(128))), - metric_type: MetricType::Cosine, - }; - assert!(!pq.use_residual()); - - let pq = ProductQuantizerImpl:: { - num_bits: 8, - num_sub_vectors: 4, - dimension: 16, - codebook: Arc::new(Float32Array::from_iter_values(repeat(0.0).take(128))), - metric_type: MetricType::L2, - }; - assert!(pq.use_residual()); - } - #[tokio::test] async fn test_empty_dist_iter() { let pq = ProductQuantizerImpl:: { diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index c89f2be1e0..26c6cd2825 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -800,20 +800,6 @@ pub async fn build_ivf_pq_index( } Ivf::new(centroids.clone()) } else { - // Pre-transforms - if pq_params.use_opq { - #[cfg(not(feature = "opq"))] - return Err(Error::Index { - message: "Feature 'opq' is not installed.".to_string(), - location: location!(), - }); - #[cfg(feature = "opq")] - { - let opq = train_opq(&training_data, pq_params).await?; - transforms.push(Box::new(opq)); - } - } - // Transform training data if necessary. for transform in transforms.iter() { if let Some(training_data) = &mut training_data { @@ -906,13 +892,12 @@ pub async fn build_ivf_pq_index( // the time to compute them is not that bad. let part_ids = ivf2.compute_partitions(&training_data).await?; - let training_data = if metric_type == MetricType::Cosine { - // Do not run residual distance for cosine distance. - training_data - } else { + let training_data = if ivf_params.use_residual { span!(Level::INFO, "compute residual for PQ training") .in_scope(|| ivf2.compute_residual(&training_data, Some(&part_ids))) .await? + } else { + training_data }; info!("Start train PQ: params={:#?}", pq_params); pq_params.build(&training_data, metric_type).await? @@ -1205,6 +1190,7 @@ mod tests { use std::collections::HashMap; use std::iter::repeat; + use arrow_array::types::UInt64Type; use arrow_array::{cast::AsArray, RecordBatchIterator, RecordBatchReader, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use lance_core::utils::address::RowAddress; @@ -1893,4 +1879,94 @@ mod tests { ])) ); } + + #[tokio::test] + async fn test_check_cosine_normalization() { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + const DIM: usize = 32; + + let schema = Arc::new(Schema::new(vec![Field::new( + "vector", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + DIM as i32, + ), + true, + )])); + + let arr = generate_random_array(1000 * DIM) + .values() + .iter() + .map(|&v| v + 1000.0) + .collect::(); + let fsl = FixedSizeListArray::try_new_from_values(arr.clone(), DIM as i32).unwrap(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(fsl)]).unwrap(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema.clone()); + let mut dataset = Dataset::write(batches, test_uri, None).await.unwrap(); + + let params = VectorIndexParams::ivf_pq(2, 8, 4, false, MetricType::Cosine, 50); + dataset + .create_index(&[&"vector"], IndexType::Vector, None, ¶ms, false) + .await + .unwrap(); + let indices = dataset.load_indices().await.unwrap(); + let idx = dataset + .open_generic_index("vector", indices[0].uuid.to_string().as_str()) + .await + .unwrap(); + let ivf_idx = idx.as_any().downcast_ref::().unwrap(); + // All centroids are normalized. + // + // If not normalized, the centroids should be on the mean of original vector space + assert!(ivf_idx + .ivf + .centroids + .values() + .as_primitive::() + .values() + .iter() + .all(|v| (0.0..=1.0).contains(v))); + + let pq_idx = ivf_idx + .sub_index + .as_any() + .downcast_ref::() + .unwrap(); + assert!(pq_idx + .pq + .codebook_as_fsl() + .values() + .as_primitive::() + .values() + .iter() + .all(|v| (0.0..=1.0).contains(v))); + + let dataset = Dataset::open(test_uri).await.unwrap(); + + let mut correct_times = 0; + for query_id in 0..10 { + let query = &arr.slice(query_id * DIM, DIM); + let results = dataset + .scan() + .with_row_id() + .nearest("vector", query, 1) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 1); + let row_id = results + .column_by_name("_rowid") + .unwrap() + .as_primitive::() + .value(0); + println!("Row id: {}", row_id); + if row_id == (query_id as u64) { + correct_times += 1; + } + } + + assert!(correct_times >= 9, "correct: {}", correct_times); + } }