Skip to content

Commit b267c70

Browse files
steinitzurudolfix
andauthored
Add load_id to arrow tables in extract step instead of normalize (#1449)
* Add load_id to arrow tables in extract step instead of normalize * Test arrow load id in extract * Get normalize config without decorator * Normalize load ID column name * Load ID column goes last * adds update_table column order tests --------- Co-authored-by: Marcin Rudolf <[email protected]>
1 parent 14f06e4 commit b267c70

File tree

6 files changed

+200
-47
lines changed

6 files changed

+200
-47
lines changed

dlt/common/libs/pyarrow.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ def should_normalize_arrow_schema(
226226
schema: pyarrow.Schema,
227227
columns: TTableSchemaColumns,
228228
naming: NamingConvention,
229-
) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], TTableSchemaColumns]:
229+
add_load_id: bool = False,
230+
) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], bool, TTableSchemaColumns]:
230231
rename_mapping = get_normalized_arrow_fields_mapping(schema, naming)
231232
rev_mapping = {v: k for k, v in rename_mapping.items()}
232233
nullable_mapping = {k: v.get("nullable", True) for k, v in columns.items()}
@@ -238,39 +239,62 @@ def should_normalize_arrow_schema(
238239
if norm_name in nullable_mapping and field.nullable != nullable_mapping[norm_name]:
239240
nullable_updates[norm_name] = nullable_mapping[norm_name]
240241

241-
dlt_tables = list(map(naming.normalize_table_identifier, ("_dlt_id", "_dlt_load_id")))
242+
dlt_load_id_col = naming.normalize_table_identifier("_dlt_load_id")
243+
dlt_id_col = naming.normalize_table_identifier("_dlt_id")
244+
dlt_columns = {dlt_load_id_col, dlt_id_col}
245+
246+
# Do we need to add a load id column?
247+
if add_load_id and dlt_load_id_col in columns:
248+
try:
249+
schema.field(dlt_load_id_col)
250+
needs_load_id = False
251+
except KeyError:
252+
needs_load_id = True
253+
else:
254+
needs_load_id = False
242255

243256
# remove all columns that are dlt columns but are not present in arrow schema. we do not want to add such columns
244257
# that should happen in the normalizer
245258
columns = {
246259
name: column
247260
for name, column in columns.items()
248-
if name not in dlt_tables or name in rev_mapping
261+
if name not in dlt_columns or name in rev_mapping
249262
}
250263

251264
# check if nothing to rename
252265
skip_normalize = (
253-
list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys())
254-
) and not nullable_updates
255-
return not skip_normalize, rename_mapping, rev_mapping, nullable_updates, columns
266+
(list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys()))
267+
and not nullable_updates
268+
and not needs_load_id
269+
)
270+
return (
271+
not skip_normalize,
272+
rename_mapping,
273+
rev_mapping,
274+
nullable_updates,
275+
needs_load_id,
276+
columns,
277+
)
256278

257279

