Skip to content

Commit

Permalink
feat: initial setup for S3TablesCatalog
Browse files Browse the repository at this point in the history
  • Loading branch information
felixscherz committed Jan 6, 2025
1 parent 551f524 commit e41c428
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pyiceberg/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,8 +903,8 @@ def _get_default_warehouse_location(self, database_name: str, table_name: str) -
raise ValueError("No default path is set, please specify a location when creating a table")

@staticmethod
def _write_metadata(metadata: TableMetadata, io: FileIO, metadata_path: str) -> None:
ToOutputFile.table_metadata(metadata, io.new_output(metadata_path))
def _write_metadata(metadata: TableMetadata, io: FileIO, metadata_path: str, overwrite: bool = False) -> None:
ToOutputFile.table_metadata(metadata, io.new_output(metadata_path), overwrite=overwrite)

@staticmethod
def _get_metadata_location(location: str, new_version: int = 0) -> str:
Expand Down
228 changes: 228 additions & 0 deletions pyiceberg/catalog/s3tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import re
from typing import TYPE_CHECKING
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union

import boto3

from pyiceberg.catalog import DEPRECATED_BOTOCORE_SESSION, MetastoreCatalog
from pyiceberg.catalog import WAREHOUSE_LOCATION
from pyiceberg.catalog import Catalog
from pyiceberg.catalog import PropertiesUpdateSummary
from pyiceberg.exceptions import S3TablesError
from pyiceberg.exceptions import TableBucketNotFound
from pyiceberg.io import AWS_ACCESS_KEY_ID
from pyiceberg.io import AWS_REGION
from pyiceberg.io import AWS_SECRET_ACCESS_KEY
from pyiceberg.io import AWS_SESSION_TOKEN
from pyiceberg.io import load_file_io
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.serializers import FromInputFile
from pyiceberg.table import CommitTableResponse
from pyiceberg.table import CreateTableTransaction
from pyiceberg.table import Table
from pyiceberg.table import TableRequirement
from pyiceberg.table.metadata import new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.table.update import TableUpdate
from pyiceberg.typedef import EMPTY_DICT, Identifier
from pyiceberg.typedef import Properties
from pyiceberg.utils.properties import get_first_property_value

if TYPE_CHECKING:
import pyarrow as pa

S3TABLES_PROFILE_NAME = "s3tables.profile-name"
S3TABLES_REGION = "s3tables.region"
S3TABLES_ACCESS_KEY_ID = "s3tables.access-key-id"
S3TABLES_SECRET_ACCESS_KEY = "s3tables.secret-access-key"
S3TABLES_SESSION_TOKEN = "s3tables.session-token"

S3TABLES_ENDPOINT = "s3tables.endpoint"


class S3TableCatalog(MetastoreCatalog):
def __init__(self, name: str, **properties: str):
super().__init__(name, **properties)

session = boto3.Session(
profile_name=properties.get(S3TABLES_PROFILE_NAME),
region_name=get_first_property_value(properties, S3TABLES_REGION, AWS_REGION),
botocore_session=properties.get(DEPRECATED_BOTOCORE_SESSION),
aws_access_key_id=get_first_property_value(properties, S3TABLES_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
aws_secret_access_key=get_first_property_value(
properties, S3TABLES_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY
),
aws_session_token=get_first_property_value(properties, S3TABLES_SESSION_TOKEN, AWS_SESSION_TOKEN),
)
# TODO: s3tables client only supported from boto3>=1.35.74 so this can crash
# TODO: set a custom user-agent for api calls like the Java implementation
self.s3tables = session.client("s3tables")
# TODO: handle malformed properties instead of just raising a key error here
self.table_bucket_arn = self.properties[WAREHOUSE_LOCATION]
try:
self.s3tables.get_table_bucket(tableBucketARN=self.table_bucket_arn)
except self.s3tables.exceptions.NotFoundException as e:
raise TableBucketNotFound(e) from e

def commit_table(
self, table: Table, requirements: Tuple[TableRequirement, ...], updates: Tuple[TableUpdate, ...]
) -> CommitTableResponse:
return super().commit_table(table, requirements, updates)

def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = ...) -> None:
valid_namespace: str = self._validate_namespace_identifier(namespace)
self.s3tables.create_namespace(tableBucketARN=self.table_bucket_arn, namespace=[valid_namespace])

