Skip to content

Commit

Permalink
Add SqlCatalog _commit_table support (#265)
Browse files Browse the repository at this point in the history
* sql commit

* SqlCatalog _commit_table

* better variable names

* fallback to FOR UPDATE commit when engine.dialect.supports_sane_rowcount is False

* remove stray print

* wait

* better logging
  • Loading branch information
sungwy authored Jan 17, 2024
1 parent 2d30119 commit 7deb739
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 22 deletions.
132 changes: 110 additions & 22 deletions pyiceberg/catalog/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
union,
update,
)
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.exc import IntegrityError, NoResultFound, OperationalError
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
Expand All @@ -48,6 +48,7 @@
PropertiesUpdateSummary,
)
from pyiceberg.exceptions import (
CommitFailedException,
NamespaceAlreadyExistsError,
NamespaceNotEmptyError,
NoSuchNamespaceError,
Expand All @@ -59,7 +60,7 @@
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.serializers import FromInputFile
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, update_table_metadata
from pyiceberg.table.metadata import new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT
Expand Down Expand Up @@ -268,16 +269,32 @@ def drop_table(self, identifier: Union[str, Identifier]) -> None:
identifier_tuple = self.identifier_to_tuple_without_catalog(identifier)
database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError)
with Session(self.engine) as session:
res = session.execute(
delete(IcebergTables).where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
if self.engine.dialect.supports_sane_rowcount:
res = session.execute(
delete(IcebergTables).where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
)
)
)
if res.rowcount < 1:
raise NoSuchTableError(f"Table does not exist: {database_name}.{table_name}")
else:
try:
tbl = (
session.query(IcebergTables)
.with_for_update(of=IcebergTables)
.filter(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
)
.one()
)
session.delete(tbl)
except NoResultFound as e:
raise NoSuchTableError(f"Table does not exist: {database_name}.{table_name}") from e
session.commit()
if res.rowcount < 1:
raise NoSuchTableError(f"Table does not exist: {database_name}.{table_name}")

def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table:
"""Rename a fully classified table name.
Expand All @@ -301,18 +318,35 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U
raise NoSuchNamespaceError(f"Namespace does not exist: {to_database_name}")
with Session(self.engine) as session:
try:
stmt = (
update(IcebergTables)
.where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == from_database_name,
IcebergTables.table_name == from_table_name,
if self.engine.dialect.supports_sane_rowcount:
stmt = (
update(IcebergTables)
.where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == from_database_name,
IcebergTables.table_name == from_table_name,
)
.values(table_namespace=to_database_name, table_name=to_table_name)
)
.values(table_namespace=to_database_name, table_name=to_table_name)
)
result = session.execute(stmt)
if result.rowcount < 1:
raise NoSuchTableError(f"Table does not exist: {from_table_name}")
result = session.execute(stmt)
if result.rowcount < 1:
raise NoSuchTableError(f"Table does not exist: {from_table_name}")
else:
try:
tbl = (
session.query(IcebergTables)
.with_for_update(of=IcebergTables)
.filter(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == from_database_name,
IcebergTables.table_name == from_table_name,
)
.one()
)
tbl.table_namespace = to_database_name
tbl.table_name = to_table_name
except NoResultFound as e:
raise NoSuchTableError(f"Table does not exist: {from_table_name}") from e
session.commit()
except IntegrityError as e:
raise TableAlreadyExistsError(f"Table {to_database_name}.{to_table_name} already exists") from e
Expand All @@ -329,8 +363,62 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons
Raises:
NoSuchTableError: If a table with the given identifier does not exist.
CommitFailedException: If the commit failed.
"""
raise NotImplementedError
identifier_tuple = self.identifier_to_tuple_without_catalog(
tuple(table_request.identifier.namespace.root + [table_request.identifier.name])
)
current_table = self.load_table(identifier_tuple)
database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError)
base_metadata = current_table.metadata
for requirement in table_request.requirements:
requirement.validate(base_metadata)

updated_metadata = update_table_metadata(base_metadata, table_request.updates)
if updated_metadata == base_metadata:
# no changes, do nothing
return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location)

# write new metadata
new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1
new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version)
self._write_metadata(updated_metadata, current_table.io, new_metadata_location)

