Skip to content

Commit 0a0e829

Browse files
committed
Add name-mapping
All the things to (de)serialize the name-mapping, and all the neccessary visitors and such
1 parent 2bd8cf2 commit 0a0e829

File tree

2 files changed

+495
-0
lines changed

2 files changed

+495
-0
lines changed

pyiceberg/table/name_mapping.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
Contains everything around the name mapping.
19+
20+
More information can be found on here:
21+
https://iceberg.apache.org/spec/#name-mapping-serialization
22+
"""
23+
from __future__ import annotations
24+
25+
from abc import ABC, abstractmethod
26+
from collections import ChainMap
27+
from functools import cached_property, singledispatch
28+
from typing import Any, Dict, Generic, List, Set, TypeVar, Union
29+
30+
from pydantic import Field, conset, model_serializer
31+
32+
from pyiceberg.schema import Schema, SchemaVisitor, visit
33+
from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel
34+
from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType
35+
36+
37+
class MappedField(IcebergBaseModel):
38+
field_id: int = Field(alias="field-id")
39+
names: Set[str] = conset(str, min_length=1)
40+
fields: List[MappedField] = Field(default_factory=list)
41+
42+
@model_serializer
43+
def ser_model(self) -> Dict[str, Any]:
44+
"""Set custom serializer to leave out the field when it is empty."""
45+
fields = {'fields': self.fields} if len(self.fields) > 0 else {}
46+
return {
47+
'field-id': self.field_id,
48+
# Sort the names to give a consistent output in json
49+
'names': sorted([self.names]),
50+
**fields,
51+
}
52+
53+
def __len__(self) -> int:
54+
"""Return the number of fields."""
55+
return len(self.fields)
56+
57+
def __str__(self) -> str:
58+
"""Convert the mapped-field into a nicely formatted string."""
59+
# Otherwise the UTs fail because the order of the set can change
60+
fields_str = ", ".join([str(e) for e in self.fields]) or ""
61+
fields_str = " " + fields_str if fields_str else ""
62+
return "([" + ", ".join(sorted(self.names)) + "] -> " + (str(self.field_id) or "?") + fields_str + ")"
63+
64+
65+
class NameMapping(IcebergRootModel[List[MappedField]]):
66+
root: List[MappedField]
67+
68+
@cached_property
69+
def _field_by_id(self) -> Dict[int, MappedField]:
70+
return visit_name_mapping(self, _IndexById())
71+
72+
@cached_property
73+
def _field_by_name(self) -> Dict[str, MappedField]:
74+
return visit_name_mapping(self, _IndexByName())
75+
76+
def id(self, name: str) -> int:
77+
try:
78+
return self._field_by_name[name].field_id
79+
except KeyError as e:
80+
raise ValueError(f"Could not find field with name: {name}") from e
81+
82+
def field(self, field_id: int) -> MappedField:
83+
try:
84+
return self._field_by_id[field_id]
85+
except KeyError as e:
86+
raise ValueError(f"Could not find field-id: {field_id}") from e
87+
88+
def __len__(self) -> int:
89+
"""Return the number of mappings."""
90+
return len(self.root)
91+
92+
def __str__(self) -> str:
93+
"""Convert the name-mapping into a nicely formatted string."""
94+
if len(self.root) == 0:
95+
return "[]"
96+
else:
97+
return "[\n " + "\n ".join([str(e) for e in self.root]) + "\n]"
98+
99+
100+
T = TypeVar("T")
101+
102+
103+
class NameMappingVisitor(Generic[T], ABC):
104+
@abstractmethod
105+
def mapping(self, nm: NameMapping, field_results: T) -> T:
106+
"""Visit a NameMapping."""
107+
108+
@abstractmethod
109+
def fields(self, struct: List[MappedField], field_results: List[T]) -> T:
110+
"""Visit a List[MappedField]."""
111+
112+
@abstractmethod
113+
def field(self, field: MappedField, field_result: T) -> T:
114+
"""Visit a MappedField."""
115+
116+
117+
class _IndexById(NameMappingVisitor[Dict[int, MappedField]]):
118+
result: Dict[int, MappedField]
119+
120+
def __init__(self) -> None:
121+
self.result = {}
122+
123+
def mapping(self, nm: NameMapping, field_results: Dict[int, MappedField]) -> Dict[int, MappedField]:
124+
return field_results
125+
126+
def fields(self, struct: List[MappedField], field_results: List[Dict[int, MappedField]]) -> Dict[int, MappedField]:
127+
return self.result
128+
129+
def field(self, field: MappedField, field_result: Dict[int, MappedField]) -> Dict[int, MappedField]:
130+
if field.field_id in self.result:
131+
raise ValueError(f"Invalid mapping: ID {field.field_id} is not unique")
132+
133+
self.result[field.field_id] = field
134+
135+
return self.result
136+
137+
138+
class _IndexByName(NameMappingVisitor[Dict[str, MappedField]]):
139+
def mapping(self, nm: NameMapping, field_results: Dict[str, MappedField]) -> Dict[str, MappedField]:
140+
return field_results
141+
142+
def fields(self, struct: List[MappedField], field_results: List[Dict[str, MappedField]]) -> Dict[str, MappedField]:
143+
return dict(ChainMap(*field_results))
144+
145+
def field(self, field: MappedField, field_result: Dict[str, MappedField]) -> Dict[str, MappedField]:
146+
result: Dict[str, MappedField] = {
147+
f"{field_name}.{key}": result_field for key, result_field in field_result.items() for field_name in field.names
148+
}
149+
150+
for name in field.names:
151+
result[name] = field
152+
153+
return result
154+
155+
156+
@singledispatch
157+
def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[T]) -> T:
158+
"""Traverse the name mapping in post-order traversal."""
159+
raise NotImplementedError(f"Cannot visit non-type: {obj}")
160+
161+
162+
@visit_name_mapping.register(NameMapping)
163+
def _(obj: NameMapping, visitor: NameMappingVisitor[T]) -> T:
164+
return visitor.mapping(obj, visit_name_mapping(obj.root, visitor))
165+
166+
167+
@visit_name_mapping.register(list)
168+
def _(fields: List[MappedField], visitor: NameMappingVisitor[T]) -> T:
169+
results = [visitor.field(field, visit_name_mapping(field.fields, visitor)) for field in fields]
170+
return visitor.fields(fields, results)
171+
172+
173+
def load_mapping_from_json(mapping: str) -> NameMapping:
174+
return NameMapping.model_validate_json(mapping)
175+
176+
177+
class _CreateMapping(SchemaVisitor[List[MappedField]]):
178+
def schema(self, schema: Schema, struct_result: List[MappedField]) -> List[MappedField]:
179+
return struct_result
180+
181+
def struct(self, struct: StructType, field_results: List[List[MappedField]]) -> List[MappedField]:
182+
return [
183+
MappedField(field_id=field.field_id, names={field.name}, fields=result)
184+
for field, result in zip(struct.fields, field_results)
185+
]
186+
187+
def field(self, field: NestedField, field_result: List[MappedField]) -> List[MappedField]:
188+
return field_result
189+
190+
def list(self, list_type: ListType, element_result: List[MappedField]) -> List[MappedField]:
191+
return [MappedField(field_id=list_type.element_id, names={"element"}, fields=element_result)]
192+
193+
def map(self, map_type: MapType, key_result: List[MappedField], value_result: List[MappedField]) -> List[MappedField]:
194+
return [
195+
MappedField(field_id=map_type.key_id, names={"key"}, fields=key_result),
196+
MappedField(field_id=map_type.value_id, names={"value"}, fields=value_result),
197+
]
198+
199+
def primitive(self, primitive: PrimitiveType) -> List[MappedField]:
200+
return []
201+
202+
203+
def create_mapping_from_schema(schema: Schema) -> NameMapping:
204+
return NameMapping(visit(schema, _CreateMapping()))

0 commit comments

Comments
 (0)