diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 9d97d4f676..6f79835db0 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -293,6 +293,47 @@ with table.transaction() as transaction: # ... Update properties etc ``` +### Union by Name + +Using `.union_by_name()` you can merge another schema into an existing schema without having to worry about field-IDs: + +```python +from pyiceberg.catalog import load_catalog +from pyiceberg.schema import Schema +from pyiceberg.types import NestedField, StringType, DoubleType, LongType + +catalog = load_catalog() + +schema = Schema( + NestedField(1, "city", StringType(), required=False), + NestedField(2, "lat", DoubleType(), required=False), + NestedField(3, "long", DoubleType(), required=False), +) + +table = catalog.create_table("default.locations", schema) + +new_schema = Schema( + NestedField(1, "city", StringType(), required=False), + NestedField(2, "lat", DoubleType(), required=False), + NestedField(3, "long", DoubleType(), required=False), + NestedField(10, "population", LongType(), required=False), +) + +with table.update_schema() as update: + update.union_by_name(new_schema) +``` + +Now the table has the union of the two schemas `print(table.schema())`: + +``` +table { + 1: city: optional string + 2: lat: optional double + 3: long: optional double + 4: population: optional long +} +``` + ### Add column Using `add_column` you can add a column, without having to worry about the field-id: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 057dd8427c..221a609e5c 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -69,11 +69,14 @@ ) from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import ( + PartnerAccessor, Schema, SchemaVisitor, + SchemaWithPartnerVisitor, assign_fresh_schema_ids, promote, visit, + visit_with_partner, ) from pyiceberg.table.metadata import ( INITIAL_SEQUENCE_NUMBER, @@ -1379,7 +1382,7 @@ class Move: class UpdateSchema: - _table: Table + _table: Optional[Table] _schema: Schema _last_column_id: itertools.count[int] _identifier_field_names: Set[str] @@ -1398,14 +1401,23 @@ class UpdateSchema: def __init__( self, - table: Table, + table: Optional[Table], transaction: Optional[Transaction] = None, allow_incompatible_changes: bool = False, case_sensitive: bool = True, + schema: Optional[Schema] = None, ) -> None: self._table = table - self._schema = table.schema() - self._last_column_id = itertools.count(table.metadata.last_column_id + 1) + + if isinstance(schema, Schema): + self._schema = schema + self._last_column_id = itertools.count(1 + schema.highest_field_id) + elif table is not None: + self._schema = table.schema() + self._last_column_id = itertools.count(1 + table.metadata.last_column_id) + else: + raise ValueError("Either provide a table or a schema") + self._identifier_field_names = self._schema.identifier_field_names() self._adds = {} @@ -1449,6 +1461,15 @@ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema: self._case_sensitive = case_sensitive return self + def union_by_name(self, new_schema: Schema) -> UpdateSchema: + visit_with_partner( + new_schema, + -1, + UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore + PartnerIdByNameAccessor(partner_schema=self._schema, case_sensitive=self._case_sensitive), + ) + return self + def add_column( self, path: Union[str, Tuple[str, ...]], field_type: IcebergType, doc: Optional[str] = None, required: bool = False ) -> UpdateSchema: @@ -1816,6 +1837,9 @@ def move_after(self, path: Union[str, Tuple[str, ...]], after_name: Union[str, T def commit(self) -> None: """Apply the pending changes and commit.""" + if self._table is None: + raise ValueError("Requires a table to commit to") + new_schema = self._apply() existing_schema_id = next((schema.schema_id for schema in self._table.metadata.schemas if schema == new_schema), None) @@ -1862,7 +1886,8 @@ def _apply(self) -> Schema: field_ids.add(field.field_id) - return Schema(*struct.fields, schema_id=1 + max(self._table.schemas().keys()), identifier_field_ids=field_ids) + next_schema_id = 1 + (max(self._table.schemas().keys()) if self._table is not None else self._schema.schema_id) + return Schema(*struct.fields, schema_id=next_schema_id, identifier_field_ids=field_ids) def assign_new_column_id(self) -> int: return next(self._last_column_id) @@ -1995,6 +2020,159 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: return primitive +class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]): + update_schema: UpdateSchema + existing_schema: Schema + case_sensitive: bool + + def __init__(self, update_schema: UpdateSchema, existing_schema: Schema, case_sensitive: bool) -> None: + self.update_schema = update_schema + self.existing_schema = existing_schema + self.case_sensitive = case_sensitive + + def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool: + return struct_result + + def struct(self, struct: StructType, partner_id: Optional[int], missing_positions: List[bool]) -> bool: + if partner_id is None: + return True + + fields = struct.fields + partner_struct = self._find_field_type(partner_id) + + if not partner_struct.is_struct: + raise ValueError(f"Expected a struct, got: {partner_struct}") + + for pos, missing in enumerate(missing_positions): + if missing: + self._add_column(partner_id, fields[pos]) + else: + field = fields[pos] + if nested_field := partner_struct.field_by_name(field.name, case_sensitive=self.case_sensitive): + self._update_column(field, nested_field) + + return False + + def _add_column(self, parent_id: int, field: NestedField) -> None: + if parent_name := self.existing_schema.find_column_name(parent_id): + path: Tuple[str, ...] = (parent_name, field.name) + else: + path = (field.name,) + + self.update_schema.add_column(path=path, field_type=field.field_type, required=field.required, doc=field.doc) + + def _update_column(self, field: NestedField, existing_field: NestedField) -> None: + full_name = self.existing_schema.find_column_name(existing_field.field_id) + + if full_name is None: + raise ValueError(f"Could not find field: {existing_field}") + + if field.optional and existing_field.required: + self.update_schema.make_column_optional(full_name) + + if field.field_type.is_primitive and field.field_type != existing_field.field_type: + self.update_schema.update_column(full_name, field_type=field.field_type) + + if field.doc is not None and not field.doc != existing_field.doc: + self.update_schema.update_column(full_name, doc=field.doc) + + def _find_field_type(self, field_id: int) -> IcebergType: + if field_id == -1: + return self.existing_schema.as_struct() + else: + return self.existing_schema.find_field(field_id).field_type + + def field(self, field: NestedField, partner_id: Optional[int], field_result: bool) -> bool: + return partner_id is None + + def list(self, list_type: ListType, list_partner_id: Optional[int], element_missing: bool) -> bool: + if list_partner_id is None: + return True + + if element_missing: + raise ValueError("Error traversing schemas: element is missing, but list is present") + + partner_list_type = self._find_field_type(list_partner_id) + if not isinstance(partner_list_type, ListType): + raise ValueError(f"Expected list-type, got: {partner_list_type}") + + self._update_column(list_type.element_field, partner_list_type.element_field) + + return False + + def map(self, map_type: MapType, map_partner_id: Optional[int], key_missing: bool, value_missing: bool) -> bool: + if map_partner_id is None: + return True + + if key_missing: + raise ValueError("Error traversing schemas: key is missing, but map is present") + + if value_missing: + raise ValueError("Error traversing schemas: value is missing, but map is present") + + partner_map_type = self._find_field_type(map_partner_id) + if not isinstance(partner_map_type, MapType): + raise ValueError(f"Expected map-type, got: {partner_map_type}") + + self._update_column(map_type.key_field, partner_map_type.key_field) + self._update_column(map_type.value_field, partner_map_type.value_field) + + return False + + def primitive(self, primitive: PrimitiveType, primitive_partner_id: Optional[int]) -> bool: + return primitive_partner_id is None + + +class PartnerIdByNameAccessor(PartnerAccessor[int]): + partner_schema: Schema + case_sensitive: bool + + def __init__(self, partner_schema: Schema, case_sensitive: bool) -> None: + self.partner_schema = partner_schema + self.case_sensitive = case_sensitive + + def schema_partner(self, partner: Optional[int]) -> Optional[int]: + return -1 + + def field_partner(self, partner_field_id: Optional[int], field_id: int, field_name: str) -> Optional[int]: + if partner_field_id is not None: + if partner_field_id == -1: + struct = self.partner_schema.as_struct() + else: + struct = self.partner_schema.find_field(partner_field_id).field_type + if not struct.is_struct: + raise ValueError(f"Expected StructType: {struct}") + + if field := struct.field_by_name(name=field_name, case_sensitive=self.case_sensitive): + return field.field_id + + return None + + def list_element_partner(self, partner_list_id: Optional[int]) -> Optional[int]: + if partner_list_id is not None and (field := self.partner_schema.find_field(partner_list_id)): + if not isinstance(field.field_type, ListType): + raise ValueError(f"Expected ListType: {field}") + return field.field_type.element_field.field_id + else: + return None + + def map_key_partner(self, partner_map_id: Optional[int]) -> Optional[int]: + if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)): + if not isinstance(field.field_type, MapType): + raise ValueError(f"Expected MapType: {field}") + return field.field_type.key_field.field_id + else: + return None + + def map_value_partner(self, partner_map_id: Optional[int]) -> Optional[int]: + if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)): + if not isinstance(field.field_type, MapType): + raise ValueError(f"Expected MapType: {field}") + return field.field_type.value_field.field_id + else: + return None + + def _add_fields(fields: Tuple[NestedField, ...], adds: Optional[List[NestedField]]) -> Tuple[NestedField, ...]: adds = adds or [] return fields + tuple(adds) diff --git a/pyiceberg/types.py b/pyiceberg/types.py index 5e7c193295..eb215121dc 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -350,6 +350,18 @@ def field(self, field_id: int) -> Optional[NestedField]: return field return None + def field_by_name(self, name: str, case_sensitive: bool = True) -> Optional[NestedField]: + if case_sensitive: + name_lower = name.lower() + for field in self.fields: + if field.name.lower() == name_lower: + return field + else: + for field in self.fields: + if field.name == name: + return field + return None + def __str__(self) -> str: """Return the string representation of the StructType class.""" return f"struct<{', '.join(map(str, self.fields))}>" diff --git a/tests/test_schema.py b/tests/test_schema.py index 8e34423cf9..a5487b7fd9 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -16,12 +16,12 @@ # under the License. from textwrap import dedent -from typing import Any, Dict +from typing import Any, Dict, List import pytest from pyiceberg import schema -from pyiceberg.exceptions import ResolveError +from pyiceberg.exceptions import ResolveError, ValidationError from pyiceberg.expressions import Accessor from pyiceberg.schema import ( Schema, @@ -30,6 +30,7 @@ prune_columns, sanitize_column_names, ) +from pyiceberg.table import UpdateSchema from pyiceberg.typedef import EMPTY_DICT, StructProtocol from pyiceberg.types import ( BinaryType, @@ -45,6 +46,7 @@ LongType, MapType, NestedField, + PrimitiveType, StringType, StructType, TimestampType, @@ -912,3 +914,668 @@ def test_promotion(file_type: IcebergType, read_type: IcebergType) -> None: else: with pytest.raises(ResolveError): promote(file_type, read_type) + + +@pytest.fixture() +def primitive_fields() -> List[NestedField]: + return [ + NestedField(field_id=1, name=str(primitive_type), field_type=primitive_type, required=False) + for primitive_type in TEST_PRIMITIVE_TYPES + ] + + +def test_add_top_level_primitives(primitive_fields: NestedField) -> None: + for primitive_field in primitive_fields: + new_schema = Schema(primitive_field) + applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + assert applied == new_schema + + +def test_add_top_level_list_of_primitives(primitive_fields: NestedField) -> None: + for primitive_type in TEST_PRIMITIVE_TYPES: + new_schema = Schema( + NestedField( + field_id=1, + name="aList", + field_type=ListType(element_id=2, element_type=primitive_type, element_required=False), + required=False, + ) + ) + applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def test_add_top_level_map_of_primitives(primitive_fields: NestedField) -> None: + for primitive_type in TEST_PRIMITIVE_TYPES: + new_schema = Schema( + NestedField( + field_id=1, + name="aMap", + field_type=MapType( + key_id=2, key_type=primitive_type, value_id=3, value_type=primitive_type, value_required=False + ), + required=False, + ) + ) + applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def test_add_top_struct_of_primitives(primitive_fields: NestedField) -> None: + for primitive_type in TEST_PRIMITIVE_TYPES: + new_schema = Schema( + NestedField( + field_id=1, + name="aStruct", + field_type=StructType(NestedField(field_id=2, name="primitive", field_type=primitive_type, required=False)), + required=False, + ) + ) + applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def test_add_nested_primitive(primitive_fields: NestedField) -> None: + for primitive_type in TEST_PRIMITIVE_TYPES: + current_schema = Schema(NestedField(field_id=1, name="aStruct", field_type=StructType(), required=False)) + new_schema = Schema( + NestedField( + field_id=1, + name="aStruct", + field_type=StructType(NestedField(field_id=2, name="primitive", field_type=primitive_type, required=False)), + required=False, + ) + ) + applied = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def _primitive_fields(types: List[PrimitiveType], start_id: int = 0) -> List[NestedField]: + fields = [] + for iceberg_type in types: + fields.append(NestedField(field_id=start_id, name=str(iceberg_type), field_type=iceberg_type, required=False)) + start_id = start_id + 1 + + return fields + + +def test_add_nested_primitives(primitive_fields: NestedField) -> None: + current_schema = Schema(NestedField(field_id=1, name="aStruct", field_type=StructType(), required=False)) + new_schema = Schema( + NestedField( + field_id=1, name="aStruct", field_type=StructType(*_primitive_fields(TEST_PRIMITIVE_TYPES, 2)), required=False + ) + ) + applied = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def test_add_nested_lists(primitive_fields: NestedField) -> None: + new_schema = Schema( + NestedField( + field_id=1, + name="aList", + type=ListType( + element_id=2, + element_type=ListType( + element_id=3, + element_type=ListType( + element_id=4, + element_type=ListType( + element_id=5, + element_type=ListType( + element_id=6, + element_type=ListType( + element_id=7, + element_type=ListType( + element_id=8, + element_type=ListType(element_id=9, element_type=DecimalType(precision=11, scale=20)), + element_required=False, + ), + element_required=False, + ), + element_required=False, + ), + element_required=False, + ), + element_required=False, + ), + element_required=False, + ), + element_required=False, + ), + required=False, + ) + ) + applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def test_add_nested_struct(primitive_fields: NestedField) -> None: + new_schema = Schema( + NestedField( + field_id=1, + name="struct1", + type=StructType( + NestedField( + field_id=2, + name="struct2", + type=StructType( + NestedField( + field_id=3, + name="struct3", + type=StructType( + NestedField( + field_id=4, + name="struct4", + type=StructType( + NestedField( + field_id=5, + name="struct5", + type=StructType( + NestedField( + field_id=6, + name="struct6", + type=StructType( + NestedField(field_id=7, name="aString", field_type=StringType()) + ), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ) + applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def test_add_nested_maps(primitive_fields: NestedField) -> None: + new_schema = Schema( + NestedField( + field_id=1, + name="struct", + field_type=MapType( + key_id=2, + value_id=3, + key_type=StringType(), + value_type=MapType( + key_id=4, + value_id=5, + key_type=StringType(), + value_type=MapType( + key_id=6, + value_id=7, + key_type=StringType(), + value_type=MapType( + key_id=8, + value_id=9, + key_type=StringType(), + value_type=MapType( + key_id=10, + value_id=11, + key_type=StringType(), + value_type=MapType(key_id=12, value_id=13, key_type=StringType(), value_type=StringType()), + value_required=False, + ), + value_required=False, + ), + value_required=False, + ), + value_required=False, + ), + value_required=False, + ), + required=False, + ) + ) + applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def test_detect_invalid_top_level_list() -> None: + current_schema = Schema( + NestedField( + field_id=1, + name="aList", + field_type=ListType(element_id=2, element_type=StringType(), element_required=False), + required=False, + ) + ) + new_schema = Schema( + NestedField( + field_id=1, + name="aList", + field_type=ListType(element_id=2, element_type=DoubleType(), element_required=False), + required=False, + ) + ) + + with pytest.raises(ValidationError, match="Cannot change column type: aList.element: string -> double"): + _ = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + + +def test_detect_invalid_top_level_maps() -> None: + current_schema = Schema( + NestedField( + field_id=1, + name="aMap", + field_type=MapType(key_id=2, value_id=3, key_type=StringType(), value_type=StringType(), value_required=False), + required=False, + ) + ) + new_schema = Schema( + NestedField( + field_id=1, + name="aMap", + field_type=MapType(key_id=2, value_id=3, key_type=UUIDType(), value_type=StringType(), value_required=False), + required=False, + ) + ) + + with pytest.raises(ValidationError, match="Cannot change column type: aMap.key: string -> uuid"): + _ = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + + +def test_promote_float_to_double() -> None: + current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=FloatType(), required=False)) + new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=DoubleType(), required=False)) + + applied = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + + assert applied.as_struct() == new_schema.as_struct() + assert len(applied.fields) == 1 + assert isinstance(applied.fields[0].field_type, DoubleType) + + +def test_detect_invalid_promotion_double_to_float() -> None: + current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=DoubleType(), required=False)) + new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=FloatType(), required=False)) + + with pytest.raises(ValidationError, match="Cannot change column type: aCol: double -> float"): + _ = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + + +# decimal(P,S) Fixed-point decimal; precision P, scale S -> Scale is fixed [1], +# precision must be 38 or less +def test_type_promote_decimal_to_fixed_scale_with_wider_precision() -> None: + current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=DecimalType(precision=20, scale=1), required=False)) + new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=DecimalType(precision=22, scale=1), required=False)) + + applied = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + + assert applied.as_struct() == new_schema.as_struct() + assert len(applied.fields) == 1 + field = applied.fields[0] + decimal_type = field.field_type + assert isinstance(decimal_type, DecimalType) + assert decimal_type.precision == 22 + assert decimal_type.scale == 1 + + +def test_add_nested_structs(primitive_fields: NestedField) -> None: + schema = Schema( + NestedField( + field_id=1, + name="struct1", + field_type=StructType( + NestedField( + field_id=2, + name="struct2", + field_type=StructType( + NestedField( + field_id=3, + name="list", + field_type=ListType( + element_id=4, + element_type=StructType( + NestedField(field_id=5, name="value", field_type=StringType(), required=False) + ), + element_required=False, + ), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ) + new_schema = Schema( + NestedField( + field_id=1, + name="struct1", + field_type=StructType( + NestedField( + field_id=2, + name="struct2", + field_type=StructType( + NestedField( + field_id=3, + name="list", + field_type=ListType( + element_id=4, + element_type=StructType( + NestedField(field_id=5, name="time", field_type=TimeType(), required=False) + ), + element_required=False, + ), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ) + applied = UpdateSchema(None, schema=schema).union_by_name(new_schema)._apply() + + expected = Schema( + NestedField( + field_id=1, + name="struct1", + field_type=StructType( + NestedField( + field_id=2, + name="struct2", + field_type=StructType( + NestedField( + field_id=3, + name="list", + field_type=ListType( + element_id=4, + element_type=StructType( + NestedField(field_id=5, name="value", field_type=StringType(), required=False), + NestedField(field_id=6, name="time", field_type=TimeType(), required=False), + ), + element_required=False, + ), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ) + + assert applied.as_struct() == expected.as_struct() + + +def test_replace_list_with_primitive() -> None: + current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=ListType(element_id=2, element_type=StringType()))) + new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=StringType())) + + with pytest.raises(ValidationError, match="Cannot change column type: list is not a primitive"): + _ = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + + +def test_mirrored_schemas() -> None: + current_schema = Schema( + NestedField(9, "struct1", StructType(NestedField(8, "string1", StringType(), required=False)), required=False), + NestedField(6, "list1", ListType(element_id=7, element_type=StringType(), element_required=False), required=False), + NestedField(5, "string2", StringType(), required=False), + NestedField(4, "string3", StringType(), required=False), + NestedField(3, "string4", StringType(), required=False), + NestedField(2, "string5", StringType(), required=False), + NestedField(1, "string6", StringType(), required=False), + ) + mirrored_schema = Schema( + NestedField(1, "struct1", StructType(NestedField(2, "string1", StringType(), required=False))), + NestedField(3, "list1", ListType(element_id=4, element_type=StringType(), element_required=False), required=False), + NestedField(5, "string2", StringType(), required=False), + NestedField(6, "string3", StringType(), required=False), + NestedField(7, "string4", StringType(), required=False), + NestedField(8, "string5", StringType(), required=False), + NestedField(9, "string6", StringType(), required=False), + ) + + applied = UpdateSchema(None, schema=current_schema).union_by_name(mirrored_schema)._apply() + + assert applied.as_struct() == current_schema.as_struct() + + +def test_add_new_top_level_struct() -> None: + current_schema = Schema( + NestedField( + 1, + "map1", + MapType( + key_id=2, + value_id=3, + key_type=StringType(), + value_type=ListType( + element_id=4, + element_type=StructType(NestedField(field_id=5, name="string", field_type=StringType(), required=False)), + ), + value_required=False, + ), + ) + ) + observed_schema = Schema( + NestedField( + 1, + "map1", + MapType( + key_id=2, + value_id=3, + key_type=StringType(), + value_type=ListType( + element_id=4, + element_type=StructType(NestedField(field_id=5, name="string", field_type=StringType(), required=False)), + ), + value_required=False, + ), + ), + NestedField( + field_id=6, + name="struct1", + field_type=StructType( + NestedField( + field_id=7, + name="d1", + field_type=StructType(NestedField(field_id=8, name="d2", field_type=StringType(), required=False)), + required=False, + ) + ), + required=False, + ), + ) + + applied = UpdateSchema(None, schema=current_schema).union_by_name(observed_schema)._apply() + + assert applied.as_struct() == observed_schema.as_struct() + + +def test_append_nested_struct() -> None: + current_schema = Schema( + NestedField( + field_id=1, + name="s1", + field_type=StructType( + NestedField( + field_id=2, + name="s2", + field_type=StructType( + NestedField( + field_id=3, + name="s3", + field_type=StructType(NestedField(field_id=4, name="s4", field_type=StringType(), required=False)), + ) + ), + required=False, + ) + ), + ) + ) + observed_schema = Schema( + NestedField( + field_id=1, + name="s1", + field_type=StructType( + NestedField( + field_id=2, + name="s2", + field_type=StructType( + NestedField( + field_id=3, + name="s3", + field_type=StructType(NestedField(field_id=4, name="s4", field_type=StringType(), required=False)), + required=False, + ), + NestedField( + field_id=5, + name="repeat", + field_type=StructType( + NestedField( + field_id=6, + name="s1", + field_type=StructType( + NestedField( + field_id=7, + name="s2", + field_type=StructType( + NestedField( + field_id=8, + name="s3", + field_type=StructType( + NestedField(field_id=9, name="s4", field_type=StringType()) + ), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ), + required=False, + ), + required=False, + ) + ), + required=False, + ) + ) + + applied = UpdateSchema(None, schema=current_schema).union_by_name(observed_schema)._apply() + + assert applied.as_struct() == observed_schema.as_struct() + + +def test_append_nested_lists() -> None: + current_schema = Schema( + NestedField( + field_id=1, + name="s1", + field_type=StructType( + NestedField( + field_id=2, + name="s2", + field_type=StructType( + NestedField( + field_id=3, + name="s3", + field_type=StructType( + NestedField( + field_id=4, + name="list1", + field_type=ListType(element_id=5, element_type=StringType(), element_required=False), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ) + + observed_schema = Schema( + NestedField( + field_id=1, + name="s1", + field_type=StructType( + NestedField( + field_id=2, + name="s2", + field_type=StructType( + NestedField( + field_id=3, + name="s3", + field_type=StructType( + NestedField( + field_id=4, + name="list2", + field_type=ListType(element_id=5, element_type=StringType(), element_required=False), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ) + union = UpdateSchema(None, schema=current_schema).union_by_name(observed_schema)._apply() + + expected = Schema( + NestedField( + field_id=1, + name="s1", + field_type=StructType( + NestedField( + field_id=2, + name="s2", + field_type=StructType( + NestedField( + field_id=3, + name="s3", + field_type=StructType( + NestedField( + field_id=4, + name="list1", + field_type=ListType(element_id=5, element_type=StringType(), element_required=False), + required=False, + ), + NestedField( + field_id=6, + name="list2", + field_type=ListType(element_id=7, element_type=StringType(), element_required=False), + required=False, + ), + ), + required=False, + ) + ), + required=False, + ) + ), + required=False, + ) + ) + + assert union.as_struct() == expected.as_struct()