diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 92ca66b63c..27361148fc 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -26,6 +26,7 @@ Optional, Sequence, Set, + Tuple, TypedDict, Union, ) @@ -102,6 +103,30 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None): return super(MergeInsertBuilder, self).execute(reader) + def execute_uncommitted( + self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None + ) -> Tuple[Transaction, Dict[str, Any]]: + """Executes the merge insert operation without committing + + This function updates the original dataset and returns a dictionary with + information about merge statistics - i.e. the number of inserted, updated, + and deleted rows. + + Parameters + ---------- + + data_obj: ReaderLike + The new data to use as the source table for the operation. This parameter + can be any source of data (e.g. table / dataset) that + :func:`~lance.write_dataset` accepts. + schema: Optional[pa.Schema] + The schema of the data. This only needs to be supplied whenever the data + source is some kind of generator. + """ + reader = _coerce_reader(data_obj, schema) + + return super(MergeInsertBuilder, self).execute_uncommitted(reader) + # These next three overrides exist only to document the methods def when_matched_update_all( @@ -2220,7 +2245,7 @@ def _commit( @staticmethod def commit( base_uri: Union[str, Path, LanceDataset], - operation: LanceOperation.BaseOperation, + operation: Union[LanceOperation.BaseOperation, Transaction], blobs_op: Optional[LanceOperation.BaseOperation] = None, read_version: Optional[int] = None, commit_lock: Optional[CommitLock] = None, @@ -2326,24 +2351,45 @@ def commit( f"commit_lock must be a function, got {type(commit_lock)}" ) - if read_version is None and not isinstance( - operation, (LanceOperation.Overwrite, LanceOperation.Restore) + if ( + isinstance(operation, LanceOperation.BaseOperation) + and read_version is None + and not isinstance( + operation, (LanceOperation.Overwrite, LanceOperation.Restore) + ) ): raise ValueError( "read_version is required for all operations except " "Overwrite and Restore" ) - new_ds = _Dataset.commit( - base_uri, - operation, - blobs_op, - read_version, - commit_lock, - storage_options=storage_options, - enable_v2_manifest_paths=enable_v2_manifest_paths, - detached=detached, - max_retries=max_retries, - ) + if isinstance(operation, Transaction): + new_ds = _Dataset.commit_transaction( + base_uri, + operation, + commit_lock, + storage_options=storage_options, + enable_v2_manifest_paths=enable_v2_manifest_paths, + detached=detached, + max_retries=max_retries, + ) + elif isinstance(operation, LanceOperation.BaseOperation): + new_ds = _Dataset.commit( + base_uri, + operation, + blobs_op, + read_version, + commit_lock, + storage_options=storage_options, + enable_v2_manifest_paths=enable_v2_manifest_paths, + detached=detached, + max_retries=max_retries, + ) + else: + raise TypeError( + "operation must be a LanceOperation.BaseOperation or Transaction, " + f"got {type(operation)}" + ) + ds = LanceDataset.__new__(LanceDataset) ds._storage_options = storage_options ds._ds = new_ds @@ -2722,6 +2768,29 @@ class Delete(BaseOperation): def __post_init__(self): LanceOperation._validate_fragments(self.updated_fragments) + @dataclass + class Update(BaseOperation): + """ + Operation that updates rows in the dataset. + + Attributes + ---------- + removed_fragment_ids: list[int] + The ids of the fragments that have been removed entirely. + updated_fragments: list[FragmentMetadata] + The fragments that have been updated with new deletion vectors. + new_fragments: list[FragmentMetadata] + The fragments that contain the new rows. + """ + + removed_fragment_ids: List[int] + updated_fragments: List[FragmentMetadata] + new_fragments: List[FragmentMetadata] + + def __post_init__(self): + LanceOperation._validate_fragments(self.updated_fragments) + LanceOperation._validate_fragments(self.new_fragments) + @dataclass class Merge(BaseOperation): """ diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index fb9b177ab9..73fae33863 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -1015,6 +1015,31 @@ def test_restore_with_commit(tmp_path: Path): assert tbl == table +def test_merge_insert_with_commit(): + table = pa.table({"id": range(10), "updated": [False] * 10}) + dataset = lance.write_dataset(table, "memory://test") + + updates = pa.Table.from_pylist([{"id": 1, "updated": True}]) + transaction, stats = ( + dataset.merge_insert(on="id") + .when_matched_update_all() + .execute_uncommitted(updates) + ) + + assert isinstance(stats, dict) + assert stats["num_updated_rows"] == 1 + assert stats["num_inserted_rows"] == 0 + assert stats["num_deleted_rows"] == 0 + + assert isinstance(transaction, lance.Transaction) + assert isinstance(transaction.operation, lance.LanceOperation.Update) + + dataset = lance.LanceDataset.commit(dataset, transaction) + assert dataset.to_table().sort_by("id") == pa.table( + {"id": range(10), "updated": [False] + [True] + [False] * 8} + ) + + def test_merge_with_commit(tmp_path: Path): table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) base_dir = tmp_path / "test" diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 4eff1ea0a4..4072978436 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -43,7 +43,8 @@ use lance::dataset::{ WriteParams, }; use lance::dataset::{ - BatchInfo, BatchUDF, CommitBuilder, NewColumnTransform, UDFCheckpointStore, WriteDestination, + BatchInfo, BatchUDF, CommitBuilder, MergeStats, NewColumnTransform, UDFCheckpointStore, + WriteDestination, }; use lance::dataset::{ColumnAlteration, ProjectionRequest}; use lance::index::vector::utils::get_vector_type; @@ -199,20 +200,46 @@ impl MergeInsertBuilder { .try_build() .map_err(|err| PyValueError::new_err(err.to_string()))?; - let new_self = RT + let (new_dataset, stats) = RT .spawn(Some(py), job.execute_reader(new_data))? .map_err(|err| PyIOError::new_err(err.to_string()))?; let dataset = self.dataset.bind(py); - dataset.borrow_mut().ds = new_self.0; - let merge_stats = new_self.1; - let merge_dict = PyDict::new_bound(py); - merge_dict.set_item("num_inserted_rows", merge_stats.num_inserted_rows)?; - merge_dict.set_item("num_updated_rows", merge_stats.num_updated_rows)?; - merge_dict.set_item("num_deleted_rows", merge_stats.num_deleted_rows)?; + dataset.borrow_mut().ds = new_dataset; - Ok(merge_dict.into()) + Ok(Self::build_stats(&stats, py)?.into()) + } + + pub fn execute_uncommitted<'a>( + &mut self, + new_data: &Bound<'a, PyAny>, + ) -> PyResult<(PyLance, Bound<'a, PyDict>)> { + let py = new_data.py(); + let new_data = convert_reader(new_data)?; + + let job = self + .builder + .try_build() + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let (transaction, stats) = RT + .spawn(Some(py), job.execute_uncommitted(new_data))? + .map_err(|err| PyIOError::new_err(err.to_string()))?; + + let stats = Self::build_stats(&stats, py)?; + + Ok((PyLance(transaction), stats)) + } +} + +impl MergeInsertBuilder { + fn build_stats<'a>(stats: &MergeStats, py: Python<'a>) -> PyResult> { + let dict = PyDict::new_bound(py); + dict.set_item("num_inserted_rows", stats.num_inserted_rows)?; + dict.set_item("num_updated_rows", stats.num_updated_rows)?; + dict.set_item("num_deleted_rows", stats.num_deleted_rows)?; + Ok(dict) } } @@ -1312,6 +1339,36 @@ impl Dataset { enable_v2_manifest_paths: Option, detached: Option, max_retries: Option, + ) -> PyResult { + let transaction = Transaction::new( + read_version.unwrap_or_default(), + operation.0, + blobs_op.map(|op| op.0), + None, + ); + + Self::commit_transaction( + dest, + PyLance(transaction), + commit_lock, + storage_options, + enable_v2_manifest_paths, + detached, + max_retries, + ) + } + + #[allow(clippy::too_many_arguments)] + #[staticmethod] + #[pyo3(signature = (dest, transaction, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))] + fn commit_transaction( + dest: &Bound, + transaction: PyLance, + commit_lock: Option<&Bound<'_, PyAny>>, + storage_options: Option>, + enable_v2_manifest_paths: Option, + detached: Option, + max_retries: Option, ) -> PyResult { let object_store_params = storage_options @@ -1333,13 +1390,6 @@ impl Dataset { WriteDestination::Uri(dest.extract()?) }; - let transaction = Transaction::new( - read_version.unwrap_or_default(), - operation.0, - blobs_op.map(|op| op.0), - None, - ); - let mut builder = CommitBuilder::new(dest) .enable_v2_manifest_paths(enable_v2_manifest_paths.unwrap_or(false)) .with_detached(detached.unwrap_or(false)) @@ -1354,7 +1404,10 @@ impl Dataset { } let ds = RT - .block_on(commit_lock.map(|cl| cl.py()), builder.execute(transaction))? + .block_on( + commit_lock.map(|cl| cl.py()), + builder.execute(transaction.0), + )? .map_err(|err| PyIOError::new_err(err.to_string()))?; let uri = ds.uri().to_string(); diff --git a/python/src/transaction.rs b/python/src/transaction.rs index 63b31ae611..ee549503d1 100644 --- a/python/src/transaction.rs +++ b/python/src/transaction.rs @@ -47,6 +47,20 @@ impl FromPyObject<'_> for PyLance { }; Ok(Self(op)) } + "Update" => { + let removed_fragment_ids = ob.getattr("removed_fragment_ids")?.extract()?; + + let updated_fragments = extract_vec(&ob.getattr("updated_fragments")?)?; + + let new_fragments = extract_vec(&ob.getattr("new_fragments")?)?; + + let op = Operation::Update { + removed_fragment_ids, + updated_fragments, + new_fragments, + }; + Ok(Self(op)) + } "Merge" => { let schema = extract_schema(&ob.getattr("schema")?)?; @@ -143,6 +157,21 @@ impl ToPyObject for PyLance<&Operation> { .expect("Failed to create Overwrite instance") .to_object(py) } + Operation::Update { + removed_fragment_ids, + updated_fragments, + new_fragments, + } => { + let removed_fragment_ids = removed_fragment_ids.to_object(py); + let updated_fragments = export_vec(py, updated_fragments.as_slice()); + let new_fragments = export_vec(py, new_fragments.as_slice()); + let cls = namespace + .getattr("Update") + .expect("Failed to get Update class"); + cls.call1((removed_fragment_ids, updated_fragments, new_fragments)) + .unwrap() + .to_object(py) + } _ => todo!(), } } diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 8e06360129..72c754c5ba 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -85,7 +85,8 @@ pub use schema_evolution::{ }; pub use take::TakeBuilder; pub use write::merge_insert::{ - MergeInsertBuilder, MergeInsertJob, WhenMatched, WhenNotMatched, WhenNotMatchedBySource, + MergeInsertBuilder, MergeInsertJob, MergeStats, WhenMatched, WhenNotMatched, + WhenNotMatchedBySource, }; pub use write::update::{UpdateBuilder, UpdateJob}; #[allow(deprecated)] diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index ddf6de7d9e..4ad998f8b7 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -84,17 +84,13 @@ use crate::{ write::open_writer, }, index::DatasetIndexInternalExt, - io::{ - commit::commit_transaction, - exec::{ - project, scalar_index::MapIndexExec, utils::ReplayExec, AddRowAddrExec, Planner, - TakeExec, - }, + io::exec::{ + project, scalar_index::MapIndexExec, utils::ReplayExec, AddRowAddrExec, Planner, TakeExec, }, Dataset, }; -use super::{write_fragments_internal, WriteParams}; +use super::{write_fragments_internal, CommitBuilder, WriteParams}; // "update if" expressions typically compare fields from the source table to the target table. // These tables have the same schema and so filter expressions need to differentiate. To do that @@ -1001,6 +997,27 @@ impl MergeInsertJob { self, source: SendableRecordBatchStream, ) -> Result<(Arc, MergeStats)> { + let ds = self.dataset.clone(); + let (transaction, stats) = self.execute_uncommitted_impl(source).await?; + let dataset = CommitBuilder::new(ds).execute(transaction).await?; + Ok((Arc::new(dataset), stats)) + } + + /// Execute the merge insert job without committing the changes. + /// + /// Use [`CommitBuilder`] to commit the returned transaction. + pub async fn execute_uncommitted( + self, + source: impl StreamingWriteSource, + ) -> Result<(Transaction, MergeStats)> { + let stream = source.into_stream(); + self.execute_uncommitted_impl(stream).await + } + + async fn execute_uncommitted_impl( + self, + source: SendableRecordBatchStream, + ) -> Result<(Transaction, MergeStats)> { let schema = source.schema(); let full_schema = Schema::from(self.dataset.local_schema()); @@ -1016,7 +1033,7 @@ impl MergeInsertJob { .try_flatten(); let stream = RecordBatchStreamAdapter::new(merger_schema, stream); - let committed_ds = if !is_full_schema { + let operation = if !is_full_schema { if !matches!( self.params.delete_not_matched_by_source, WhenNotMatchedBySource::Keep @@ -1030,7 +1047,11 @@ impl MergeInsertJob { let (updated_fragments, new_fragments) = Self::update_fragments(self.dataset.clone(), Box::pin(stream)).await?; - Self::commit(self.dataset, Vec::new(), updated_fragments, new_fragments).await? + Operation::Update { + removed_fragment_ids: Vec::new(), + updated_fragments, + new_fragments, + } } else { let written = write_fragments_internal( Some(&self.dataset), @@ -1052,13 +1073,11 @@ impl MergeInsertJob { Self::apply_deletions(&self.dataset, &removed_row_ids).await?; // Commit updated and new fragments - Self::commit( - self.dataset, + Operation::Update { removed_fragment_ids, - old_fragments, + updated_fragments: old_fragments, new_fragments, - ) - .await? + } }; let stats = Arc::into_inner(merge_statistics) @@ -1066,7 +1085,14 @@ impl MergeInsertJob { .into_inner() .unwrap(); - Ok((committed_ds, stats)) + let transaction = Transaction::new( + self.dataset.manifest.version, + operation, + /*blobs_op=*/ None, + None, + ); + + Ok((transaction, stats)) } // Delete a batch of rows by id, returns the fragments modified and the fragments removed @@ -1115,43 +1141,6 @@ impl MergeInsertJob { Ok((updated_fragments, removed_fragments)) } - - // Commit the operation - async fn commit( - dataset: Arc, - removed_fragment_ids: Vec, - updated_fragments: Vec, - new_fragments: Vec, - ) -> Result> { - let operation = Operation::Update { - removed_fragment_ids, - updated_fragments, - new_fragments, - }; - let transaction = Transaction::new( - dataset.manifest.version, - operation, - /*blobs_op=*/ None, - None, - ); - - let (manifest, manifest_path) = commit_transaction( - dataset.as_ref(), - dataset.object_store(), - dataset.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - dataset.manifest_naming_scheme, - ) - .await?; - - let mut dataset = dataset.as_ref().clone(); - dataset.manifest = Arc::new(manifest); - dataset.manifest_file = manifest_path; - - Ok(Arc::new(dataset)) - } } /// Merger will store these statistics as it runs (for each batch)