Skip to content

Commit 70972d9

Browse files
authored
Apply Name mapping (#219)
1 parent e1018e5 commit 70972d9

File tree

4 files changed

+484
-65
lines changed

4 files changed

+484
-65
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 164 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@
124124
visit_with_partner,
125125
)
126126
from pyiceberg.table import WriteTask
127+
from pyiceberg.table.name_mapping import NameMapping
127128
from pyiceberg.transforms import TruncateTransform
128129
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
129130
from pyiceberg.types import (
@@ -164,6 +165,9 @@
164165
# The PARQUET: in front means that it is Parquet specific, in this case the field_id
165166
PYARROW_PARQUET_FIELD_ID_KEY = b"PARQUET:field_id"
166167
PYARROW_FIELD_DOC_KEY = b"doc"
168+
LIST_ELEMENT_NAME = "element"
169+
MAP_KEY_NAME = "key"
170+
MAP_VALUE_NAME = "value"
167171

168172
T = TypeVar("T")
169173

@@ -631,8 +635,16 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows:
631635
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)
632636

633637

634-
def pyarrow_to_schema(schema: pa.Schema) -> Schema:
635-
visitor = _ConvertToIceberg()
638+
def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = None) -> Schema:
639+
has_ids = visit_pyarrow(schema, _HasIds())
640+
if has_ids:
641+
visitor = _ConvertToIceberg()
642+
elif name_mapping is not None:
643+
visitor = _ConvertToIceberg(name_mapping=name_mapping)
644+
else:
645+
raise ValueError(
646+
"Parquet file does not have field-ids and the Iceberg table does not have 'schema.name-mapping.default' defined"
647+
)
636648
return visit_pyarrow(schema, visitor)
637649

638650

@@ -653,50 +665,47 @@ def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PyArrowSchemaVisi
653665

654666

