11import logging
22import math
3+ import os
34import warnings
45from dataclasses import dataclass
56from typing import (
67 TYPE_CHECKING ,
78 Any ,
89 Callable ,
910 Dict ,
11+ Iterable ,
1012 Iterator ,
1113 List ,
1214 Literal ,
2022
2123import ray
2224from ray ._private .arrow_utils import get_pyarrow_version
25+ from ray .data ._internal .arrow_block import ArrowBlockAccessor
2326from ray .data ._internal .progress_bar import ProgressBar
2427from ray .data ._internal .remote_fn import cached_remote_fn
2528from ray .data ._internal .util import (
5255
5356if TYPE_CHECKING :
5457 import pyarrow
58+ from pyarrow import parquet as pq
5559 from pyarrow .dataset import ParquetFileFragment
5660
5761
100104PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS = 1024
101105
102106
107+ _BATCH_SIZE_PRESERVING_STUB_COL_NAME = "__bsp_stub"
108+
109+
103110class _ParquetFragment :
104111 """This wrapper class is created to avoid utilizing `ParquetFileFragment` original
105112 serialization protocol that actually does network RPCs during serialization
@@ -434,51 +441,25 @@ def read_fragments(
434441 # Ensure that we're reading at least one dataset fragment.
435442 assert len (fragments ) > 0
436443
437- import pyarrow as pa
438-
439444 logger .debug (f"Reading { len (fragments )} parquet fragments" )
440-
441- use_threads = to_batches_kwargs .pop ("use_threads" , False )
442- batch_size = to_batches_kwargs .pop ("batch_size" , default_read_batch_size_rows )
443445 for fragment in fragments :
444- partitions = {}
445- if partitioning is not None :
446- parse = PathPartitionParser (partitioning )
447- partitions = parse (fragment .original .path )
448-
449- # Filter out partitions that aren't in the user-specified columns list.
450- if partition_columns is not None :
451- partitions = {
452- field_name : value
453- for field_name , value in partitions .items ()
454- if field_name in partition_columns
455- }
456-
457- def get_batch_iterable ():
458- if batch_size is not None :
459- to_batches_kwargs ["batch_size" ] = batch_size
460-
461- return fragment .original .to_batches (
462- use_threads = use_threads ,
463- columns = data_columns ,
464- schema = schema ,
465- ** to_batches_kwargs ,
466- )
467-
468446 # S3 can raise transient errors during iteration, and PyArrow doesn't expose a
469447 # way to retry specific batches.
470448 ctx = ray .data .DataContext .get_current ()
471- for batch in iterate_with_retry (
472- get_batch_iterable , "load batch" , match = ctx .retried_io_errors
449+ for table in iterate_with_retry (
450+ lambda : _read_batches_from (
451+ fragment .original ,
452+ schema = schema ,
453+ data_columns = data_columns ,
454+ partition_columns = partition_columns ,
455+ partitioning = partitioning ,
456+ include_path = include_paths ,
457+ batch_size = default_read_batch_size_rows ,
458+ to_batches_kwargs = to_batches_kwargs ,
459+ ),
460+ "reading batches" ,
461+ match = ctx .retried_io_errors ,
473462 ):
474- table = pa .Table .from_batches ([batch ], schema = schema )
475- if include_paths :
476- table = BlockAccessor .for_block (table ).fill_column (
477- "path" , fragment .original .path
478- )
479- if partitions :
480- table = _add_partitions_to_table (partitions , table )
481-
482463 # If the table is empty, drop it.
483464 if table .num_rows > 0 :
484465 if block_udf is not None :
@@ -487,6 +468,112 @@ def get_batch_iterable():
487468 yield table
488469
489470
471+ def _read_batches_from (
472+ fragment : "ParquetFileFragment" ,
473+ * ,
474+ schema : "pyarrow.Schema" ,
475+ data_columns : Optional [List [str ]],
476+ partition_columns : Optional [List [str ]],
477+ partitioning : Partitioning ,
478+ filter_expr : Optional ["pyarrow.dataset.Expression" ] = None ,
479+ batch_size : Optional [int ] = None ,
480+ include_path : bool = False ,
481+ use_threads : bool = False ,
482+ to_batches_kwargs : Optional [Dict [str , Any ]] = None ,
483+ ) -> Iterable ["pyarrow.Table" ]:
484+ """Get an iterable of batches from a parquet fragment."""
485+
486+ import pyarrow as pa
487+
488+ # Copy to avoid modifying passed in arg
489+ to_batches_kwargs = dict (to_batches_kwargs or {})
490+
491+ # NOTE: Passed in kwargs overrides always take precedence
492+ # TODO deprecate to_batches_kwargs
493+ use_threads = to_batches_kwargs .pop ("use_threads" , use_threads )
494+ filter_expr = to_batches_kwargs .pop ("filter" , filter_expr )
495+ # NOTE: Arrow's ``to_batches`` expects ``batch_size`` as an int
496+ if batch_size is not None :
497+ to_batches_kwargs .setdefault ("batch_size" , batch_size )
498+
499+ partition_col_values = _parse_partition_column_values (
500+ fragment , partition_columns , partitioning
501+ )
502+
503+ try :
504+ for batch in fragment .to_batches (
505+ columns = data_columns ,
506+ filter = filter_expr ,
507+ schema = schema ,
508+ use_threads = use_threads ,
509+ ** to_batches_kwargs ,
510+ ):
511+ table = pa .Table .from_batches ([batch ])
512+
513+ if include_path :
514+ table = ArrowBlockAccessor .for_block (table ).fill_column (
515+ "path" , fragment .path
516+ )
517+
518+ if partition_col_values :
519+ table = _add_partitions_to_table (partition_col_values , table )
520+
521+ # ``ParquetFileFragment.to_batches`` returns ``RecordBatch``,
522+ # which could have empty projection (ie ``num_columns`` == 0)
523+ # while having non-empty rows (ie ``num_rows`` > 0), which
524+ # could occur when list of requested columns is empty.
525+ #
526+ # However, when ``RecordBatches`` are concatenated using
527+ # ``pyarrow.concat_tables`` it will return a single ``Table``
528+ # with 0 columns and therefore 0 rows (since ``Table``s number of
529+ # rows is determined as the length of its columns).
530+ #
531+ # To avoid running into this pitfall, we introduce a stub column
532+ # holding just nulls to maintain invariance of the number of rows.
533+ #
534+ # NOTE: There's no impact from this as the binary size of the
535+ # extra column is basically 0
536+ if table .num_columns == 0 and table .num_rows > 0 :
537+ table = table .append_column (
538+ _BATCH_SIZE_PRESERVING_STUB_COL_NAME , pa .nulls (table .num_rows )
539+ )
540+
541+ yield table
542+
543+ except pa .lib .ArrowInvalid as e :
544+ error_message = str (e )
545+ if "No match for FieldRef.Name" in error_message and filter_expr is not None :
546+ filename = os .path .basename (fragment .path )
547+ file_columns = set (fragment .physical_schema .names )
548+ raise RuntimeError (
549+ f"Filter expression: '{ filter_expr } ' failed on parquet "
550+ f"file: '{ filename } ' with columns: { file_columns } "
551+ )
552+ raise
553+
554+
555+ def _parse_partition_column_values (
556+ fragment : "ParquetFileFragment" ,
557+ partition_columns : Optional [List [str ]],
558+ partitioning : Partitioning ,
559+ ):
560+ partitions = {}
561+
562+ if partitioning is not None :
563+ parse = PathPartitionParser (partitioning )
564+ partitions = parse (fragment .path )
565+
566+ # Filter out partitions that aren't in the user-specified columns list.
567+ if partition_columns is not None :
568+ partitions = {
569+ field_name : value
570+ for field_name , value in partitions .items ()
571+ if field_name in partition_columns
572+ }
573+
574+ return partitions
575+
576+
490577def _fetch_parquet_file_info (
491578 fragment : _ParquetFragment ,
492579 * ,
@@ -690,13 +777,18 @@ def _sample_fragments(
690777
691778
692779def _add_partitions_to_table (
693- partitions : Dict [str , PartitionDataType ], table : "pyarrow.Table"
780+ partition_col_values : Dict [str , PartitionDataType ], table : "pyarrow.Table"
694781) -> "pyarrow.Table" :
695782
696- for field_name , value in partitions .items ():
697- field_index = table .schema .get_field_index (field_name )
783+ for partition_col , value in partition_col_values .items ():
784+ field_index = table .schema .get_field_index (partition_col )
698785 if field_index == - 1 :
699- table = BlockAccessor .for_block (table ).fill_column (field_name , value )
786+ table = BlockAccessor .for_block (table ).fill_column (partition_col , value )
787+ elif log_once (f"duplicate_partition_field_{ partition_col } " ):
788+ logger .warning (
789+ f"The partition field '{ partition_col } ' also exists in the Parquet "
790+ f"file. Ray Data will default to using the value in the Parquet file."
791+ )
700792
701793 return table
702794
@@ -747,7 +839,11 @@ def emit_file_extensions_future_warning(future_file_extensions: List[str]):
747839
748840
749841def _infer_schema (
750- parquet_dataset , schema , columns , partitioning , _block_udf
842+ parquet_dataset : "pq.ParquetDataset" ,
843+ schema : "pyarrow.Schema" ,
844+ columns : Optional [List [str ]],
845+ partitioning ,
846+ _block_udf ,
751847) -> "pyarrow.Schema" :
752848 """Infer the schema of read data using the user-specified parameters."""
753849 import pyarrow as pa
@@ -760,7 +856,7 @@ def _infer_schema(
760856 partitioning , inferred_schema , parquet_dataset
761857 )
762858
763- if columns :
859+ if columns is not None :
764860 inferred_schema = pa .schema (
765861 [inferred_schema .field (column ) for column in columns ],
766862 inferred_schema .metadata ,
0 commit comments