Skip to content

Arrow: Allow missing field-ids from Schema #183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 47 additions & 18 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
from __future__ import annotations

import concurrent.futures
import itertools
import logging
import os
import re
import warnings
from abc import ABC, abstractmethod
from concurrent.futures import Future
from dataclasses import dataclass
Expand Down Expand Up @@ -110,6 +112,7 @@
Schema,
SchemaVisitorPerPrimitiveType,
SchemaWithPartnerVisitor,
assign_fresh_schema_ids,
pre_order_visit,
promote,
prune_columns,
Expand Down Expand Up @@ -616,7 +619,12 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows:

def pyarrow_to_schema(schema: pa.Schema) -> Schema:
visitor = _ConvertToIceberg()
return visit_pyarrow(schema, visitor)
schema = visit_pyarrow(schema, visitor)

if visitor.missing_id_metadata:
return assign_fresh_schema_ids(schema)
else:
return schema


@singledispatch
Expand Down Expand Up @@ -713,28 +721,49 @@ def primitive(self, primitive: pa.DataType) -> Optional[T]:
"""Visit a primitive type."""


def _get_field_id(field: pa.Field) -> Optional[int]:
for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS:
if field_id_str := field.metadata.get(pyarrow_field_id_key):
return int(field_id_str.decode())
return None
class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
counter: itertools.count[int]
missing_id_metadata: Optional[bool]

def __init__(self) -> None:
self.counter = itertools.count(1)
self.missing_id_metadata = None

def _get_field_doc(field: pa.Field) -> Optional[str]:
for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS:
if doc_str := field.metadata.get(pyarrow_doc_key):
return doc_str.decode()
return None
def _get_field_id(self, field: pa.Field) -> int:
field_id: Optional[int] = None

for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS:
if field.metadata and (field_id_str := field.metadata.get(pyarrow_field_id_key)):
field_id = int(field_id_str.decode())

if field_id is None:
if self.missing_id_metadata is None:
warnings.warn("Field-ids are missing, new IDs will be set")
field_id = next(self.counter)
missing_is_metadata = True
else:
missing_is_metadata = False

if self.missing_id_metadata is not None and self.missing_id_metadata != missing_is_metadata:
raise ValueError("Parquet file contains partial field-ids")
else:
self.missing_id_metadata = missing_is_metadata

return field_id

def _get_field_doc(self, field: pa.Field) -> Optional[str]:
for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS:
if field.metadata and (doc_str := field.metadata.get(pyarrow_doc_key)):
return doc_str.decode()
return None

class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]:
fields = []
for i, field in enumerate(arrow_fields):
field_id = _get_field_id(field)
field_doc = _get_field_doc(field)
field_id = self._get_field_id(field)
field_doc = self._get_field_doc(field)
field_type = field_results[i]
if field_type is not None and field_id is not None:
if field_type is not None:
fields.append(NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc))
return fields

Expand All @@ -746,7 +775,7 @@ def struct(self, struct: pa.StructType, field_results: List[Optional[IcebergType

def list(self, list_type: pa.ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
element_field = list_type.value_field
element_id = _get_field_id(element_field)
element_id = self._get_field_id(element_field)
if element_result is not None and element_id is not None:
return ListType(element_id, element_result, element_required=not element_field.nullable)
return None
Expand All @@ -755,9 +784,9 @@ def map(
self, map_type: pa.MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
) -> Optional[IcebergType]:
key_field = map_type.key_field
key_id = _get_field_id(key_field)
key_id = self._get_field_id(key_field)
value_field = map_type.item_field
value_id = _get_field_id(value_field)
value_id = self._get_field_id(value_field)
if key_result is not None and value_result is not None and key_id is not None and value_id is not None:
return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable)
return None
Expand Down
30 changes: 30 additions & 0 deletions tests/io/test_pyarrow_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=protected-access,unused-argument,redefined-outer-name
import re
from unittest.mock import Mock, patch

import pyarrow as pa
import pytest
Expand Down Expand Up @@ -269,3 +270,32 @@ def test_round_schema_conversion_nested(table_schema_nested: Schema) -> None:
15: person: optional struct<16: name: optional string, 17: age: required int>
}"""
assert actual == expected


@patch("warnings.warn")
def test_schema_to_pyarrow_schema_missing_ids(warn: Mock) -> None:
schema = pa.schema([pa.field('some_int', pa.int32(), nullable=True), pa.field('some_string', pa.string(), nullable=False)])
actual = pyarrow_to_schema(schema)

expected = Schema(
NestedField(field_id=1, name="some_int", field_type=IntegerType(), required=False),
NestedField(field_id=2, name="some_string", field_type=StringType(), required=True),
)

assert actual == expected
assert warn.called


@patch("warnings.warn")
def test_schema_to_pyarrow_schema_missing_id(warn: Mock) -> None:
schema = pa.schema(
[
pa.field('some_int', pa.int32(), nullable=True),
pa.field('some_string', pa.string(), nullable=False, metadata={b"field_id": "22"}),
]
)

with pytest.raises(ValueError) as exc_info:
_ = pyarrow_to_schema(schema)
assert "Parquet file contains partial field-ids" in str(exc_info.value)
assert warn.called