Skip to content

Commit

Permalink
chore: expose utils for infering vector dim and element type (#3385)
Browse files Browse the repository at this point in the history
found this is useful so that we don't need to repeat these lines
everywhere, and avoid diff error message

---------

Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal authored Jan 16, 2025
1 parent 8b8b8c8 commit 4149457
Showing 1 changed file with 36 additions and 14 deletions.
50 changes: 36 additions & 14 deletions rust/lance/src/index/vector/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use tokio::sync::Mutex;
use crate::dataset::Dataset;
use crate::{Error, Result};

/// Get the vector dimension of the given column in the schema.
pub fn get_vector_dim(schema: &Schema, column: &str) -> Result<usize> {
let field = schema.field(column).ok_or(Error::Index {
message: format!("Column {} does not exist in schema {}", column, schema),
Expand All @@ -20,20 +21,25 @@ pub fn get_vector_dim(schema: &Schema, column: &str) -> Result<usize> {
infer_vector_dim(&field.data_type())
}

fn infer_vector_dim(data_type: &arrow::datatypes::DataType) -> Result<usize> {
match data_type {
arrow::datatypes::DataType::FixedSizeList(_, dim) => Ok(*dim as usize),
arrow::datatypes::DataType::List(inner) => infer_vector_dim(inner.data_type()),
/// Infer the vector dimension from the given data type.
pub fn infer_vector_dim(data_type: &arrow::datatypes::DataType) -> Result<usize> {
infer_vector_dim_impl(data_type, false)
}

fn infer_vector_dim_impl(data_type: &arrow::datatypes::DataType, in_list: bool) -> Result<usize> {
match (data_type,in_list) {
(arrow::datatypes::DataType::FixedSizeList(_, dim),_) => Ok(*dim as usize),
(arrow::datatypes::DataType::List(inner), false) => infer_vector_dim_impl(inner.data_type(),true),
_ => Err(Error::Index {
message: format!("Data type is not a FixedSizeListArray, but {:?}", data_type),
message: format!("Data type is not a vector (FixedSizeListArray or List<FixedSizeListArray>), but {:?}", data_type),
location: location!(),
}),
}
}

// this checks whether the given column is with a valid vector type
// returns the vector type (FixedSizeList for vectors, or List for multivectors),
// and element type (Float16/Float32/Float64 or UInt8 for binary vectors).
/// Checks whether the given column is with a valid vector type
/// returns the vector type (FixedSizeList for vectors, or List for multivectors),
/// and element type (Float16/Float32/Float64 or UInt8 for binary vectors).
pub fn get_vector_type(
schema: &Schema,
column: &str,
Expand All @@ -48,28 +54,44 @@ pub fn get_vector_type(
))
}

fn infer_vector_element_type(
/// If the data type is a fixed size list or list of fixed size list return the inner element type
/// and verify it is a type we can create a vector index on.
///
/// Return an error if the data type is any other type
pub fn infer_vector_element_type(
data_type: &arrow::datatypes::DataType,
) -> Result<arrow_schema::DataType> {
infer_vector_element_type_impl(data_type, false)
}

fn infer_vector_element_type_impl(
data_type: &arrow::datatypes::DataType,
in_list: bool,
) -> Result<arrow_schema::DataType> {
match data_type {
arrow::datatypes::DataType::FixedSizeList(element_field, _) => {
match (data_type, in_list) {
(arrow::datatypes::DataType::FixedSizeList(element_field, _), _) => {
match element_field.data_type() {
arrow::datatypes::DataType::Float16
| arrow::datatypes::DataType::Float32
| arrow::datatypes::DataType::Float64
| arrow::datatypes::DataType::UInt8 => Ok(element_field.data_type().clone()),
_ => Err(Error::Index {
message: format!(
"vector element is not expected type (Float16/Float32/Float64 or UInt8) {:?}",
"vector element is not expected type (Float16/Float32/Float64 or UInt8): {:?}",
element_field.data_type()
),
location: location!(),
}),
}
}
arrow::datatypes::DataType::List(inner) => infer_vector_element_type(inner.data_type()),
(arrow::datatypes::DataType::List(inner), false) => {
infer_vector_element_type_impl(inner.data_type(), true)
}
_ => Err(Error::Index {
message: format!("vector is not with valid data type: {:?}", data_type),
message: format!(
"Data type is not a vector (FixedSizeListArray or List<FixedSizeListArray>), but {:?}",
data_type
),
location: location!(),
}),
}
Expand Down

0 comments on commit 4149457

Please sign in to comment.