Skip to content

Commit

Permalink
chore: simplify output schema in scanner (#1999)
Browse files Browse the repository at this point in the history
  • Loading branch information
chebbyChefNEQ authored Feb 27, 2024
1 parent e79db4f commit ea621cc
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 53 deletions.
26 changes: 19 additions & 7 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,33 +541,45 @@ def test_knn_with_deletions(tmp_path):
def test_index_cache_size(tmp_path):
rng = np.random.default_rng(seed=42)

def query_index(ds, ntimes):
def query_index(ds, ntimes, q=None):
ndim = ds.schema[0].type.list_size
for _ in range(ntimes):
ds.to_table(
nearest={
"column": "vector",
"q": rng.standard_normal(ndim),
"q": q if q is not None else rng.standard_normal(ndim),
},
)

tbl = create_table(nvec=1024, ndim=16)
dataset = lance.write_dataset(tbl, tmp_path / "test")

indexed_dataset = dataset.create_index(
dataset.create_index(
"vector",
index_type="IVF_PQ",
num_partitions=128,
num_sub_vectors=2,
index_cache_size=10,
)

indexed_dataset = lance.dataset(tmp_path / "test", index_cache_size=0)
# when there is no hit, the hit rate is hard coded to 1.0
assert np.isclose(indexed_dataset._ds.index_cache_hit_rate(), 1.0)
query_index(indexed_dataset, 1)
assert np.isclose(indexed_dataset._ds.index_cache_hit_rate(), 0.4)
query_index(indexed_dataset, 128)
indexed_dataset = lance.LanceDataset(indexed_dataset.uri, index_cache_size=5)
# index cache is size=0, there should be no hit
assert np.isclose(indexed_dataset._ds.index_cache_hit_rate(), 0.0)

indexed_dataset = lance.dataset(tmp_path / "test", index_cache_size=1)
# query using the same vector, we should get a very high hit rate
query_index(indexed_dataset, 100, q=rng.standard_normal(16))
assert indexed_dataset._ds.index_cache_hit_rate() > 0.99

last_hit_rate = indexed_dataset._ds.index_cache_hit_rate()

# send a few queries with different vectors, the hit rate should drop
query_index(indexed_dataset, 128)
assert indexed_dataset._ds.index_cache_entry_count() == 6

assert last_hit_rate > indexed_dataset._ds.index_cache_hit_rate()


def test_f16_index(tmp_path: Path):
Expand Down
2 changes: 1 addition & 1 deletion python/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub struct LanceReader {
impl LanceReader {
pub async fn try_new(scanner: Arc<LanceScanner>) -> ::lance::error::Result<Self> {
Ok(Self {
schema: scanner.schema()?,
schema: scanner.schema().await?,
stream: Arc::new(Mutex::new(scanner.try_into_stream().await?)), // needs tokio Runtime
})
}
Expand Down
5 changes: 2 additions & 3 deletions python/src/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ impl Scanner {
impl Scanner {
#[getter(schema)]
fn schema(self_: PyRef<'_, Self>) -> PyResult<PyObject> {
self_
.scanner
.schema()
let scanner = self_.scanner.clone();
RT.spawn(Some(self_.py()), async move { scanner.schema().await })?
.map(|s| s.to_pyarrow(self_.py()))
.map_err(|err| PyValueError::new_err(err.to_string()))?
}
Expand Down
46 changes: 4 additions & 42 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ use datafusion::physical_plan::{
ExecutionPlan, SendableRecordBatchStream,
};
use datafusion::scalar::ScalarValue;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::PhysicalExpr;
use futures::stream::{Stream, StreamExt};
use futures::TryStreamExt;
Expand Down Expand Up @@ -188,21 +187,6 @@ pub struct Scanner {
fragments: Option<Vec<Fragment>>,
}

/// lifted from datafusion
/// If e is a direct column reference, returns the field level
/// metadata for that field, if any. Otherwise returns None
fn get_field_metadata(
e: &Arc<dyn PhysicalExpr>,
input_schema: &ArrowSchema,
) -> Option<HashMap<String, String>> {
// Look up field by index in schema (not NAME as there can be more than one
// column with the same name)
e.as_any()
.downcast_ref::<Column>()
.map(|column| input_schema.field(column.index()).metadata())
.cloned()
}

impl Scanner {
pub fn new(dataset: Arc<Dataset>) -> Self {
let projection = dataset.schema().clone();
Expand Down Expand Up @@ -282,7 +266,7 @@ impl Scanner {
&mut self,
columns: &[(impl AsRef<str>, impl AsRef<str>)],
) -> Result<&mut Self> {
let planner = Planner::new(self.schema()?);
let planner = Planner::new(Arc::new(self.dataset.schema().into()));
let mut output = HashMap::new();
let mut physical_cols_set = HashSet::new();
let mut physical_cols = vec![];
Expand Down Expand Up @@ -570,11 +554,9 @@ impl Scanner {
}

/// The Arrow schema of the output, including projections and vector / _distance
pub fn schema(&self) -> Result<SchemaRef> {
let schema = self
.output_schema()
.map(|s| SchemaRef::new(ArrowSchema::from(s.as_ref())))?;
Ok(schema)
pub async fn schema(&self) -> Result<SchemaRef> {
let plan = self.create_plan().await?;
Ok(plan.schema())
}

/// The schema of the Scanner from lance physical takes
Expand Down Expand Up @@ -674,26 +656,6 @@ impl Scanner {
Ok(output_expr)
}

/// The schema of the Scanner output
pub(crate) fn output_schema(&self) -> Result<Arc<Schema>> {
let arrow_schema = self.physical_schema()?.as_ref().into();
let output_expr = self.output_expr()?;

let mut fields = vec![];
for (expr, name) in output_expr {
let dtype = expr.data_type(&arrow_schema)?;
let nullable = expr.nullable(&arrow_schema)?;

let mut field = ArrowField::new(name, dtype, nullable);
field.set_metadata(get_field_metadata(&expr, &arrow_schema).unwrap_or_default());

fields.push(field);
}
let schema = (&ArrowSchema::new(fields)).try_into()?;

Ok(Arc::new(schema))
}

/// Create a stream from the Scanner.
#[instrument(skip_all)]
pub async fn try_into_stream(&self) -> Result<DatasetRecordBatchStream> {
Expand Down

0 comments on commit ea621cc

Please sign in to comment.