Skip to content

Commit

Permalink
Thanks Honah!
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko committed Jan 25, 2024
1 parent 8b5ed1b commit 4b3fa28
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,7 @@ def union_by_name(self, new_schema: Schema) -> UpdateSchema:
visit_with_partner(
new_schema,
-1,
UnionByNameVisitor(update_schema=self, new_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore
UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore
PartnerIdByNameAccessor(partner_schema=self._schema, case_sensitive=self._case_sensitive),
)
return self
Expand Down Expand Up @@ -2022,12 +2022,12 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:

class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]):
update_schema: UpdateSchema
new_schema: Schema
existing_schema: Schema
case_sensitive: bool

def __init__(self, update_schema: UpdateSchema, new_schema: Schema, case_sensitive: bool) -> None:
def __init__(self, update_schema: UpdateSchema, existing_schema: Schema, case_sensitive: bool) -> None:
self.update_schema = update_schema
self.new_schema = new_schema
self.existing_schema = existing_schema
self.case_sensitive = case_sensitive

def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool:
Expand All @@ -2040,6 +2040,9 @@ def struct(self, struct: StructType, partner_id: Optional[int], missing_position
fields = struct.fields
partner_struct = self._find_field_type(partner_id)

if not partner_struct.is_struct:
raise ValueError(f"Expected a struct, got: {partner_struct}")

for pos, missing in enumerate(missing_positions):
if missing:
self._add_column(partner_id, fields[pos])
Expand All @@ -2051,15 +2054,15 @@ def struct(self, struct: StructType, partner_id: Optional[int], missing_position
return False

def _add_column(self, parent_id: int, field: NestedField) -> None:
if parent_name := self.new_schema.find_column_name(parent_id):
if parent_name := self.existing_schema.find_column_name(parent_id):
path: Tuple[str, ...] = (parent_name, field.name)
else:
path = (field.name,)

self.update_schema.add_column(path=path, field_type=field.field_type, required=field.required, doc=field.doc)

def _update_column(self, field: NestedField, existing_field: NestedField) -> None:
full_name = self.new_schema.find_column_name(existing_field.field_id)
full_name = self.existing_schema.find_column_name(existing_field.field_id)

if full_name is None:
raise ValueError(f"Could not find field: {existing_field}")
Expand All @@ -2075,39 +2078,39 @@ def _update_column(self, field: NestedField, existing_field: NestedField) -> Non

def _find_field_type(self, field_id: int) -> IcebergType:
if field_id == -1:
return self.new_schema.as_struct()
return self.existing_schema.as_struct()
else:
return self.new_schema.find_field(field_id).field_type
return self.existing_schema.find_field(field_id).field_type

def field(self, field: NestedField, field_partner: Optional[int], field_result: bool) -> bool:
return field_partner is None
def field(self, field: NestedField, partner_id: Optional[int], field_result: bool) -> bool:
return partner_id is None

def list(self, list_type: ListType, list_partner: Optional[int], element_missing: bool) -> bool:
if list_partner is None:
return False
def list(self, list_type: ListType, list_partner_id: Optional[int], element_missing: bool) -> bool:
if list_partner_id is None:
return True

if element_missing:
raise ValueError("Error traversing schemas: element is missing, but list is present")

partner_list_type = self._find_field_type(list_partner)
partner_list_type = self._find_field_type(list_partner_id)
if not isinstance(partner_list_type, ListType):
raise ValueError(f"Expected list-type, got: {partner_list_type}")

self._update_column(list_type.element_field, partner_list_type.element_field)

return False

def map(self, map_type: MapType, map_partner: Optional[int], key_missing: bool, value_missing: bool) -> bool:
if map_partner is None:
return False
def map(self, map_type: MapType, map_partner_id: Optional[int], key_missing: bool, value_missing: bool) -> bool:
if map_partner_id is None:
return True

if key_missing:
raise ValueError("Error traversing schemas: key is missing, but map is present")

if value_missing:
raise ValueError("Error traversing schemas: value is missing, but map is present")

partner_map_type = self._find_field_type(map_partner)
partner_map_type = self._find_field_type(map_partner_id)
if not isinstance(partner_map_type, MapType):
raise ValueError(f"Expected map-type, got: {partner_map_type}")

Expand Down Expand Up @@ -2145,24 +2148,24 @@ def field_partner(self, partner_field_id: Optional[int], field_id: int, field_na

return None

def list_element_partner(self, partner_list: Optional[int]) -> Optional[int]:
if partner_list is not None and (field := self.partner_schema.find_field(partner_list)):
def list_element_partner(self, partner_list_id: Optional[int]) -> Optional[int]:
if partner_list_id is not None and (field := self.partner_schema.find_field(partner_list_id)):
if not isinstance(field.field_type, ListType):
raise ValueError(f"Expected ListType: {field}")
return field.field_type.element_field.field_id
else:
return None

def map_key_partner(self, partner_map: Optional[int]) -> Optional[int]:
if partner_map is not None and (field := self.partner_schema.find_field(partner_map)):
def map_key_partner(self, partner_map_id: Optional[int]) -> Optional[int]:
if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)):
if not isinstance(field.field_type, MapType):
raise ValueError(f"Expected MapType: {field}")
return field.field_type.key_field.field_id
else:
return None

def map_value_partner(self, partner_map: Optional[int]) -> Optional[int]:
if partner_map is not None and (field := self.partner_schema.find_field(partner_map)):
def map_value_partner(self, partner_map_id: Optional[int]) -> Optional[int]:
if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)):
if not isinstance(field.field_type, MapType):
raise ValueError(f"Expected MapType: {field}")
return field.field_type.value_field.field_id
Expand Down

0 comments on commit 4b3fa28

Please sign in to comment.