with Session(self.engine) as session:
if self.engine.dialect.supports_sane_rowcount:
stmt = (
update(IcebergTables)
.where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
IcebergTables.metadata_location == current_table.metadata_location,
)
.values(metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location)
)
result = session.execute(stmt)
if result.rowcount < 1:
raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}")
else:
try:
tbl = (
session.query(IcebergTables)
.with_for_update(of=IcebergTables)
.filter(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
IcebergTables.metadata_location == current_table.metadata_location,
)
.one()
)
tbl.metadata_location = new_metadata_location
tbl.previous_metadata_location = current_table.metadata_location
except NoResultFound as e:
raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") from e
session.commit()

return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location)

def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool:
namespace = self.identifier_to_database(identifier)
Expand Down
55 changes: 55 additions & 0 deletions tests/catalog/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
SortOrder,
)
from pyiceberg.transforms import IdentityTransform
from pyiceberg.types import IntegerType


@pytest.fixture(name="warehouse", scope="session")
Expand Down Expand Up @@ -87,6 +88,19 @@ def catalog_sqlite(warehouse: Path) -> Generator[SqlCatalog, None, None]:
catalog.destroy_tables()


@pytest.fixture(scope="module")
def catalog_sqlite_without_rowcount(warehouse: Path) -> Generator[SqlCatalog, None, None]:
props = {
"uri": "sqlite:////tmp/sql-catalog.db",
"warehouse": f"file://{warehouse}",
}
catalog = SqlCatalog("test_sql_catalog", **props)
catalog.engine.dialect.supports_sane_rowcount = False
catalog.create_tables()
yield catalog
catalog.destroy_tables()


def test_creation_with_no_uri() -> None:
with pytest.raises(NoSuchPropertyException):
SqlCatalog("test_ddb_catalog", not_uri="unused")
Expand Down Expand Up @@ -305,6 +319,7 @@ def test_load_table_from_self_identifier(catalog: SqlCatalog, table_schema_neste
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_drop_table(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None:
Expand All @@ -322,6 +337,7 @@ def test_drop_table(catalog: SqlCatalog, table_schema_nested: Schema, random_ide
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_drop_table_from_self_identifier(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None:
Expand All @@ -341,6 +357,7 @@ def test_drop_table_from_self_identifier(catalog: SqlCatalog, table_schema_neste
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_drop_table_that_does_not_exist(catalog: SqlCatalog, random_identifier: Identifier) -> None:
Expand All @@ -353,6 +370,7 @@ def test_drop_table_that_does_not_exist(catalog: SqlCatalog, random_identifier:
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_rename_table(
Expand All @@ -377,6 +395,7 @@ def test_rename_table(
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_rename_table_from_self_identifier(
Expand All @@ -403,6 +422,7 @@ def test_rename_table_from_self_identifier(
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_rename_table_to_existing_one(
Expand All @@ -425,6 +445,7 @@ def test_rename_table_to_existing_one(
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_rename_missing_table(catalog: SqlCatalog, random_identifier: Identifier, another_random_identifier: Identifier) -> None:
Expand All @@ -439,6 +460,7 @@ def test_rename_missing_table(catalog: SqlCatalog, random_identifier: Identifier
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_rename_table_to_missing_namespace(
Expand Down Expand Up @@ -664,3 +686,36 @@ def test_update_namespace_properties(catalog: SqlCatalog, database_name: str) ->
else:
assert k in update_report.removed
assert "updated test description" == catalog.load_namespace_properties(database_name)["comment"]


@pytest.mark.parametrize(
'catalog',
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_commit_table(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None:
database_name, _table_name = random_identifier
catalog.create_namespace(database_name)
table = catalog.create_table(random_identifier, table_schema_nested)

assert catalog._parse_metadata_version(table.metadata_location) == 0
assert table.metadata.current_schema_id == 0

transaction = table.transaction()
update = transaction.update_schema()
update.add_column(path="b", field_type=IntegerType())
update.commit()
transaction.commit_transaction()

updated_table_metadata = table.metadata

assert catalog._parse_metadata_version(table.metadata_location) == 1
assert updated_table_metadata.current_schema_id == 1
assert len(updated_table_metadata.schemas) == 2
new_schema = next(schema for schema in updated_table_metadata.schemas if schema.schema_id == 1)
assert new_schema
assert new_schema == update._apply()
assert new_schema.find_field("b").field_type == IntegerType()

0 comments on commit 7deb739

Please sign in to comment.