124
124
visit_with_partner ,
125
125
)
126
126
from pyiceberg .table import WriteTask
127
+ from pyiceberg .table .name_mapping import NameMapping
127
128
from pyiceberg .transforms import TruncateTransform
128
129
from pyiceberg .typedef import EMPTY_DICT , Properties , Record
129
130
from pyiceberg .types import (
164
165
# The PARQUET: in front means that it is Parquet specific, in this case the field_id
165
166
PYARROW_PARQUET_FIELD_ID_KEY = b"PARQUET:field_id"
166
167
PYARROW_FIELD_DOC_KEY = b"doc"
168
+ LIST_ELEMENT_NAME = "element"
169
+ MAP_KEY_NAME = "key"
170
+ MAP_VALUE_NAME = "value"
167
171
168
172
T = TypeVar ("T" )
169
173
@@ -631,8 +635,16 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows:
631
635
return np .setdiff1d (np .arange (rows ), all_chunks , assume_unique = False )
632
636
633
637
634
- def pyarrow_to_schema (schema : pa .Schema ) -> Schema :
635
- visitor = _ConvertToIceberg ()
638
+ def pyarrow_to_schema (schema : pa .Schema , name_mapping : Optional [NameMapping ] = None ) -> Schema :
639
+ has_ids = visit_pyarrow (schema , _HasIds ())
640
+ if has_ids :
641
+ visitor = _ConvertToIceberg ()
642
+ elif name_mapping is not None :
643
+ visitor = _ConvertToIceberg (name_mapping = name_mapping )
644
+ else :
645
+ raise ValueError (
646
+ "Parquet file does not have field-ids and the Iceberg table does not have 'schema.name-mapping.default' defined"
647
+ )
636
648
return visit_pyarrow (schema , visitor )
637
649
638
650
@@ -653,50 +665,47 @@ def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PyArrowSchemaVisi
653
665
654
666
655
667
@visit_pyarrow .register (pa .Schema )
656
- def _ (obj : pa .Schema , visitor : PyArrowSchemaVisitor [T ]) -> Optional [T ]:
657
- struct_results : List [Optional [T ]] = []
658
- for field in obj :
659
- visitor .before_field (field )
660
- struct_result = visit_pyarrow (field .type , visitor )
661
- visitor .after_field (field )
662
- struct_results .append (struct_result )
663
-
664
- return visitor .schema (obj , struct_results )
668
+ def _ (obj : pa .Schema , visitor : PyArrowSchemaVisitor [T ]) -> T :
669
+ return visitor .schema (obj , visit_pyarrow (pa .struct (obj ), visitor ))
665
670
666
671
667
672
@visit_pyarrow .register (pa .StructType )
668
- def _ (obj : pa .StructType , visitor : PyArrowSchemaVisitor [T ]) -> Optional [T ]:
669
- struct_results : List [Optional [T ]] = []
673
+ def _ (obj : pa .StructType , visitor : PyArrowSchemaVisitor [T ]) -> T :
674
+ results = []
675
+
670
676
for field in obj :
671
677
visitor .before_field (field )
672
- struct_result = visit_pyarrow (field .type , visitor )
678
+ result = visit_pyarrow (field .type , visitor )
679
+ results .append (visitor .field (field , result ))
673
680
visitor .after_field (field )
674
- struct_results .append (struct_result )
675
681
676
- return visitor .struct (obj , struct_results )
682
+ return visitor .struct (obj , results )
677
683
678
684
679
685
@visit_pyarrow .register (pa .ListType )
680
- def _ (obj : pa .ListType , visitor : PyArrowSchemaVisitor [T ]) -> Optional [T ]:
681
- visitor .before_field (obj .value_field )
682
- list_result = visit_pyarrow (obj .value_field .type , visitor )
683
- visitor .after_field (obj .value_field )
684
- return visitor .list (obj , list_result )
686
+ def _ (obj : pa .ListType , visitor : PyArrowSchemaVisitor [T ]) -> T :
687
+ visitor .before_list_element (obj .value_field )
688
+ result = visit_pyarrow (obj .value_type , visitor )
689
+ visitor .after_list_element (obj .value_field )
690
+
691
+ return visitor .list (obj , result )
685
692
686
693
687
694
@visit_pyarrow .register (pa .MapType )
688
- def _ (obj : pa .MapType , visitor : PyArrowSchemaVisitor [T ]) -> Optional [T ]:
689
- visitor .before_field (obj .key_field )
690
- key_result = visit_pyarrow (obj .key_field .type , visitor )
691
- visitor .after_field (obj .key_field )
692
- visitor .before_field (obj .item_field )
693
- value_result = visit_pyarrow (obj .item_field .type , visitor )
694
- visitor .after_field (obj .item_field )
695
+ def _ (obj : pa .MapType , visitor : PyArrowSchemaVisitor [T ]) -> T :
696
+ visitor .before_map_key (obj .key_field )
697
+ key_result = visit_pyarrow (obj .key_type , visitor )
698
+ visitor .after_map_key (obj .key_field )
699
+
700
+ visitor .before_map_value (obj .item_field )
701
+ value_result = visit_pyarrow (obj .item_type , visitor )
702
+ visitor .after_map_value (obj .item_field )
703
+
695
704
return visitor .map (obj , key_result , value_result )
696
705
697
706
698
707
@visit_pyarrow .register (pa .DataType )
699
- def _ (obj : pa .DataType , visitor : PyArrowSchemaVisitor [T ]) -> Optional [ T ] :
708
+ def _ (obj : pa .DataType , visitor : PyArrowSchemaVisitor [T ]) -> T :
700
709
if pa .types .is_nested (obj ):
701
710
raise TypeError (f"Expected primitive type, got: { type (obj )} " )
702
711
return visitor .primitive (obj )
@@ -709,24 +718,46 @@ def before_field(self, field: pa.Field) -> None:
709
718
def after_field (self , field : pa .Field ) -> None :
710
719
"""Override this method to perform an action immediately after visiting a field."""
711
720
721
+ def before_list_element (self , element : pa .Field ) -> None :
722
+ """Override this method to perform an action immediately before visiting an element within a ListType."""
723
+
724
+ def after_list_element (self , element : pa .Field ) -> None :
725
+ """Override this method to perform an action immediately after visiting an element within a ListType."""
726
+
727
+ def before_map_key (self , key : pa .Field ) -> None :
728
+ """Override this method to perform an action immediately before visiting a key within a MapType."""
729
+
730
+ def after_map_key (self , key : pa .Field ) -> None :
731
+ """Override this method to perform an action immediately after visiting a key within a MapType."""
732
+
733
+ def before_map_value (self , value : pa .Field ) -> None :
734
+ """Override this method to perform an action immediately before visiting a value within a MapType."""
735
+
736
+ def after_map_value (self , value : pa .Field ) -> None :
737
+ """Override this method to perform an action immediately after visiting a value within a MapType."""
738
+
712
739
@abstractmethod
713
- def schema (self , schema : pa .Schema , field_results : List [ Optional [ T ]] ) -> Optional [ T ] :
740
+ def schema (self , schema : pa .Schema , struct_result : T ) -> T :
714
741
"""Visit a schema."""
715
742
716
743
@abstractmethod
717
- def struct (self , struct : pa .StructType , field_results : List [Optional [ T ]] ) -> Optional [ T ] :
744
+ def struct (self , struct : pa .StructType , field_results : List [T ] ) -> T :
718
745
"""Visit a struct."""
719
746
720
747
@abstractmethod
721
- def list (self , list_type : pa .ListType , element_result : Optional [T ]) -> Optional [T ]:
748
+ def field (self , field : pa .Field , field_result : T ) -> T :
749
+ """Visit a field."""
750
+
751
+ @abstractmethod
752
+ def list (self , list_type : pa .ListType , element_result : T ) -> T :
722
753
"""Visit a list."""
723
754
724
755
@abstractmethod
725
- def map (self , map_type : pa .MapType , key_result : Optional [ T ] , value_result : Optional [ T ] ) -> Optional [ T ] :
756
+ def map (self , map_type : pa .MapType , key_result : T , value_result : T ) -> T :
726
757
"""Visit a map."""
727
758
728
759
@abstractmethod
729
- def primitive (self , primitive : pa .DataType ) -> Optional [ T ] :
760
+ def primitive (self , primitive : pa .DataType ) -> T :
730
761
"""Visit a primitive type."""
731
762
732
763
@@ -738,42 +769,84 @@ def _get_field_id(field: pa.Field) -> Optional[int]:
738
769
)
739
770
740
771
741
- class _ConvertToIceberg (PyArrowSchemaVisitor [Union [IcebergType , Schema ]]):
742
- def _convert_fields (self , arrow_fields : Iterable [pa .Field ], field_results : List [Optional [IcebergType ]]) -> List [NestedField ]:
743
- fields = []
744
- for i , field in enumerate (arrow_fields ):
745
- field_id = _get_field_id (field )
746
- field_doc = doc_str .decode () if (field .metadata and (doc_str := field .metadata .get (PYARROW_FIELD_DOC_KEY ))) else None
747
- field_type = field_results [i ]
748
- if field_type is not None and field_id is not None :
749
- fields .append (NestedField (field_id , field .name , field_type , required = not field .nullable , doc = field_doc ))
750
- return fields
751
-
752
- def schema (self , schema : pa .Schema , field_results : List [Optional [IcebergType ]]) -> Schema :
753
- return Schema (* self ._convert_fields (schema , field_results ))
754
-
755
- def struct (self , struct : pa .StructType , field_results : List [Optional [IcebergType ]]) -> IcebergType :
756
- return StructType (* self ._convert_fields (struct , field_results ))
757
-
758
- def list (self , list_type : pa .ListType , element_result : Optional [IcebergType ]) -> Optional [IcebergType ]:
772
+ class _HasIds (PyArrowSchemaVisitor [bool ]):
773
+ def schema (self , schema : pa .Schema , struct_result : bool ) -> bool :
774
+ return struct_result
775
+
776
+ def struct (self , struct : pa .StructType , field_results : List [bool ]) -> bool :
777
+ return all (field_results )
778
+
779
+ def field (self , field : pa .Field , field_result : bool ) -> bool :
780
+ return all ([_get_field_id (field ) is not None , field_result ])
781
+
782
+ def list (self , list_type : pa .ListType , element_result : bool ) -> bool :
759
783
element_field = list_type .value_field
760
784
element_id = _get_field_id (element_field )
761
- if element_result is not None and element_id is not None :
762
- return ListType (element_id , element_result , element_required = not element_field .nullable )
763
- return None
785
+ return element_result and element_id is not None
764
786
765
- def map (
766
- self , map_type : pa .MapType , key_result : Optional [IcebergType ], value_result : Optional [IcebergType ]
767
- ) -> Optional [IcebergType ]:
787
+ def map (self , map_type : pa .MapType , key_result : bool , value_result : bool ) -> bool :
768
788
key_field = map_type .key_field
769
789
key_id = _get_field_id (key_field )
770
790
value_field = map_type .item_field
771
791
value_id = _get_field_id (value_field )
772
- if key_result is not None and value_result is not None and key_id is not None and value_id is not None :
773
- return MapType (key_id , key_result , value_id , value_result , value_required = not value_field .nullable )
774
- return None
792
+ return all ([key_id is not None , value_id is not None , key_result , value_result ])
793
+
794
+ def primitive (self , primitive : pa .DataType ) -> bool :
795
+ return True
775
796
776
- def primitive (self , primitive : pa .DataType ) -> IcebergType :
797
+
798
+ class _ConvertToIceberg (PyArrowSchemaVisitor [Union [IcebergType , Schema ]]):
799
+ """Converts PyArrowSchema to Iceberg Schema. Applies the IDs from name_mapping if provided."""
800
+
801
+ _field_names : List [str ]
802
+ _name_mapping : Optional [NameMapping ]
803
+
804
+ def __init__ (self , name_mapping : Optional [NameMapping ] = None ) -> None :
805
+ self ._field_names = []
806
+ self ._name_mapping = name_mapping
807
+
808
+ def _current_path (self ) -> str :
809
+ return "." .join (self ._field_names )
810
+
811
+ def _field_id (self , field : pa .Field ) -> int :
812
+ if self ._name_mapping :
813
+ return self ._name_mapping .find (self ._current_path ()).field_id
814
+ elif (field_id := _get_field_id (field )) is not None :
815
+ return field_id
816
+ else :
817
+ raise ValueError (f"Cannot convert { field } to Iceberg Field as field_id is empty." )
818
+
819
+ def schema (self , schema : pa .Schema , struct_result : StructType ) -> Schema :
820
+ return Schema (* struct_result .fields )
821
+
822
+ def struct (self , struct : pa .StructType , field_results : List [NestedField ]) -> StructType :
823
+ return StructType (* field_results )
824
+
825
+ def field (self , field : pa .Field , field_result : IcebergType ) -> NestedField :
826
+ field_id = self ._field_id (field )
827
+ field_doc = doc_str .decode () if (field .metadata and (doc_str := field .metadata .get (PYARROW_FIELD_DOC_KEY ))) else None
828
+ field_type = field_result
829
+ return NestedField (field_id , field .name , field_type , required = not field .nullable , doc = field_doc )
830
+
831
+ def list (self , list_type : pa .ListType , element_result : IcebergType ) -> ListType :
832
+ element_field = list_type .value_field
833
+ self ._field_names .append (LIST_ELEMENT_NAME )
834
+ element_id = self ._field_id (element_field )
835
+ self ._field_names .pop ()
836
+ return ListType (element_id , element_result , element_required = not element_field .nullable )
837
+
838
+ def map (self , map_type : pa .MapType , key_result : IcebergType , value_result : IcebergType ) -> MapType :
839
+ key_field = map_type .key_field
840
+ self ._field_names .append (MAP_KEY_NAME )
841
+ key_id = self ._field_id (key_field )
842
+ self ._field_names .pop ()
843
+ value_field = map_type .item_field
844
+ self ._field_names .append (MAP_VALUE_NAME )
845
+ value_id = self ._field_id (value_field )
846
+ self ._field_names .pop ()
847
+ return MapType (key_id , key_result , value_id , value_result , value_required = not value_field .nullable )
848
+
849
+ def primitive (self , primitive : pa .DataType ) -> PrimitiveType :
777
850
if pa .types .is_boolean (primitive ):
778
851
return BooleanType ()
779
852
elif pa .types .is_int32 (primitive ):
@@ -808,6 +881,30 @@ def primitive(self, primitive: pa.DataType) -> IcebergType:
808
881
809
882
raise TypeError (f"Unsupported type: { primitive } " )
810
883
884
+ def before_field (self , field : pa .Field ) -> None :
885
+ self ._field_names .append (field .name )
886
+
887
+ def after_field (self , field : pa .Field ) -> None :
888
+ self ._field_names .pop ()
889
+
890
+ def before_list_element (self , element : pa .Field ) -> None :
891
+ self ._field_names .append (LIST_ELEMENT_NAME )
892
+
893
+ def after_list_element (self , element : pa .Field ) -> None :
894
+ self ._field_names .pop ()
895
+
896
+ def before_map_key (self , key : pa .Field ) -> None :
897
+ self ._field_names .append (MAP_KEY_NAME )
898
+
899
+ def after_map_key (self , element : pa .Field ) -> None :
900
+ self ._field_names .pop ()
901
+
902
+ def before_map_value (self , value : pa .Field ) -> None :
903
+ self ._field_names .append (MAP_VALUE_NAME )
904
+
905
+ def after_map_value (self , element : pa .Field ) -> None :
906
+ self ._field_names .pop ()
907
+
811
908
812
909
def _task_to_table (
813
910
fs : FileSystem ,
@@ -819,6 +916,7 @@ def _task_to_table(
819
916
case_sensitive : bool ,
820
917
row_counts : List [int ],
821
918
limit : Optional [int ] = None ,
919
+ name_mapping : Optional [NameMapping ] = None ,
822
920
) -> Optional [pa .Table ]:
823
921
if limit and sum (row_counts ) >= limit :
824
922
return None
@@ -831,9 +929,9 @@ def _task_to_table(
831
929
schema_raw = None
832
930
if metadata := physical_schema .metadata :
833
931
schema_raw = metadata .get (ICEBERG_SCHEMA )
834
- # TODO: if field_ids are not present, Name Mapping should be implemented to look them up in the table schema,
835
- # see https://github.com/apache/iceberg/issues/7451
836
- file_schema = Schema . model_validate_json ( schema_raw ) if schema_raw is not None else pyarrow_to_schema ( physical_schema )
932
+ file_schema = (
933
+ Schema . model_validate_json ( schema_raw ) if schema_raw is not None else pyarrow_to_schema ( physical_schema , name_mapping )
934
+ )
837
935
838
936
pyarrow_filter = None
839
937
if bound_row_filter is not AlwaysTrue ():
@@ -970,6 +1068,7 @@ def project_table(
970
1068
case_sensitive ,
971
1069
row_counts ,
972
1070
limit ,
1071
+ table .name_mapping (),
973
1072
)
974
1073
for task in tasks
975
1074
]
0 commit comments