258280
def normalize_py_arrow_item(
259281
item: TAnyArrowItem,
260282
columns: TTableSchemaColumns,
261283
naming: NamingConvention,
262284
caps: DestinationCapabilitiesContext,
285+
load_id: Optional[str] = None,
263286
) -> TAnyArrowItem:
264287
"""Normalize arrow `item` schema according to the `columns`.
265288
266289
1. arrow schema field names will be normalized according to `naming`
267290
2. arrows columns will be reordered according to `columns`
268291
3. empty columns will be inserted if they are missing, types will be generated using `caps`
269292
4. arrow columns with different nullability than corresponding schema columns will be updated
293+
5. Add `_dlt_load_id` column if it is missing and `load_id` is provided
270294
"""
271295
schema = item.schema
272-
should_normalize, rename_mapping, rev_mapping, nullable_updates, columns = (
273-
should_normalize_arrow_schema(schema, columns, naming)
296+
should_normalize, rename_mapping, rev_mapping, nullable_updates, needs_load_id, columns = (
297+
should_normalize_arrow_schema(schema, columns, naming, load_id is not None)
274298
)
275299
if not should_normalize:
276300
return item
@@ -307,6 +331,18 @@ def normalize_py_arrow_item(
307331
new_fields.append(schema.field(idx).with_name(column_name))
308332
new_columns.append(item.column(idx))
309333

334+
if needs_load_id and load_id:
335+
# Storage efficient type for a column with constant value
336+
load_id_type = pyarrow.dictionary(pyarrow.int8(), pyarrow.string())
337+
new_fields.append(
338+
pyarrow.field(
339+
naming.normalize_table_identifier("_dlt_load_id"),
340+
load_id_type,
341+
nullable=False,
342+
)
343+
)
344+
new_columns.append(pyarrow.array([load_id] * item.num_rows, type=load_id_type))
345+
310346
# create desired type
311347
return item.__class__.from_arrays(new_columns, schema=pyarrow.schema(new_fields))
312348

@@ -383,6 +419,30 @@ def from_arrow_scalar(arrow_value: pyarrow.Scalar) -> Any:
383419
"""Sequence of tuples: (field index, field, generating function)"""
384420

385421

422+
def add_constant_column(
423+
item: TAnyArrowItem,
424+
name: str,
425+
data_type: pyarrow.DataType,
426+
value: Any = None,
427+
nullable: bool = True,
428+
index: int = -1,
429+
) -> TAnyArrowItem:
430+
"""Add column with a single value to the table.
431+
432+
Args:
433+
item: Arrow table or record batch
434+
name: The new column name
435+
data_type: The data type of the new column
436+
nullable: Whether the new column is nullable
437+
value: The value to fill the new column with
438+
index: The index at which to insert the new column. Defaults to -1 (append)
439+
"""
440+
field = pyarrow.field(name, pyarrow.dictionary(pyarrow.int8(), data_type), nullable=nullable)
441+
if index == -1:
442+
return item.append_column(field, pyarrow.array([value] * item.num_rows, type=field.type))
443+
return item.add_column(index, field, pyarrow.array([value] * item.num_rows, type=field.type))
444+
445+
386446
def pq_stream_with_new_columns(
387447
parquet_file: TFileOrPath, columns: TNewColumns, row_groups_per_read: int = 1
388448
) -> Iterator[pyarrow.Table]:

dlt/extract/extractors.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from copy import copy
22
from typing import Set, Dict, Any, Optional, List
33

4+
from dlt.common.configuration import known_sections, resolve_configuration, with_config
45
from dlt.common import logger
5-
from dlt.common.configuration.inject import with_config
66
from dlt.common.configuration.specs import BaseConfiguration, configspec
77
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
88
from dlt.common.exceptions import MissingDependencyException
@@ -21,6 +21,7 @@
2121
from dlt.extract.resource import DltResource
2222
from dlt.extract.items import TableNameMeta
2323
from dlt.extract.storage import ExtractorItemStorage
24+
from dlt.normalize.configuration import ItemsNormalizerConfiguration
2425

2526
try:
2627
from dlt.common.libs import pyarrow
@@ -215,13 +216,29 @@ class ObjectExtractor(Extractor):
215216
class ArrowExtractor(Extractor):
216217
"""Extracts arrow data items into parquet. Normalizes arrow items column names.
217218
Compares the arrow schema to actual dlt table schema to reorder the columns and to
218-
insert missing columns (without data).
219+
insert missing columns (without data). Adds _dlt_load_id column to the table if
220+
`add_dlt_load_id` is set to True in normalizer config.
219221
220222
We do things that normalizer should do here so we do not need to load and save parquet
221223
files again later.
222224
225+
Handles the following types:
226+
- `pyarrow.Table`
227+
- `pyarrow.RecordBatch`
228+
- `pandas.DataFrame` (is converted to arrow `Table` before processing)
223229
"""
224230

231+
def __init__(self, *args: Any, **kwargs: Any) -> None:
232+
super().__init__(*args, **kwargs)
233+
self._normalize_config = self._retrieve_normalize_config()
234+
235+
def _retrieve_normalize_config(self) -> ItemsNormalizerConfiguration:
236+
"""Get normalizer settings that are used here"""
237+
return resolve_configuration(
238+
ItemsNormalizerConfiguration(),
239+
sections=(known_sections.NORMALIZE, "parquet_normalizer"),
240+
)
241+
225242
def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> None:
226243
static_table_name = self._get_static_table_name(resource, meta)
227244
items = [
@@ -294,7 +311,13 @@ def _write_item(
294311
columns = columns or self.schema.get_table_columns(table_name)
295312
# Note: `items` is always a list here due to the conversion in `write_table`
296313
items = [
297-
pyarrow.normalize_py_arrow_item(item, columns, self.naming, self._caps)
314+
pyarrow.normalize_py_arrow_item(
315+
item,
316+
columns,
317+
self.naming,
318+
self._caps,
319+
load_id=self.load_id if self._normalize_config.add_dlt_load_id else None,
320+
)
298321
for item in items
299322
]
300323
# write items one by one
@@ -316,8 +339,22 @@ def _compute_table(
316339
else:
317340
arrow_table = copy(computed_table)
318341
arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(item.schema)
342+
343+
# Add load_id column if needed
344+
dlt_load_id_col = self.naming.normalize_table_identifier("_dlt_load_id")
345+
if (
346+
self._normalize_config.add_dlt_load_id
347+
and dlt_load_id_col not in arrow_table["columns"]
348+
):
349+
arrow_table["columns"][dlt_load_id_col] = {
350+
"name": dlt_load_id_col,
351+
"data_type": "text",
352+
"nullable": False,
353+
}
354+
319355
# normalize arrow table before merging
320356
arrow_table = self.schema.normalize_table_identifiers(arrow_table)
357+
321358
# issue warnings when overriding computed with arrow
322359
override_warn: bool = False
323360
for col_name, column in arrow_table["columns"].items():
@@ -343,6 +380,7 @@ def _compute_table(
343380
utils.merge_columns(
344381
arrow_table["columns"], computed_table["columns"], merge_columns=True
345382
)
383+
346384
return arrow_table
347385

348386
def _compute_and_update_table(

dlt/normalize/items_normalizers.py

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -228,37 +228,13 @@ class ArrowItemsNormalizer(ItemsNormalizer):
228228
REWRITE_ROW_GROUPS = 1
229229

230230
def _write_with_dlt_columns(
231-
self, extracted_items_file: str, root_table_name: str, add_load_id: bool, add_dlt_id: bool
231+
self, extracted_items_file: str, root_table_name: str, add_dlt_id: bool
232232
) -> List[TSchemaUpdate]:
233233
new_columns: List[Any] = []
234234
schema = self.schema
235235
load_id = self.load_id
236236
schema_update: TSchemaUpdate = {}
237237

238-
if add_load_id:
239-
table_update = schema.update_table(
240-
{
241-
"name": root_table_name,
242-
"columns": {
243-
"_dlt_load_id": {
244-
"name": "_dlt_load_id",
245-
"data_type": "text",
246-
"nullable": False,
247-
}
248-
},
249-
}
250-
)
251-
table_updates = schema_update.setdefault(root_table_name, [])
252-
table_updates.append(table_update)
253-
load_id_type = pa.dictionary(pa.int8(), pa.string())
254-
new_columns.append(
255-
(
256-
-1,
257-
pa.field("_dlt_load_id", load_id_type, nullable=False),
258-
lambda batch: pa.array([load_id] * batch.num_rows, type=load_id_type),
259-
)
260-
)
261-
262238
if add_dlt_id:
263239
table_update = schema.update_table(
264240
{
@@ -292,9 +268,9 @@ def _write_with_dlt_columns(
292268
items_count += batch.num_rows
293269
# we may need to normalize
294270
if is_native_arrow_writer and should_normalize is None:
295-
should_normalize, _, _, _, _ = pyarrow.should_normalize_arrow_schema(
271+
should_normalize = pyarrow.should_normalize_arrow_schema(
296272
batch.schema, columns_schema, schema.naming
297-
)
273+
)[0]
298274
if should_normalize:
299275
logger.info(
300276
f"When writing arrow table to {root_table_name} the schema requires"
@@ -366,25 +342,22 @@ def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSch
366342
base_schema_update = self._fix_schema_precisions(root_table_name, arrow_schema)
367343

368344
add_dlt_id = self.config.parquet_normalizer.add_dlt_id
369-
add_dlt_load_id = self.config.parquet_normalizer.add_dlt_load_id
370345
# if we need to add any columns or the file format is not parquet, we can't just import files
371-
must_rewrite = (
372-
add_dlt_id or add_dlt_load_id or self.item_storage.writer_spec.file_format != "parquet"
373-
)
346+
must_rewrite = add_dlt_id or self.item_storage.writer_spec.file_format != "parquet"
374347
if not must_rewrite:
375348
# in rare cases normalization may be needed
376-
must_rewrite, _, _, _, _ = pyarrow.should_normalize_arrow_schema(
349+
must_rewrite = pyarrow.should_normalize_arrow_schema(
377350
arrow_schema, self.schema.get_table_columns(root_table_name), self.schema.naming
378-
)
351+
)[0]
379352
if must_rewrite:
380353
logger.info(
381354
f"Table {root_table_name} parquet file {extracted_items_file} must be rewritten:"
382-
f" add_dlt_id: {add_dlt_id} add_dlt_load_id: {add_dlt_load_id} destination file"
355+
f" add_dlt_id: {add_dlt_id} destination file"
383356
f" format: {self.item_storage.writer_spec.file_format} or due to required"
384357
" normalization "
385358
)
386359
schema_update = self._write_with_dlt_columns(
387-
extracted_items_file, root_table_name, add_dlt_load_id, add_dlt_id
360+
extracted_items_file, root_table_name, add_dlt_id
388361
)
389362
return base_schema_update + schema_update
390363

docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,10 @@ add_dlt_load_id = true
8484
add_dlt_id = true
8585
```
8686

87-
Keep in mind that enabling these incurs some performance overhead because the `parquet` file needs to be read back from disk in chunks, processed and rewritten with new columns.
87+
Keep in mind that enabling these incurs some performance overhead:
88+
89+
- `add_dlt_load_id` has minimal overhead since the column is added to arrow table in memory during `extract` stage, before parquet file is written to disk
90+
- `add_dlt_id` adds the column during `normalize` stage after file has been extracted to disk. The file needs to be read back from disk in chunks, processed and rewritten with new columns
8891

8992
## Incremental loading with Arrow tables
9093

tests/common/schema/test_inference.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,3 +565,24 @@ def test_infer_on_incomplete_column(schema: Schema) -> None:
565565
assert i_column["x-special"] == "spec" # type: ignore[typeddict-item]
566566
assert i_column["primary_key"] is True
567567
assert i_column["data_type"] == "text"
568+
569+
570+
def test_update_table_adds_at_end(schema: Schema) -> None:
571+
row = {"evm": Wei(1)}
572+
_, new_table = schema.coerce_row("eth", None, row)
573+
schema.update_table(new_table)
574+
schema.update_table(
575+
{
576+
"name": new_table["name"],
577+
"columns": {
578+
"_dlt_load_id": {
579+
"name": "_dlt_load_id",
580+
"data_type": "text",
581+
"nullable": False,
582+
}
583+
},
584+
}
585+
)
586+
table = schema.tables["eth"]
587+
# place new columns at the end
588+
assert list(table["columns"].keys()) == ["evm", "_dlt_load_id"]

0 commit comments

Comments
 (0)