655667
@visit_pyarrow.register(pa.Schema)
656-
def _(obj: pa.Schema, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
657-
struct_results: List[Optional[T]] = []
658-
for field in obj:
659-
visitor.before_field(field)
660-
struct_result = visit_pyarrow(field.type, visitor)
661-
visitor.after_field(field)
662-
struct_results.append(struct_result)
663-
664-
return visitor.schema(obj, struct_results)
668+
def _(obj: pa.Schema, visitor: PyArrowSchemaVisitor[T]) -> T:
669+
return visitor.schema(obj, visit_pyarrow(pa.struct(obj), visitor))
665670

666671

667672
@visit_pyarrow.register(pa.StructType)
668-
def _(obj: pa.StructType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
669-
struct_results: List[Optional[T]] = []
673+
def _(obj: pa.StructType, visitor: PyArrowSchemaVisitor[T]) -> T:
674+
results = []
675+
670676
for field in obj:
671677
visitor.before_field(field)
672-
struct_result = visit_pyarrow(field.type, visitor)
678+
result = visit_pyarrow(field.type, visitor)
679+
results.append(visitor.field(field, result))
673680
visitor.after_field(field)
674-
struct_results.append(struct_result)
675681

676-
return visitor.struct(obj, struct_results)
682+
return visitor.struct(obj, results)
677683

678684

679685
@visit_pyarrow.register(pa.ListType)
680-
def _(obj: pa.ListType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
681-
visitor.before_field(obj.value_field)
682-
list_result = visit_pyarrow(obj.value_field.type, visitor)
683-
visitor.after_field(obj.value_field)
684-
return visitor.list(obj, list_result)
686+
def _(obj: pa.ListType, visitor: PyArrowSchemaVisitor[T]) -> T:
687+
visitor.before_list_element(obj.value_field)
688+
result = visit_pyarrow(obj.value_type, visitor)
689+
visitor.after_list_element(obj.value_field)
690+
691+
return visitor.list(obj, result)
685692

686693

687694
@visit_pyarrow.register(pa.MapType)
688-
def _(obj: pa.MapType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
689-
visitor.before_field(obj.key_field)
690-
key_result = visit_pyarrow(obj.key_field.type, visitor)
691-
visitor.after_field(obj.key_field)
692-
visitor.before_field(obj.item_field)
693-
value_result = visit_pyarrow(obj.item_field.type, visitor)
694-
visitor.after_field(obj.item_field)
695+
def _(obj: pa.MapType, visitor: PyArrowSchemaVisitor[T]) -> T:
696+
visitor.before_map_key(obj.key_field)
697+
key_result = visit_pyarrow(obj.key_type, visitor)
698+
visitor.after_map_key(obj.key_field)
699+
700+
visitor.before_map_value(obj.item_field)
701+
value_result = visit_pyarrow(obj.item_type, visitor)
702+
visitor.after_map_value(obj.item_field)
703+
695704
return visitor.map(obj, key_result, value_result)
696705

697706

698707
@visit_pyarrow.register(pa.DataType)
699-
def _(obj: pa.DataType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
708+
def _(obj: pa.DataType, visitor: PyArrowSchemaVisitor[T]) -> T:
700709
if pa.types.is_nested(obj):
701710
raise TypeError(f"Expected primitive type, got: {type(obj)}")
702711
return visitor.primitive(obj)
@@ -709,24 +718,46 @@ def before_field(self, field: pa.Field) -> None:
709718
def after_field(self, field: pa.Field) -> None:
710719
"""Override this method to perform an action immediately after visiting a field."""
711720

721+
def before_list_element(self, element: pa.Field) -> None:
722+
"""Override this method to perform an action immediately before visiting an element within a ListType."""
723+
724+
def after_list_element(self, element: pa.Field) -> None:
725+
"""Override this method to perform an action immediately after visiting an element within a ListType."""
726+
727+
def before_map_key(self, key: pa.Field) -> None:
728+
"""Override this method to perform an action immediately before visiting a key within a MapType."""
729+
730+
def after_map_key(self, key: pa.Field) -> None:
731+
"""Override this method to perform an action immediately after visiting a key within a MapType."""
732+
733+
def before_map_value(self, value: pa.Field) -> None:
734+
"""Override this method to perform an action immediately before visiting a value within a MapType."""
735+
736+
def after_map_value(self, value: pa.Field) -> None:
737+
"""Override this method to perform an action immediately after visiting a value within a MapType."""
738+
712739
@abstractmethod
713-
def schema(self, schema: pa.Schema, field_results: List[Optional[T]]) -> Optional[T]:
740+
def schema(self, schema: pa.Schema, struct_result: T) -> T:
714741
"""Visit a schema."""
715742

716743
@abstractmethod
717-
def struct(self, struct: pa.StructType, field_results: List[Optional[T]]) -> Optional[T]:
744+
def struct(self, struct: pa.StructType, field_results: List[T]) -> T:
718745
"""Visit a struct."""
719746

720747
@abstractmethod
721-
def list(self, list_type: pa.ListType, element_result: Optional[T]) -> Optional[T]:
748+
def field(self, field: pa.Field, field_result: T) -> T:
749+
"""Visit a field."""
750+
751+
@abstractmethod
752+
def list(self, list_type: pa.ListType, element_result: T) -> T:
722753
"""Visit a list."""
723754

724755
@abstractmethod
725-
def map(self, map_type: pa.MapType, key_result: Optional[T], value_result: Optional[T]) -> Optional[T]:
756+
def map(self, map_type: pa.MapType, key_result: T, value_result: T) -> T:
726757
"""Visit a map."""
727758

728759
@abstractmethod
729-
def primitive(self, primitive: pa.DataType) -> Optional[T]:
760+
def primitive(self, primitive: pa.DataType) -> T:
730761
"""Visit a primitive type."""
731762

732763

@@ -738,42 +769,84 @@ def _get_field_id(field: pa.Field) -> Optional[int]:
738769
)
739770

740771

741-
class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
742-
def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]:
743-
fields = []
744-
for i, field in enumerate(arrow_fields):
745-
field_id = _get_field_id(field)
746-
field_doc = doc_str.decode() if (field.metadata and (doc_str := field.metadata.get(PYARROW_FIELD_DOC_KEY))) else None
747-
field_type = field_results[i]
748-
if field_type is not None and field_id is not None:
749-
fields.append(NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc))
750-
return fields
751-
752-
def schema(self, schema: pa.Schema, field_results: List[Optional[IcebergType]]) -> Schema:
753-
return Schema(*self._convert_fields(schema, field_results))
754-
755-
def struct(self, struct: pa.StructType, field_results: List[Optional[IcebergType]]) -> IcebergType:
756-
return StructType(*self._convert_fields(struct, field_results))
757-
758-
def list(self, list_type: pa.ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
772+
class _HasIds(PyArrowSchemaVisitor[bool]):
773+
def schema(self, schema: pa.Schema, struct_result: bool) -> bool:
774+
return struct_result
775+
776+
def struct(self, struct: pa.StructType, field_results: List[bool]) -> bool:
777+
return all(field_results)
778+
779+
def field(self, field: pa.Field, field_result: bool) -> bool:
780+
return all([_get_field_id(field) is not None, field_result])
781+
782+
def list(self, list_type: pa.ListType, element_result: bool) -> bool:
759783
element_field = list_type.value_field
760784
element_id = _get_field_id(element_field)
761-
if element_result is not None and element_id is not None:
762-
return ListType(element_id, element_result, element_required=not element_field.nullable)
763-
return None
785+
return element_result and element_id is not None
764786

765-
def map(
766-
self, map_type: pa.MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
767-
) -> Optional[IcebergType]:
787+
def map(self, map_type: pa.MapType, key_result: bool, value_result: bool) -> bool:
768788
key_field = map_type.key_field
769789
key_id = _get_field_id(key_field)
770790
value_field = map_type.item_field
771791
value_id = _get_field_id(value_field)
772-
if key_result is not None and value_result is not None and key_id is not None and value_id is not None:
773-
return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable)
774-
return None
792+
return all([key_id is not None, value_id is not None, key_result, value_result])
793+
794+
def primitive(self, primitive: pa.DataType) -> bool:
795+
return True
775796

776-
def primitive(self, primitive: pa.DataType) -> IcebergType:
797+
798+
class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
799+
"""Converts PyArrowSchema to Iceberg Schema. Applies the IDs from name_mapping if provided."""
800+
801+
_field_names: List[str]
802+
_name_mapping: Optional[NameMapping]
803+
804+
def __init__(self, name_mapping: Optional[NameMapping] = None) -> None:
805+
self._field_names = []
806+
self._name_mapping = name_mapping
807+
808+
def _current_path(self) -> str:
809+
return ".".join(self._field_names)
810+
811+
def _field_id(self, field: pa.Field) -> int:
812+
if self._name_mapping:
813+
return self._name_mapping.find(self._current_path()).field_id
814+
elif (field_id := _get_field_id(field)) is not None:
815+
return field_id
816+
else:
817+
raise ValueError(f"Cannot convert {field} to Iceberg Field as field_id is empty.")
818+
819+
def schema(self, schema: pa.Schema, struct_result: StructType) -> Schema:
820+
return Schema(*struct_result.fields)
821+
822+
def struct(self, struct: pa.StructType, field_results: List[NestedField]) -> StructType:
823+
return StructType(*field_results)
824+
825+
def field(self, field: pa.Field, field_result: IcebergType) -> NestedField:
826+
field_id = self._field_id(field)
827+
field_doc = doc_str.decode() if (field.metadata and (doc_str := field.metadata.get(PYARROW_FIELD_DOC_KEY))) else None
828+
field_type = field_result
829+
return NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc)
830+
831+
def list(self, list_type: pa.ListType, element_result: IcebergType) -> ListType:
832+
element_field = list_type.value_field
833+
self._field_names.append(LIST_ELEMENT_NAME)
834+
element_id = self._field_id(element_field)
835+
self._field_names.pop()
836+
return ListType(element_id, element_result, element_required=not element_field.nullable)
837+
838+
def map(self, map_type: pa.MapType, key_result: IcebergType, value_result: IcebergType) -> MapType:
839+
key_field = map_type.key_field
840+
self._field_names.append(MAP_KEY_NAME)
841+
key_id = self._field_id(key_field)
842+
self._field_names.pop()
843+
value_field = map_type.item_field
844+
self._field_names.append(MAP_VALUE_NAME)
845+
value_id = self._field_id(value_field)
846+
self._field_names.pop()
847+
return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable)
848+
849+
def primitive(self, primitive: pa.DataType) -> PrimitiveType:
777850
if pa.types.is_boolean(primitive):
778851
return BooleanType()
779852
elif pa.types.is_int32(primitive):
@@ -808,6 +881,30 @@ def primitive(self, primitive: pa.DataType) -> IcebergType:
808881

