Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging
import math
import os
Expand Down Expand Up @@ -308,6 +309,7 @@ def __init__(
partitioning: Optional[Partitioning] = Partitioning("hive"),
shuffle: Union[Literal["files"], None] = None,
include_paths: bool = False,
include_row_hash: bool = False,
file_extensions: Optional[List[str]] = None,
):
super().__init__()
Expand Down Expand Up @@ -440,6 +442,12 @@ def __init__(
)
self._file_metadata_shuffler = None
self._include_paths = include_paths
self._include_row_hash = include_row_hash
if self._include_row_hash and "row_hash" in pq_ds.schema.names:
logger.warning(
"The Parquet file(s) already contain a column named 'row_hash'. "
"It will be overwritten by the generated row hash column."
)
self._partitioning = partitioning
_validate_shuffle_arg(shuffle)
self._shuffle = shuffle
Expand Down Expand Up @@ -503,6 +511,7 @@ def get_read_tasks(
projected_columns=self.get_current_projection(),
_block_udf=self._block_udf,
include_paths=self._include_paths,
include_row_hash=self._include_row_hash,
)

read_tasks = []
Expand Down Expand Up @@ -535,6 +544,7 @@ def get_read_tasks(
partition_columns,
read_schema,
include_paths,
include_row_hash,
partitioning,
) = (
self._block_udf,
Expand All @@ -545,6 +555,7 @@ def get_read_tasks(
self._get_partition_columns(),
self._read_schema,
self._include_paths,
self._include_row_hash,
self._partitioning,
)

Expand All @@ -560,6 +571,7 @@ def get_read_tasks(
read_schema,
f,
include_paths,
include_row_hash,
partitioning,
filter_expr,
),
Expand Down Expand Up @@ -603,6 +615,8 @@ def get_current_projection(self) -> Optional[List[str]]:
# via _derive_schema, so we only need to add it when there is a projection.
if self._include_paths and "path" not in result:
result = result + ["path"]
if self._include_row_hash and "row_hash" not in result:
result = result + ["row_hash"]

return result

Expand Down Expand Up @@ -662,11 +676,13 @@ def _get_data_columns(self) -> Optional[List[str]]:

# Get partition columns and filter them out from the projection
partition_cols = self._partition_columns
# Also filter out "path" column if include_paths is True, as it's a
# synthetic column added after reading from the file
# Also filter out synthetic columns (path, row_hash) as they are
# added after reading from the file
cols_to_filter = set(partition_cols)
if self._include_paths:
cols_to_filter.add("path")
if self._include_row_hash:
cols_to_filter.add("row_hash")
data_cols = [
col for col in self._projection_map.keys() if col not in cols_to_filter
]
Expand Down Expand Up @@ -744,6 +760,7 @@ def _derive_schema(
projected_columns: Optional[List[str]],
_block_udf,
include_paths: bool = False,
include_row_hash: bool = False,
) -> "pyarrow.Schema":
"""Derives target schema for read operation"""

Expand Down Expand Up @@ -777,6 +794,9 @@ def _derive_schema(
if include_paths and target_schema.get_field_index("path") == -1:
target_schema = target_schema.append(pa.field("path", pa.string()))

if include_row_hash and target_schema.get_field_index("row_hash") == -1:
target_schema = target_schema.append(pa.field("row_hash", pa.uint64()))

# Project schema if necessary
if projected_columns is not None:
target_schema = pa.schema(
Expand Down Expand Up @@ -813,6 +833,7 @@ def read_fragments(
schema: Optional[Union[type, "pyarrow.lib.Schema"]],
fragments: List[_ParquetFragment],
include_paths: bool,
include_row_hash: bool,
partitioning: Partitioning,
filter_expr: Optional["pyarrow.dataset.Expression"] = None,
) -> Iterator["pyarrow.Table"]:
Expand All @@ -836,6 +857,7 @@ def read_fragments(
partition_columns=partition_columns,
partitioning=partitioning,
include_path=include_paths,
include_row_hash=include_row_hash,
filter_expr=filter_expr,
batch_size=default_read_batch_size_rows,
to_batches_kwargs=to_batches_kwargs,
Expand All @@ -862,6 +884,7 @@ def _read_batches_from(
filter_expr: Optional["pyarrow.dataset.Expression"] = None,
batch_size: Optional[int] = None,
include_path: bool = False,
include_row_hash: bool = False,
use_threads: bool = False,
to_batches_kwargs: Optional[Dict[str, Any]] = None,
) -> Iterable["pyarrow.Table"]:
Expand Down Expand Up @@ -893,8 +916,11 @@ def _read_batches_from(
fragment, partition_columns, partitioning
)

row_offset = 0

def _generate_tables() -> "pa.Table":
"""Inner generator that yields tables without renaming."""
nonlocal row_offset
try:
for batch in fragment.to_batches(
columns=data_columns,
Expand All @@ -913,6 +939,15 @@ def _generate_tables() -> "pa.Table":
"path", fragment.path
)

if include_row_hash:
hashes = _compute_row_hashes(
fragment.path, row_offset, table.num_rows
)
table = ArrowBlockAccessor.for_block(table).fill_column(
"row_hash", pa.array(hashes, type=pa.uint64())
)
row_offset += table.num_rows

# ``ParquetFileFragment.to_batches`` returns ``RecordBatch``,
# which could have empty projection (ie ``num_columns`` == 0)
# while having non-empty rows (ie ``num_rows`` > 0), which
Expand Down Expand Up @@ -955,6 +990,33 @@ def _generate_tables() -> "pa.Table":
)


def _compute_row_hashes(file_path: str, start_row: int, num_rows: int) -> np.ndarray:
"""Compute deterministic uint64 hashes from file path and row position.

Hashes the file path with MD5 to obtain a 64-bit seed, adds the row indices,
then applies the splitmix64 finalizer (a bijective 64-bit mixing function) to
produce well-distributed, reproducible hashes. Fully vectorized via numpy.
"""
path_seed = np.uint64(
int.from_bytes(
hashlib.md5(file_path.encode("utf-8")).digest()[:8], byteorder="little"
)
)
keys = path_seed + np.arange(start_row, start_row + num_rows, dtype=np.uint64)

# splitmix64 finalizer – a bijective 64-bit mixing function from
# Steele, Lea & Flood, "Fast Splittable Pseudorandom Number Generators",
# OOPSLA 2014. Also used in Java's SplittableRandom.
# Reference: https://xorshift.di.unimi.it/splitmix64.c
keys ^= keys >> np.uint64(30)
keys *= np.uint64(0xBF58476D1CE4E5B9)
keys ^= keys >> np.uint64(27)
keys *= np.uint64(0x94D049BB133111EB)
keys ^= keys >> np.uint64(31)

return keys


def _parse_partition_column_values(
fragment: "ParquetFileFragment",
partition_columns: Optional[List[str]],
Expand Down
8 changes: 8 additions & 0 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,7 @@ def read_parquet(
partitioning: Optional[Partitioning] = Partitioning("hive"),
shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None,
include_paths: bool = False,
include_row_hash: bool = False,
file_extensions: Optional[List[str]] = ParquetDatasource._FILE_EXTENSIONS,
concurrency: Optional[int] = None,
override_num_blocks: Optional[int] = None,
Expand Down Expand Up @@ -1024,6 +1025,12 @@ def read_parquet(
shuffle the input files. Defaults to not shuffle with ``None``.
include_paths: If ``True``, include the path to each file. File paths are
stored in the ``'path'`` column.
include_row_hash: If ``True``, include a deterministic hash for each row.
The hash is a uint64 computed from the source file path and the row's
position within that file, making it reproducible across repeated reads
of the same data. Stored in the ``'row_hash'`` column. If a column
named ``'row_hash'`` already exists in the file, it will be
overwritten.
file_extensions: A list of file extensions to filter files by.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
Expand Down Expand Up @@ -1073,6 +1080,7 @@ def read_parquet(
partitioning=partitioning,
shuffle=shuffle,
include_paths=include_paths,
include_row_hash=include_row_hash,
file_extensions=file_extensions,
)
return read_datasource(
Expand Down
140 changes: 140 additions & 0 deletions python/ray/data/tests/datasource/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,146 @@ def test_include_paths_with_column_projection(
assert row["path"] == path


def test_include_row_hash(
ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default
):
path = os.path.join(tmp_path, "test.parquet")
table = pa.Table.from_pydict({"animals": ["cat", "dog", "bird"]})
pq.write_table(table, path)

ds = ray.data.read_parquet(path, include_row_hash=True)

schema_names = ds.schema().names
assert "row_hash" in schema_names

rows = ds.take_all()
hashes = [row["row_hash"] for row in rows]
assert len(hashes) == 3
assert len(set(hashes)) == 3, "Hashes must be unique"
assert all(isinstance(h, int) for h in hashes)


def test_include_row_hash_reproducible(
ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default
):
path = os.path.join(tmp_path, "test.parquet")
table = pa.Table.from_pydict({"val": list(range(10))})
pq.write_table(table, path)

hashes1 = [
row["row_hash"]
for row in ray.data.read_parquet(path, include_row_hash=True).take_all()
]
hashes2 = [
row["row_hash"]
for row in ray.data.read_parquet(path, include_row_hash=True).take_all()
]
assert hashes1 == hashes2, "Hashes must be reproducible across reads"


def test_include_row_hash_unique_across_files(
ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default
):
for i in range(3):
path = os.path.join(tmp_path, f"file{i}.parquet")
table = pa.Table.from_pydict({"val": [i * 10, i * 10 + 1]})
pq.write_table(table, path)

ds = ray.data.read_parquet(str(tmp_path), include_row_hash=True)
rows = ds.take_all()
hashes = [row["row_hash"] for row in rows]
assert len(hashes) == 6
assert len(set(hashes)) == 6, "Hashes must be unique across files"


def test_include_row_hash_same_data_different_files(
ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default
):
"""Files with identical content must produce different hashes because
the hash is derived from the file path, not the data."""
table = pa.Table.from_pydict({"val": [1, 2, 3]})
for name in ("a.parquet", "b.parquet", "c.parquet"):
pq.write_table(table, os.path.join(tmp_path, name))

ds = ray.data.read_parquet(str(tmp_path), include_row_hash=True)
rows = ds.take_all()
hashes = [row["row_hash"] for row in rows]
assert len(hashes) == 9
assert (
len(set(hashes)) == 9
), "Identical data in different files must produce distinct hashes"


def test_include_row_hash_with_column_projection(
ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default
):
path = os.path.join(tmp_path, "test.parquet")
table = pa.Table.from_pydict({"a": [1, 2], "b": [3, 4]})
pq.write_table(table, path)

ds = ray.data.read_parquet(path, columns=["a"], include_row_hash=True)
schema_names = ds.schema().names
assert "a" in schema_names
assert "b" not in schema_names
assert "row_hash" in schema_names

rows = ds.take_all()
assert len(rows) == 2
assert all("row_hash" in row and "a" in row and "b" not in row for row in rows)


def test_include_row_hash_with_include_paths(
ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default
):
path = os.path.join(tmp_path, "test.parquet")
table = pa.Table.from_pydict({"val": [1, 2]})
pq.write_table(table, path)

ds = ray.data.read_parquet(path, include_paths=True, include_row_hash=True)
schema_names = ds.schema().names
assert "path" in schema_names
assert "row_hash" in schema_names

df = ds.to_pandas()
assert "path" in df.columns
assert len(set(df["row_hash"])) == 2


def test_include_row_hash_existing_column(
ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default
):
"""When the file already has a 'row_hash' column, it should be
overwritten by the generated one without crashing."""
path = os.path.join(tmp_path, "test.parquet")
table = pa.Table.from_pydict({"val": [1, 2, 3], "row_hash": [100, 200, 300]})
pq.write_table(table, path)

ds = ray.data.read_parquet(path, include_row_hash=True)
rows = ds.take_all()
hashes = [row["row_hash"] for row in rows]
assert len(hashes) == 3
assert len(set(hashes)) == 3, "Hashes must be unique"
assert all(
h not in (100, 200, 300) for h in hashes
), "Generated hashes must overwrite the original column values"


def test_include_row_hash_existing_column_with_projection(
ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default
):
"""Column projection + pre-existing row_hash column should work."""
path = os.path.join(tmp_path, "test.parquet")
table = pa.Table.from_pydict({"val": [1, 2], "row_hash": [10, 20]})
pq.write_table(table, path)

ds = ray.data.read_parquet(path, columns=["val"], include_row_hash=True)
schema_names = ds.schema().names
assert "val" in schema_names
assert "row_hash" in schema_names
rows = ds.take_all()
assert all(row["row_hash"] not in (10, 20) for row in rows)


@pytest.mark.parametrize(
"fs,data_path",
[
Expand Down