diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index f8ff33c203..6812b7cdfa 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1465,7 +1465,7 @@ def union_by_name(self, new_schema: Schema) -> UpdateSchema: visit_with_partner( new_schema, -1, - UnionByNameVisitor(update_schema=self, new_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore + UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore PartnerIdByNameAccessor(partner_schema=self._schema, case_sensitive=self._case_sensitive), ) return self @@ -2022,12 +2022,12 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]): update_schema: UpdateSchema - new_schema: Schema + existing_schema: Schema case_sensitive: bool - def __init__(self, update_schema: UpdateSchema, new_schema: Schema, case_sensitive: bool) -> None: + def __init__(self, update_schema: UpdateSchema, existing_schema: Schema, case_sensitive: bool) -> None: self.update_schema = update_schema - self.new_schema = new_schema + self.existing_schema = existing_schema self.case_sensitive = case_sensitive def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool: @@ -2040,6 +2040,9 @@ def struct(self, struct: StructType, partner_id: Optional[int], missing_position fields = struct.fields partner_struct = self._find_field_type(partner_id) + if not partner_struct.is_struct: + raise ValueError(f"Expected a struct, got: {partner_struct}") + for pos, missing in enumerate(missing_positions): if missing: self._add_column(partner_id, fields[pos]) @@ -2051,7 +2054,7 @@ def struct(self, struct: StructType, partner_id: Optional[int], missing_position return False def _add_column(self, parent_id: int, field: NestedField) -> None: - if parent_name := self.new_schema.find_column_name(parent_id): + if parent_name := self.existing_schema.find_column_name(parent_id): path: Tuple[str, ...] = (parent_name, field.name) else: path = (field.name,) @@ -2059,7 +2062,7 @@ def _add_column(self, parent_id: int, field: NestedField) -> None: self.update_schema.add_column(path=path, field_type=field.field_type, required=field.required, doc=field.doc) def _update_column(self, field: NestedField, existing_field: NestedField) -> None: - full_name = self.new_schema.find_column_name(existing_field.field_id) + full_name = self.existing_schema.find_column_name(existing_field.field_id) if full_name is None: raise ValueError(f"Could not find field: {existing_field}") @@ -2075,21 +2078,21 @@ def _update_column(self, field: NestedField, existing_field: NestedField) -> Non def _find_field_type(self, field_id: int) -> IcebergType: if field_id == -1: - return self.new_schema.as_struct() + return self.existing_schema.as_struct() else: - return self.new_schema.find_field(field_id).field_type + return self.existing_schema.find_field(field_id).field_type - def field(self, field: NestedField, field_partner: Optional[int], field_result: bool) -> bool: - return field_partner is None + def field(self, field: NestedField, partner_id: Optional[int], field_result: bool) -> bool: + return partner_id is None - def list(self, list_type: ListType, list_partner: Optional[int], element_missing: bool) -> bool: - if list_partner is None: - return False + def list(self, list_type: ListType, list_partner_id: Optional[int], element_missing: bool) -> bool: + if list_partner_id is None: + return True if element_missing: raise ValueError("Error traversing schemas: element is missing, but list is present") - partner_list_type = self._find_field_type(list_partner) + partner_list_type = self._find_field_type(list_partner_id) if not isinstance(partner_list_type, ListType): raise ValueError(f"Expected list-type, got: {partner_list_type}") @@ -2097,9 +2100,9 @@ def list(self, list_type: ListType, list_partner: Optional[int], element_missing return False - def map(self, map_type: MapType, map_partner: Optional[int], key_missing: bool, value_missing: bool) -> bool: - if map_partner is None: - return False + def map(self, map_type: MapType, map_partner_id: Optional[int], key_missing: bool, value_missing: bool) -> bool: + if map_partner_id is None: + return True if key_missing: raise ValueError("Error traversing schemas: key is missing, but map is present") @@ -2107,7 +2110,7 @@ def map(self, map_type: MapType, map_partner: Optional[int], key_missing: bool, if value_missing: raise ValueError("Error traversing schemas: value is missing, but map is present") - partner_map_type = self._find_field_type(map_partner) + partner_map_type = self._find_field_type(map_partner_id) if not isinstance(partner_map_type, MapType): raise ValueError(f"Expected map-type, got: {partner_map_type}") @@ -2145,24 +2148,24 @@ def field_partner(self, partner_field_id: Optional[int], field_id: int, field_na return None - def list_element_partner(self, partner_list: Optional[int]) -> Optional[int]: - if partner_list is not None and (field := self.partner_schema.find_field(partner_list)): + def list_element_partner(self, partner_list_id: Optional[int]) -> Optional[int]: + if partner_list_id is not None and (field := self.partner_schema.find_field(partner_list_id)): if not isinstance(field.field_type, ListType): raise ValueError(f"Expected ListType: {field}") return field.field_type.element_field.field_id else: return None - def map_key_partner(self, partner_map: Optional[int]) -> Optional[int]: - if partner_map is not None and (field := self.partner_schema.find_field(partner_map)): + def map_key_partner(self, partner_map_id: Optional[int]) -> Optional[int]: + if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)): if not isinstance(field.field_type, MapType): raise ValueError(f"Expected MapType: {field}") return field.field_type.key_field.field_id else: return None - def map_value_partner(self, partner_map: Optional[int]) -> Optional[int]: - if partner_map is not None and (field := self.partner_schema.find_field(partner_map)): + def map_value_partner(self, partner_map_id: Optional[int]) -> Optional[int]: + if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)): if not isinstance(field.field_type, MapType): raise ValueError(f"Expected MapType: {field}") return field.field_type.value_field.field_id