809882
raise TypeError(f"Unsupported type: {primitive}")
810883

884+
def before_field(self, field: pa.Field) -> None:
885+
self._field_names.append(field.name)
886+
887+
def after_field(self, field: pa.Field) -> None:
888+
self._field_names.pop()
889+
890+
def before_list_element(self, element: pa.Field) -> None:
891+
self._field_names.append(LIST_ELEMENT_NAME)
892+
893+
def after_list_element(self, element: pa.Field) -> None:
894+
self._field_names.pop()
895+
896+
def before_map_key(self, key: pa.Field) -> None:
897+
self._field_names.append(MAP_KEY_NAME)
898+
899+
def after_map_key(self, element: pa.Field) -> None:
900+
self._field_names.pop()
901+
902+
def before_map_value(self, value: pa.Field) -> None:
903+
self._field_names.append(MAP_VALUE_NAME)
904+
905+
def after_map_value(self, element: pa.Field) -> None:
906+
self._field_names.pop()
907+
811908

812909
def _task_to_table(
813910
fs: FileSystem,
@@ -819,6 +916,7 @@ def _task_to_table(
819916
case_sensitive: bool,
820917
row_counts: List[int],
821918
limit: Optional[int] = None,
919+
name_mapping: Optional[NameMapping] = None,
822920
) -> Optional[pa.Table]:
823921
if limit and sum(row_counts) >= limit:
824922
return None
@@ -831,9 +929,9 @@ def _task_to_table(
831929
schema_raw = None
832930
if metadata := physical_schema.metadata:
833931
schema_raw = metadata.get(ICEBERG_SCHEMA)
834-
# TODO: if field_ids are not present, Name Mapping should be implemented to look them up in the table schema,
835-
# see https://github.com/apache/iceberg/issues/7451
836-
file_schema = Schema.model_validate_json(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema)
932+
file_schema = (
933+
Schema.model_validate_json(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema, name_mapping)
934+
)
837935

838936
pyarrow_filter = None
839937
if bound_row_filter is not AlwaysTrue():
@@ -970,6 +1068,7 @@ def project_table(
9701068
case_sensitive,
9711069
row_counts,
9721070
limit,
1071+
table.name_mapping(),
9731072
)
9741073
for task in tasks
9751074
]

pyiceberg/table/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@
8181
TableMetadata,
8282
TableMetadataUtil,
8383
)
84+
from pyiceberg.table.name_mapping import (
85+
SCHEMA_NAME_MAPPING_DEFAULT,
86+
NameMapping,
87+
create_mapping_from_schema,
88+
parse_mapping_from_json,
89+
)
8490
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
8591
from pyiceberg.table.snapshots import (
8692
Operation,
@@ -909,6 +915,13 @@ def history(self) -> List[SnapshotLogEntry]:
909915
def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema:
910916
return UpdateSchema(self, allow_incompatible_changes=allow_incompatible_changes, case_sensitive=case_sensitive)
911917

918+
def name_mapping(self) -> NameMapping:
919+
"""Return the table's field-id NameMapping."""
920+
if name_mapping_json := self.properties.get(SCHEMA_NAME_MAPPING_DEFAULT):
921+
return parse_mapping_from_json(name_mapping_json)
922+
else:
923+
return create_mapping_from_schema(self.schema())
924+
912925
def append(self, df: pa.Table) -> None:
913926
"""
914927
Append data to the table.

pyiceberg/table/name_mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel
3535
from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType
3636

37+
SCHEMA_NAME_MAPPING_DEFAULT = "schema.name-mapping.default"
38+
3739

3840
class MappedField(IcebergBaseModel):
3941
field_id: int = Field(alias="field-id")

0 commit comments

Comments
 (0)