def _validate_namespace_identifier(self, namespace: Union[str, Identifier]) -> str:
# for naming rules see: https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-tables-buckets-naming.html
# TODO: extract into constant variables
pattern = re.compile("[a-z0-9][a-z0-9_]{2,62}")
reserved = "aws_s3_metadata"

namespace = self.identifier_to_database(namespace)

if not pattern.fullmatch(namespace):
...

if namespace == reserved:
...

return namespace

def _validate_database_and_table_identifier(self, identifier: Union[str, Identifier]) -> Tuple[str, str]:
# for naming rules see: https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-tables-buckets-naming.html
# TODO: extract into constant variables
pattern = re.compile("[a-z0-9][a-z0-9_]{2,62}")

namespace, table_name = self.identifier_to_database_and_table(identifier)

namespace = self._validate_namespace_identifier(namespace)

if not pattern.fullmatch(table_name):
...

return namespace, table_name

def create_table(
self,
identifier: Union[str, Identifier],
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
properties: Properties = EMPTY_DICT,
) -> Table:
namespace, table_name = self._validate_database_and_table_identifier(identifier)

schema: Schema = self._convert_schema_if_needed(schema) # type: ignore

# TODO: check whether namespace exists and if it does, whether table_name already exists
self.s3tables.create_table(
tableBucketARN=self.table_bucket_arn, namespace=namespace, name=table_name, format="ICEBERG"
)

# location is given by s3 table bucket
response = self.s3tables.get_table(tableBucketARN=self.table_bucket_arn, namespace=namespace, name=table_name)
version_token = response["versionToken"]

location = response["warehouseLocation"]
metadata_location = self._get_metadata_location(location=location)
metadata = new_table_metadata(
location=location,
schema=schema,
partition_spec=partition_spec,
sort_order=sort_order,
properties=properties,
)

io = load_file_io(properties=self.properties, location=metadata_location)
# TODO: this triggers unsupported list operation error, setting overwrite=True is a workaround for now
# TODO: we can perform this check manually maybe?
self._write_metadata(metadata, io, metadata_location, overwrite=True)
# TODO: after writing need to update table metadata location
# can this operation fail if the version token does not match?
self.s3tables.update_table_metadata_location(
tableBucketARN=self.table_bucket_arn,
namespace=namespace,
name=table_name,
versionToken=version_token,
metadataLocation=metadata_location,
)

return self.load_table(identifier=identifier)

def create_table_transaction(
self,
identifier: Union[str, Identifier],
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = ...,
sort_order: SortOrder = ...,
properties: Properties = ...,
) -> CreateTableTransaction:
return super().create_table_transaction(identifier, schema, location, partition_spec, sort_order, properties)

def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
return super().drop_namespace(namespace)

def drop_table(self, identifier: Union[str, Identifier]) -> None:
return super().drop_table(identifier)

def drop_view(self, identifier: Union[str, Identifier]) -> None:
return super().drop_view(identifier)

def list_namespaces(self, namespace: Union[str, Identifier] = ...) -> List[Identifier]:
# TODO: handle pagination
response = self.s3tables.list_namespaces(tableBucketARN=self.table_bucket_arn)
return [tuple(namespace["namespace"]) for namespace in response["namespaces"]]

def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
return super().list_tables(namespace)

def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]:
return super().list_views(namespace)

def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties:
return super().load_namespace_properties(namespace)

def load_table(self, identifier: Union[str, Identifier]) -> Table:
namespace, table_name = self._validate_database_and_table_identifier(identifier)
# TODO: raise a NoSuchTableError if it does not exist
response = self.s3tables.get_table_metadata_location(
tableBucketARN=self.table_bucket_arn, namespace=namespace, name=table_name
)
# TODO: we might need to catch if table is not initialized i.e. does not have metadata setup yet
metadata_location = response["metadataLocation"]

