Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UnionByName functionality #296

Merged
merged 5 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions mkdocs/docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,47 @@ with table.transaction() as transaction:
# ... Update properties etc
```

### Union by Name

Using `.union_by_name()` you can merge another schema into an existing schema without having to worry about field-IDs:

```python
from pyiceberg.catalog import load_catalog
from pyiceberg.schema import Schema
from pyiceberg.types import NestedField, StringType, DoubleType, LongType

catalog = load_catalog()

schema = Schema(
NestedField(1, "city", StringType(), required=False),
NestedField(2, "lat", DoubleType(), required=False),
NestedField(3, "long", DoubleType(), required=False),
)

table = catalog.create_table("default.locations", schema)

new_schema = Schema(
NestedField(1, "city", StringType(), required=False),
NestedField(2, "lat", DoubleType(), required=False),
NestedField(3, "long", DoubleType(), required=False),
NestedField(10, "population", LongType(), required=False),
)

with table.update_schema() as update:
update.union_by_name(new_schema)
```

Now the table has the union of the two schemas `print(table.schema())`:

```
table {
1: city: optional string
2: lat: optional double
3: long: optional double
4: population: optional long
}
```

### Add column

Using `add_column` you can add a column, without having to worry about the field-id:
Expand Down
188 changes: 183 additions & 5 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@
)
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import (
PartnerAccessor,
Schema,
SchemaVisitor,
SchemaWithPartnerVisitor,
assign_fresh_schema_ids,
promote,
visit,
visit_with_partner,
)
from pyiceberg.table.metadata import (
INITIAL_SEQUENCE_NUMBER,
Expand Down Expand Up @@ -1379,7 +1382,7 @@ class Move:


class UpdateSchema:
_table: Table
_table: Optional[Table]
_schema: Schema
_last_column_id: itertools.count[int]
_identifier_field_names: Set[str]
Expand All @@ -1398,14 +1401,23 @@ class UpdateSchema:

def __init__(
self,
table: Table,
table: Optional[Table],
HonahX marked this conversation as resolved.
Show resolved Hide resolved
transaction: Optional[Transaction] = None,
allow_incompatible_changes: bool = False,
case_sensitive: bool = True,
schema: Optional[Schema] = None,
) -> None:
self._table = table
self._schema = table.schema()
self._last_column_id = itertools.count(table.metadata.last_column_id + 1)

if isinstance(schema, Schema):
self._schema = schema
self._last_column_id = itertools.count(1 + schema.highest_field_id)
elif table is not None:
self._schema = table.schema()
self._last_column_id = itertools.count(1 + table.metadata.last_column_id)
else:
raise ValueError("Either provide a table or a schema")

self._identifier_field_names = self._schema.identifier_field_names()

self._adds = {}
Expand Down Expand Up @@ -1449,6 +1461,15 @@ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema:
self._case_sensitive = case_sensitive
return self

def union_by_name(self, new_schema: Schema) -> UpdateSchema:
visit_with_partner(
new_schema,
-1,
UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore
PartnerIdByNameAccessor(partner_schema=self._schema, case_sensitive=self._case_sensitive),
)
return self

def add_column(
self, path: Union[str, Tuple[str, ...]], field_type: IcebergType, doc: Optional[str] = None, required: bool = False
) -> UpdateSchema:
Expand Down Expand Up @@ -1816,6 +1837,9 @@ def move_after(self, path: Union[str, Tuple[str, ...]], after_name: Union[str, T

def commit(self) -> None:
"""Apply the pending changes and commit."""
if self._table is None:
raise ValueError("Requires a table to commit to")

new_schema = self._apply()

existing_schema_id = next((schema.schema_id for schema in self._table.metadata.schemas if schema == new_schema), None)
Expand Down Expand Up @@ -1862,7 +1886,8 @@ def _apply(self) -> Schema:

field_ids.add(field.field_id)

return Schema(*struct.fields, schema_id=1 + max(self._table.schemas().keys()), identifier_field_ids=field_ids)
next_schema_id = 1 + (max(self._table.schemas().keys()) if self._table is not None else self._schema.schema_id)
return Schema(*struct.fields, schema_id=next_schema_id, identifier_field_ids=field_ids)

def assign_new_column_id(self) -> int:
return next(self._last_column_id)
Expand Down Expand Up @@ -1995,6 +2020,159 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:
return primitive


class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]):
update_schema: UpdateSchema
existing_schema: Schema
case_sensitive: bool

def __init__(self, update_schema: UpdateSchema, existing_schema: Schema, case_sensitive: bool) -> None:
self.update_schema = update_schema
self.existing_schema = existing_schema
self.case_sensitive = case_sensitive

def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool:
return struct_result

def struct(self, struct: StructType, partner_id: Optional[int], missing_positions: List[bool]) -> bool:
if partner_id is None:
return True

fields = struct.fields
partner_struct = self._find_field_type(partner_id)
HonahX marked this conversation as resolved.
Show resolved Hide resolved

if not partner_struct.is_struct:
raise ValueError(f"Expected a struct, got: {partner_struct}")

for pos, missing in enumerate(missing_positions):
if missing:
self._add_column(partner_id, fields[pos])
else:
field = fields[pos]
if nested_field := partner_struct.field_by_name(field.name, case_sensitive=self.case_sensitive):
self._update_column(field, nested_field)

return False

def _add_column(self, parent_id: int, field: NestedField) -> None:
if parent_name := self.existing_schema.find_column_name(parent_id):
path: Tuple[str, ...] = (parent_name, field.name)
else:
path = (field.name,)

self.update_schema.add_column(path=path, field_type=field.field_type, required=field.required, doc=field.doc)

def _update_column(self, field: NestedField, existing_field: NestedField) -> None:
full_name = self.existing_schema.find_column_name(existing_field.field_id)

if full_name is None:
raise ValueError(f"Could not find field: {existing_field}")

if field.optional and existing_field.required:
self.update_schema.make_column_optional(full_name)

if field.field_type.is_primitive and field.field_type != existing_field.field_type:
self.update_schema.update_column(full_name, field_type=field.field_type)

if field.doc is not None and not field.doc != existing_field.doc:
self.update_schema.update_column(full_name, doc=field.doc)

def _find_field_type(self, field_id: int) -> IcebergType:
if field_id == -1:
return self.existing_schema.as_struct()
else:
return self.existing_schema.find_field(field_id).field_type

def field(self, field: NestedField, partner_id: Optional[int], field_result: bool) -> bool:
return partner_id is None

def list(self, list_type: ListType, list_partner_id: Optional[int], element_missing: bool) -> bool:
if list_partner_id is None:
return True

if element_missing:
raise ValueError("Error traversing schemas: element is missing, but list is present")

partner_list_type = self._find_field_type(list_partner_id)
if not isinstance(partner_list_type, ListType):
raise ValueError(f"Expected list-type, got: {partner_list_type}")

self._update_column(list_type.element_field, partner_list_type.element_field)

return False

def map(self, map_type: MapType, map_partner_id: Optional[int], key_missing: bool, value_missing: bool) -> bool:
if map_partner_id is None:
return True

if key_missing:
raise ValueError("Error traversing schemas: key is missing, but map is present")

if value_missing:
raise ValueError("Error traversing schemas: value is missing, but map is present")

partner_map_type = self._find_field_type(map_partner_id)
if not isinstance(partner_map_type, MapType):
raise ValueError(f"Expected map-type, got: {partner_map_type}")

self._update_column(map_type.key_field, partner_map_type.key_field)
self._update_column(map_type.value_field, partner_map_type.value_field)

return False

def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[int]) -> bool:
Fokko marked this conversation as resolved.
Show resolved Hide resolved
return primitive_partner is None


class PartnerIdByNameAccessor(PartnerAccessor[int]):
partner_schema: Schema
case_sensitive: bool

def __init__(self, partner_schema: Schema, case_sensitive: bool) -> None:
self.partner_schema = partner_schema
self.case_sensitive = case_sensitive

def schema_partner(self, partner: Optional[int]) -> Optional[int]:
return -1

def field_partner(self, partner_field_id: Optional[int], field_id: int, field_name: str) -> Optional[int]:
if partner_field_id is not None:
if partner_field_id == -1:
struct = self.partner_schema.as_struct()
else:
struct = self.partner_schema.find_field(partner_field_id).field_type
if not struct.is_struct:
raise ValueError(f"Expected StructType: {struct}")

if field := struct.field_by_name(name=field_name, case_sensitive=self.case_sensitive):
return field.field_id

return None

def list_element_partner(self, partner_list_id: Optional[int]) -> Optional[int]:
if partner_list_id is not None and (field := self.partner_schema.find_field(partner_list_id)):
if not isinstance(field.field_type, ListType):
raise ValueError(f"Expected ListType: {field}")
return field.field_type.element_field.field_id
else:
return None

def map_key_partner(self, partner_map_id: Optional[int]) -> Optional[int]:
if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)):
if not isinstance(field.field_type, MapType):
raise ValueError(f"Expected MapType: {field}")
return field.field_type.key_field.field_id
else:
return None

def map_value_partner(self, partner_map_id: Optional[int]) -> Optional[int]:
if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)):
if not isinstance(field.field_type, MapType):
raise ValueError(f"Expected MapType: {field}")
return field.field_type.value_field.field_id
else:
return None


def _add_fields(fields: Tuple[NestedField, ...], adds: Optional[List[NestedField]]) -> Tuple[NestedField, ...]:
adds = adds or []
return fields + tuple(adds)
Expand Down
12 changes: 12 additions & 0 deletions pyiceberg/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,18 @@ def field(self, field_id: int) -> Optional[NestedField]:
return field
return None

def field_by_name(self, name: str, case_sensitive: bool = True) -> Optional[NestedField]:
if case_sensitive:
name_lower = name.lower()
for field in self.fields:
if field.name.lower() == name_lower:
return field
else:
for field in self.fields:
if field.name == name:
return field
return None

def __str__(self) -> str:
"""Return the string representation of the StructType class."""
return f"struct<{', '.join(map(str, self.fields))}>"
Expand Down
Loading