Skip to content

Commit 6e4fc72

Browse files
alexeykudinkinZacAttack
authored andcommitted
[Data] Fixing empty projection handling in ParquetDataSource (ray-project#56299)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? 1. Fixing empty projection handling in `ParquetDataSource` 2. Adding tests ## Related issue number <!-- For example: "Closes ray-project#1234" --> ## Checks - [ ] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [ ] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com> Signed-off-by: zac <zac@anyscale.com>
1 parent 177c947 commit 6e4fc72

File tree

3 files changed

+280
-47
lines changed

3 files changed

+280
-47
lines changed

python/ray/data/_internal/datasource/parquet_datasource.py

Lines changed: 141 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import logging
22
import math
3+
import os
34
import warnings
45
from dataclasses import dataclass
56
from typing import (
67
TYPE_CHECKING,
78
Any,
89
Callable,
910
Dict,
11+
Iterable,
1012
Iterator,
1113
List,
1214
Literal,
@@ -20,6 +22,7 @@
2022

2123
import ray
2224
from ray._private.arrow_utils import get_pyarrow_version
25+
from ray.data._internal.arrow_block import ArrowBlockAccessor
2326
from ray.data._internal.progress_bar import ProgressBar
2427
from ray.data._internal.remote_fn import cached_remote_fn
2528
from ray.data._internal.util import (
@@ -52,6 +55,7 @@
5255

5356
if TYPE_CHECKING:
5457
import pyarrow
58+
from pyarrow import parquet as pq
5559
from pyarrow.dataset import ParquetFileFragment
5660

5761

@@ -100,6 +104,9 @@
100104
PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS = 1024
101105

102106

107+
_BATCH_SIZE_PRESERVING_STUB_COL_NAME = "__bsp_stub"
108+
109+
103110
class _ParquetFragment:
104111
"""This wrapper class is created to avoid utilizing `ParquetFileFragment` original
105112
serialization protocol that actually does network RPCs during serialization
@@ -434,51 +441,25 @@ def read_fragments(
434441
# Ensure that we're reading at least one dataset fragment.
435442
assert len(fragments) > 0
436443

437-
import pyarrow as pa
438-
439444
logger.debug(f"Reading {len(fragments)} parquet fragments")
440-
441-
use_threads = to_batches_kwargs.pop("use_threads", False)
442-
batch_size = to_batches_kwargs.pop("batch_size", default_read_batch_size_rows)
443445
for fragment in fragments:
444-
partitions = {}
445-
if partitioning is not None:
446-
parse = PathPartitionParser(partitioning)
447-
partitions = parse(fragment.original.path)
448-
449-
# Filter out partitions that aren't in the user-specified columns list.
450-
if partition_columns is not None:
451-
partitions = {
452-
field_name: value
453-
for field_name, value in partitions.items()
454-
if field_name in partition_columns
455-
}
456-
457-
def get_batch_iterable():
458-
if batch_size is not None:
459-
to_batches_kwargs["batch_size"] = batch_size
460-
461-
return fragment.original.to_batches(
462-
use_threads=use_threads,
463-
columns=data_columns,
464-
schema=schema,
465-
**to_batches_kwargs,
466-
)
467-
468446
# S3 can raise transient errors during iteration, and PyArrow doesn't expose a
469447
# way to retry specific batches.
470448
ctx = ray.data.DataContext.get_current()
471-
for batch in iterate_with_retry(
472-
get_batch_iterable, "load batch", match=ctx.retried_io_errors
449+
for table in iterate_with_retry(
450+
lambda: _read_batches_from(
451+
fragment.original,
452+
schema=schema,
453+
data_columns=data_columns,
454+
partition_columns=partition_columns,
455+
partitioning=partitioning,
456+
include_path=include_paths,
457+
batch_size=default_read_batch_size_rows,
458+
to_batches_kwargs=to_batches_kwargs,
459+
),
460+
"reading batches",
461+
match=ctx.retried_io_errors,
473462
):
474-
table = pa.Table.from_batches([batch], schema=schema)
475-
if include_paths:
476-
table = BlockAccessor.for_block(table).fill_column(
477-
"path", fragment.original.path
478-
)
479-
if partitions:
480-
table = _add_partitions_to_table(partitions, table)
481-
482463
# If the table is empty, drop it.
483464
if table.num_rows > 0:
484465
if block_udf is not None:
@@ -487,6 +468,112 @@ def get_batch_iterable():
487468
yield table
488469

489470

471+
def _read_batches_from(
472+
fragment: "ParquetFileFragment",
473+
*,
474+
schema: "pyarrow.Schema",
475+
data_columns: Optional[List[str]],
476+
partition_columns: Optional[List[str]],
477+
partitioning: Partitioning,
478+
filter_expr: Optional["pyarrow.dataset.Expression"] = None,
479+
batch_size: Optional[int] = None,
480+
include_path: bool = False,
481+
use_threads: bool = False,
482+
to_batches_kwargs: Optional[Dict[str, Any]] = None,
483+
) -> Iterable["pyarrow.Table"]:
484+
"""Get an iterable of batches from a parquet fragment."""
485+
486+
import pyarrow as pa
487+
488+
# Copy to avoid modifying passed in arg
489+
to_batches_kwargs = dict(to_batches_kwargs or {})
490+
491+
# NOTE: Passed in kwargs overrides always take precedence
492+
# TODO deprecate to_batches_kwargs
493+
use_threads = to_batches_kwargs.pop("use_threads", use_threads)
494+
filter_expr = to_batches_kwargs.pop("filter", filter_expr)
495+
# NOTE: Arrow's ``to_batches`` expects ``batch_size`` as an int
496+
if batch_size is not None:
497+
to_batches_kwargs.setdefault("batch_size", batch_size)
498+
499+
partition_col_values = _parse_partition_column_values(
500+
fragment, partition_columns, partitioning
501+
)
502+
503+
try:
504+
for batch in fragment.to_batches(
505+
columns=data_columns,
506+
filter=filter_expr,
507+
schema=schema,
508+
use_threads=use_threads,
509+
**to_batches_kwargs,
510+
):
511+
table = pa.Table.from_batches([batch])
512+
513+
if include_path:
514+
table = ArrowBlockAccessor.for_block(table).fill_column(
515+
"path", fragment.path
516+
)
517+
518+
if partition_col_values:
519+
table = _add_partitions_to_table(partition_col_values, table)
520+
521+
# ``ParquetFileFragment.to_batches`` returns ``RecordBatch``,
522+
# which could have empty projection (ie ``num_columns`` == 0)
523+
# while having non-empty rows (ie ``num_rows`` > 0), which
524+
# could occur when list of requested columns is empty.
525+
#
526+
# However, when ``RecordBatches`` are concatenated using
527+
# ``pyarrow.concat_tables`` it will return a single ``Table``
528+
# with 0 columns and therefore 0 rows (since ``Table``s number of
529+
# rows is determined as the length of its columns).
530+
#
531+
# To avoid running into this pitfall, we introduce a stub column
532+
# holding just nulls to maintain invariance of the number of rows.
533+
#
534+
# NOTE: There's no impact from this as the binary size of the
535+
# extra column is basically 0
536+
if table.num_columns == 0 and table.num_rows > 0:
537+
table = table.append_column(
538+
_BATCH_SIZE_PRESERVING_STUB_COL_NAME, pa.nulls(table.num_rows)
539+
)
540+
541+
yield table
542+
543+
except pa.lib.ArrowInvalid as e:
544+
error_message = str(e)
545+
if "No match for FieldRef.Name" in error_message and filter_expr is not None:
546+
filename = os.path.basename(fragment.path)
547+
file_columns = set(fragment.physical_schema.names)
548+
raise RuntimeError(
549+
f"Filter expression: '{filter_expr}' failed on parquet "
550+
f"file: '{filename}' with columns: {file_columns}"
551+
)
552+
raise
553+
554+
555+
def _parse_partition_column_values(
556+
fragment: "ParquetFileFragment",
557+
partition_columns: Optional[List[str]],
558+
partitioning: Partitioning,
559+
):
560+
partitions = {}
561+
562+
if partitioning is not None:
563+
parse = PathPartitionParser(partitioning)
564+
partitions = parse(fragment.path)
565+
566+
# Filter out partitions that aren't in the user-specified columns list.
567+
if partition_columns is not None:
568+
partitions = {
569+
field_name: value
570+
for field_name, value in partitions.items()
571+
if field_name in partition_columns
572+
}
573+
574+
return partitions
575+
576+
490577
def _fetch_parquet_file_info(
491578
fragment: _ParquetFragment,
492579
*,
@@ -690,13 +777,18 @@ def _sample_fragments(
690777

691778

692779
def _add_partitions_to_table(
693-
partitions: Dict[str, PartitionDataType], table: "pyarrow.Table"
780+
partition_col_values: Dict[str, PartitionDataType], table: "pyarrow.Table"
694781
) -> "pyarrow.Table":
695782

696-
for field_name, value in partitions.items():
697-
field_index = table.schema.get_field_index(field_name)
783+
for partition_col, value in partition_col_values.items():
784+
field_index = table.schema.get_field_index(partition_col)
698785
if field_index == -1:
699-
table = BlockAccessor.for_block(table).fill_column(field_name, value)
786+
table = BlockAccessor.for_block(table).fill_column(partition_col, value)
787+
elif log_once(f"duplicate_partition_field_{partition_col}"):
788+
logger.warning(
789+
f"The partition field '{partition_col}' also exists in the Parquet "
790+
f"file. Ray Data will default to using the value in the Parquet file."
791+
)
700792

701793
return table
702794

@@ -747,7 +839,11 @@ def emit_file_extensions_future_warning(future_file_extensions: List[str]):
747839

748840

749841
def _infer_schema(
750-
parquet_dataset, schema, columns, partitioning, _block_udf
842+
parquet_dataset: "pq.ParquetDataset",
843+
schema: "pyarrow.Schema",
844+
columns: Optional[List[str]],
845+
partitioning,
846+
_block_udf,
751847
) -> "pyarrow.Schema":
752848
"""Infer the schema of read data using the user-specified parameters."""
753849
import pyarrow as pa
@@ -760,7 +856,7 @@ def _infer_schema(
760856
partitioning, inferred_schema, parquet_dataset
761857
)
762858

763-
if columns:
859+
if columns is not None:
764860
inferred_schema = pa.schema(
765861
[inferred_schema.field(column) for column in columns],
766862
inferred_schema.metadata,

python/ray/data/_internal/output_buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def has_next(self) -> bool:
9090
self._exceeded_buffer_row_limit() or self._exceeded_buffer_size_limit()
9191
)
9292

93-
def _exceeded_block_size_slice_limit(self, block: Block) -> bool:
93+
def _exceeded_block_size_slice_limit(self, block: BlockAccessor) -> bool:
9494
# Slice a block to respect the target max block size. We only do this if we are
9595
# more than 50% above the target block size, because this ensures that the last
9696
# block produced will be at least half the target block size.
@@ -101,7 +101,7 @@ def _exceeded_block_size_slice_limit(self, block: Block) -> bool:
101101
* self._output_block_size_option.target_max_block_size
102102
)
103103

104-
def _exceeded_block_row_slice_limit(self, block: Block) -> bool:
104+
def _exceeded_block_row_slice_limit(self, block: BlockAccessor) -> bool:
105105
# Slice a block to respect the target max rows per block. We only do this if we
106106
# are more than 50% above the target rows per block, because this ensures that
107107
# the last block produced will be at least half the target row count.

0 commit comments

Comments
 (0)