Skip to content

Commit

Permalink
Add UnionByName functionality (#296)
Browse files Browse the repository at this point in the history
* Add UnionByName functionality

* Thanks Honah!

* Add `_id`

Co-authored-by: Honah J. <[email protected]>

* Fix

---------

Co-authored-by: Honah J. <[email protected]>
  • Loading branch information
Fokko and HonahX authored Jan 26, 2024
1 parent 2a27f2b commit cd7fb50
Show file tree
Hide file tree
Showing 4 changed files with 905 additions and 7 deletions.
41 changes: 41 additions & 0 deletions mkdocs/docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,47 @@ with table.transaction() as transaction:
# ... Update properties etc
```

### Union by Name

Using `.union_by_name()` you can merge another schema into an existing schema without having to worry about field-IDs:

```python
from pyiceberg.catalog import load_catalog
from pyiceberg.schema import Schema
from pyiceberg.types import NestedField, StringType, DoubleType, LongType
catalog = load_catalog()
schema = Schema(
NestedField(1, "city", StringType(), required=False),
NestedField(2, "lat", DoubleType(), required=False),
NestedField(3, "long", DoubleType(), required=False),
)
table = catalog.create_table("default.locations", schema)
new_schema = Schema(
NestedField(1, "city", StringType(), required=False),
NestedField(2, "lat", DoubleType(), required=False),
NestedField(3, "long", DoubleType(), required=False),
NestedField(10, "population", LongType(), required=False),
)
with table.update_schema() as update:
update.union_by_name(new_schema)
```

Now the table has the union of the two schemas `print(table.schema())`:

```
table {
1: city: optional string
2: lat: optional double
3: long: optional double
4: population: optional long
}
```

### Add column

Using `add_column` you can add a column, without having to worry about the field-id:
Expand Down
188 changes: 183 additions & 5 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@
)
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import (
PartnerAccessor,
Schema,
SchemaVisitor,
SchemaWithPartnerVisitor,
assign_fresh_schema_ids,
promote,
visit,
visit_with_partner,
)
from pyiceberg.table.metadata import (
INITIAL_SEQUENCE_NUMBER,
Expand Down Expand Up @@ -1379,7 +1382,7 @@ class Move:


class UpdateSchema:
_table: Table
_table: Optional[Table]
_schema: Schema
_last_column_id: itertools.count[int]
_identifier_field_names: Set[str]
Expand All @@ -1398,14 +1401,23 @@ class UpdateSchema:

def __init__(
self,
table: Table,
table: Optional[Table],
transaction: Optional[Transaction] = None,
allow_incompatible_changes: bool = False,
case_sensitive: bool = True,
schema: Optional[Schema] = None,
) -> None:
self._table = table
self._schema = table.schema()
self._last_column_id = itertools.count(table.metadata.last_column_id + 1)

if isinstance(schema, Schema):
self._schema = schema
self._last_column_id = itertools.count(1 + schema.highest_field_id)
elif table is not None:
self._schema = table.schema()
self._last_column_id = itertools.count(1 + table.metadata.last_column_id)
else:
raise ValueError("Either provide a table or a schema")

self._identifier_field_names = self._schema.identifier_field_names()

self._adds = {}
Expand Down Expand Up @@ -1449,6 +1461,15 @@ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema:
self._case_sensitive = case_sensitive
return self

def union_by_name(self, new_schema: Schema) -> UpdateSchema:
visit_with_partner(
new_schema,
-1,
UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore
PartnerIdByNameAccessor(partner_schema=self._schema, case_sensitive=self._case_sensitive),
)
return self

def add_column(
self, path: Union[str, Tuple[str, ...]], field_type: IcebergType, doc: Optional[str] = None, required: bool = False
) -> UpdateSchema:
Expand Down Expand Up @@ -1816,6 +1837,9 @@ def move_after(self, path: Union[str, Tuple[str, ...]], after_name: Union[str, T

def commit(self) -> None:
"""Apply the pending changes and commit."""
if self._table is None:
raise ValueError("Requires a table to commit to")

new_schema = self._apply()

existing_schema_id = next((schema.schema_id for schema in self._table.metadata.schemas if schema == new_schema), None)
Expand Down Expand Up @@ -1862,7 +1886,8 @@ def _apply(self) -> Schema:

field_ids.add(field.field_id)

return Schema(*struct.fields, schema_id=1 + max(self._table.schemas().keys()), identifier_field_ids=field_ids)
next_schema_id = 1 + (max(self._table.schemas().keys()) if self._table is not None else self._schema.schema_id)
return Schema(*struct.fields, schema_id=next_schema_id, identifier_field_ids=field_ids)

def assign_new_column_id(self) -> int:
return next(self._last_column_id)
Expand Down Expand Up @@ -1995,6 +2020,159 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:
return primitive


class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]):
update_schema: UpdateSchema
existing_schema: Schema
case_sensitive: bool

def __init__(self, update_schema: UpdateSchema, existing_schema: Schema, case_sensitive: bool) -> None:
self.update_schema = update_schema
self.existing_schema = existing_schema
self.case_sensitive = case_sensitive

def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool:
return struct_result

def struct(self, struct: StructType, partner_id: Optional[int], missing_positions: List[bool]) -> bool:
if partner_id is None:
return True

fields = struct.fields
partner_struct = self._find_field_type(partner_id)

if not partner_struct.is_struct:
raise ValueError(f"Expected a struct, got: {partner_struct}")

for pos, missing in enumerate(missing_positions):
if missing:
self._add_column(partner_id, fields[pos])
else:
field = fields[pos]
if nested_field := partner_struct.field_by_name(field.name, case_sensitive=self.case_sensitive):
self._update_column(field, nested_field)

return False

def _add_column(self, parent_id: int, field: NestedField) -> None:
if parent_name := self.existing_schema.find_column_name(parent_id):
path: Tuple[str, ...] = (parent_name, field.name)
else:
path = (field.name,)

self.update_schema.add_column(path=path, field_type=field.field_type, required=field.required, doc=field.doc)

def _update_column(self, field: NestedField, existing_field: NestedField) -> None:
full_name = self.existing_schema.find_column_name(existing_field.field_id)

if full_name is None:
raise ValueError(f"Could not find field: {existing_field}")

if field.optional and existing_field.required:
self.update_schema.make_column_optional(full_name)

if field.field_type.is_primitive and field.field_type != existing_field.field_type:
self.update_schema.update_column(full_name, field_type=field.field_type)

if field.doc is not None and not field.doc != existing_field.doc:
self.update_schema.update_column(full_name, doc=field.doc)

def _find_field_type(self, field_id: int) -> IcebergType:
if field_id == -1:
return self.existing_schema.as_struct()
else:
return self.existing_schema.find_field(field_id).field_type

def field(self, field: NestedField, partner_id: Optional[int], field_result: bool) -> bool:
return partner_id is None

def list(self, list_type: ListType, list_partner_id: Optional[int], element_missing: bool) -> bool:
if list_partner_id is None:
return True

if element_missing:
raise ValueError("Error traversing schemas: element is missing, but list is present")

partner_list_type = self._find_field_type(list_partner_id)
if not isinstance(partner_list_type, ListType):
raise ValueError(f"Expected list-type, got: {partner_list_type}")

self._update_column(list_type.element_field, partner_list_type.element_field)

return False

def map(self, map_type: MapType, map_partner_id: Optional[int], key_missing: bool, value_missing: bool) -> bool:
if map_partner_id is None:
return True

if key_missing:
raise ValueError("Error traversing schemas: key is missing, but map is present")

if value_missing:
raise ValueError("Error traversing schemas: value is missing, but map is present")

partner_map_type = self._find_field_type(map_partner_id)
if not isinstance(partner_map_type, MapType):
raise ValueError(f"Expected map-type, got: {partner_map_type}")

self._update_column(map_type.key_field, partner_map_type.key_field)
self._update_column(map_type.value_field, partner_map_type.value_field)

return False

def primitive(self, primitive: PrimitiveType, primitive_partner_id: Optional[int]) -> bool:
return primitive_partner_id is None


class PartnerIdByNameAccessor(PartnerAccessor[int]):
partner_schema: Schema
case_sensitive: bool

def __init__(self, partner_schema: Schema, case_sensitive: bool) -> None:
self.partner_schema = partner_schema
self.case_sensitive = case_sensitive

def schema_partner(self, partner: Optional[int]) -> Optional[int]:
return -1

def field_partner(self, partner_field_id: Optional[int], field_id: int, field_name: str) -> Optional[int]:
if partner_field_id is not None:
if partner_field_id == -1:
struct = self.partner_schema.as_struct()
else:
struct = self.partner_schema.find_field(partner_field_id).field_type
if not struct.is_struct:
raise ValueError(f"Expected StructType: {struct}")

if field := struct.field_by_name(name=field_name, case_sensitive=self.case_sensitive):
return field.field_id

return None

def list_element_partner(self, partner_list_id: Optional[int]) -> Optional[int]:
if partner_list_id is not None and (field := self.partner_schema.find_field(partner_list_id)):
if not isinstance(field.field_type, ListType):
raise ValueError(f"Expected ListType: {field}")
return field.field_type.element_field.field_id
else:
return None

def map_key_partner(self, partner_map_id: Optional[int]) -> Optional[int]:
if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)):
if not isinstance(field.field_type, MapType):
raise ValueError(f"Expected MapType: {field}")
return field.field_type.key_field.field_id
else:
return None

def map_value_partner(self, partner_map_id: Optional[int]) -> Optional[int]:
if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)):
if not isinstance(field.field_type, MapType):
raise ValueError(f"Expected MapType: {field}")
return field.field_type.value_field.field_id
else:
return None


def _add_fields(fields: Tuple[NestedField, ...], adds: Optional[List[NestedField]]) -> Tuple[NestedField, ...]:
adds = adds or []
return fields + tuple(adds)
Expand Down
12 changes: 12 additions & 0 deletions pyiceberg/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,18 @@ def field(self, field_id: int) -> Optional[NestedField]:
return field
return None

def field_by_name(self, name: str, case_sensitive: bool = True) -> Optional[NestedField]:
if case_sensitive:
name_lower = name.lower()
for field in self.fields:
if field.name.lower() == name_lower:
return field
else:
for field in self.fields:
if field.name == name:
return field
return None

def __str__(self) -> str:
"""Return the string representation of the StructType class."""
return f"struct<{', '.join(map(str, self.fields))}>"
Expand Down
Loading

0 comments on commit cd7fb50

Please sign in to comment.