io = load_file_io(properties=self.properties, location=metadata_location)
file = io.new_input(metadata_location)
metadata = FromInputFile.table_metadata(file)
return Table(
identifier=(namespace, table_name),
metadata=metadata,
metadata_location=metadata_location,
io=self._load_file_io(metadata.properties, metadata_location),
catalog=self,
)

def purge_table(self, identifier: Union[str, Identifier]) -> None:
return super().purge_table(identifier)

def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table:
return super().register_table(identifier, metadata_location)

def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table:
return super().rename_table(from_identifier, to_identifier)

def table_exists(self, identifier: Union[str, Identifier]) -> bool:
return super().table_exists(identifier)

def update_namespace_properties(
self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = ...
) -> PropertiesUpdateSummary:
return super().update_namespace_properties(namespace, removals, updates)
5 changes: 5 additions & 0 deletions pyiceberg/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ class ConditionalCheckFailedException(DynamoDbError):
class GenericDynamoDbError(DynamoDbError):
pass

class S3TablesError(Exception):
pass

class TableBucketNotFound(S3TablesError):
pass

class CommitFailedException(Exception):
"""Commit failed, refresh and try again."""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,7 @@ sql-postgres = ["sqlalchemy", "psycopg2-binary"]
sql-sqlite = ["sqlalchemy"]
gcsfs = ["gcsfs"]
rest-sigv4 = ["boto3"]
s3tables = ["boto3"]

[tool.pytest.ini_options]
markers = [
Expand Down
71 changes: 71 additions & 0 deletions tests/catalog/test_s3tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import boto3
from pyiceberg.schema import Schema
import pytest

from pyiceberg.catalog.s3tables import S3TableCatalog
from pyiceberg.exceptions import TableBucketNotFound


@pytest.fixture
def database_name(database_name):
# naming rules prevent "-" in namespaces for s3 table buckets
return database_name.replace("-", "_")


@pytest.fixture
def table_name(table_name):
# naming rules prevent "-" in table namees for s3 table buckets
return table_name.replace("-", "_")

@pytest.fixture
def table_bucket_arn():
import os
# since the moto library does not support s3tables as of 2024-12-14 we have to test against a real AWS endpoint
# in one of the supported regions.

return os.environ["ARN"]


def test_s3tables_boto_api(table_bucket_arn):
client = boto3.client("s3tables")
response = client.list_namespaces(tableBucketARN=table_bucket_arn)
print(response["namespaces"])

response = client.get_table_bucket(tableBucketARN=table_bucket_arn + "abc")
print(response)



def test_s3tables_namespaces_api(table_bucket_arn):
client = boto3.client("s3tables")
response = client.create_namespace(tableBucketARN=table_bucket_arn, namespace=["one", "two"])
print(response)
response = client.list_namespaces(tableBucketARN=table_bucket_arn)
print(response)

def test_creating_catalog_validates_s3_table_bucket_exists(table_bucket_arn):
properties = {"warehouse": f"{table_bucket_arn}-modified"}
with pytest.raises(TableBucketNotFound):
S3TableCatalog(name="test_s3tables_catalog", **properties)


def test_create_namespace(table_bucket_arn, database_name: str):
properties = {"warehouse": table_bucket_arn}
catalog = S3TableCatalog(name="test_s3tables_catalog", **properties)
catalog.create_namespace(namespace=database_name)
namespaces = catalog.list_namespaces()
assert (database_name,) in namespaces


def test_create_table(table_bucket_arn, database_name: str, table_name:str, table_schema_nested: Schema):
properties = {"warehouse": table_bucket_arn}
catalog = S3TableCatalog(name="test_s3tables_catalog", **properties)
identifier = (database_name, table_name)

catalog.create_namespace(namespace=database_name)
print(database_name, table_name)
# this fails with
# OSError: When completing multiple part upload for key 'metadata/00000-55a9c37c-b822-4a81-ac0e-1efbcd145dba.metadata.json' in bucket '14e4e036-d4ae-44f8-koana45eruw
# Uunable to parse ExceptionName: S3TablesUnsupportedHeader Message: S3 Tables does not support the following header: x-amz-api-version value: 2006-03-01
table = catalog.create_table(identifier=identifier, schema=table_schema_nested)
print(table)

0 comments on commit e41c428

Please sign in to comment.