Skip to content

Commit 40c665c

Browse files
xyuzhbveeramanicursoragent
authored andcommitted
[Data] Add optional filesystem parameter to download expression (ray-project#60677)
## Summary - Add optional `filesystem` parameter to the `download()` expression in Ray Data - Allows users to provide custom PyArrow filesystems with custom authentication credentials - If not specified, the filesystem is auto-detected from the path scheme (existing behavior) ## Test plan - [x] Verify existing download tests still pass - [x] Test with custom S3FileSystem with explicit credentials <!-- BUGBOT_STATUS --><sup><a href="https://cursor.com/dashboard?tab=bugbot">Cursor Bugbot</a> reviewed your changes and found no issues for commit <u>cfb1db1</u></sup><!-- /BUGBOT_STATUS --> --------- Signed-off-by: xyuzh <xinyzng@gmail.com> Signed-off-by: Xinyu Zhang <60529799+xyuzh@users.noreply.github.com> Co-authored-by: Balaji Veeramani <bveeramani@berkeley.edu> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: tiennguyentony <46289799+tiennguyentony@users.noreply.github.com>
1 parent ac9a081 commit 40c665c

File tree

5 files changed

+99
-21
lines changed

5 files changed

+99
-21
lines changed

python/ray/data/_internal/logical/operators/one_to_one_operator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ray.data.block import BlockMetadata
99

1010
if TYPE_CHECKING:
11+
import pyarrow
1112

1213
from ray.data.block import Schema
1314

@@ -115,6 +116,7 @@ def __init__(
115116
uri_column_names: List[str],
116117
output_bytes_column_names: List[str],
117118
ray_remote_args: Optional[Dict[str, Any]] = None,
119+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
118120
):
119121
super().__init__("Download", input_op, can_modify_num_rows=False)
120122
if len(uri_column_names) != len(output_bytes_column_names):
@@ -125,3 +127,4 @@ def __init__(
125127
self.uri_column_names = uri_column_names
126128
self.output_bytes_column_names = output_bytes_column_names
127129
self.ray_remote_args = ray_remote_args or {}
130+
self.filesystem = filesystem

python/ray/data/_internal/planner/plan_download_op.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import math
33
from concurrent.futures import ThreadPoolExecutor, as_completed
4-
from typing import Iterator, List
4+
from typing import Iterator, List, Optional
55
from urllib.parse import urlparse
66

77
import pyarrow as pa
@@ -48,6 +48,7 @@ def plan_download_op(
4848
uri_column_names_str = ", ".join(uri_column_names)
4949
output_bytes_column_names = op.output_bytes_column_names
5050
ray_remote_args = op.ray_remote_args
51+
filesystem = op.filesystem
5152

5253
# Import _get_udf from the main planner file
5354
from ray.data._internal.planner.plan_udf_map_op import (
@@ -70,7 +71,7 @@ def plan_download_op(
7071
PartitionActor,
7172
(),
7273
{},
73-
(uri_column_names, data_context),
74+
(uri_column_names, data_context, filesystem),
7475
{},
7576
compute=partition_compute,
7677
)
@@ -108,7 +109,7 @@ def plan_download_op(
108109

109110
fn, init_fn = _get_udf(
110111
download_bytes_threaded,
111-
(uri_column_names, output_bytes_column_names, data_context),
112+
(uri_column_names, output_bytes_column_names, data_context, filesystem),
112113
{},
113114
None,
114115
None,
@@ -167,10 +168,22 @@ def download_bytes_threaded(
167168
uri_column_names: List[str],
168169
output_bytes_column_names: List[str],
169170
data_context: DataContext,
171+
filesystem: Optional["pa.fs.FileSystem"] = None,
170172
) -> Iterator[pa.Table]:
171173
"""Optimized version that uses make_async_gen for concurrent downloads.
172174
173175
Supports downloading from multiple URI columns in a single operation.
176+
177+
Args:
178+
block: Input PyArrow table containing URI columns.
179+
uri_column_names: Names of columns containing URIs to download.
180+
output_bytes_column_names: Names for the output columns containing downloaded bytes.
181+
data_context: Ray Data context for configuration.
182+
filesystem: PyArrow filesystem to use for reading remote files.
183+
If None, the filesystem is auto-detected from the path scheme.
184+
185+
Yields:
186+
pa.Table: PyArrow table with the downloaded bytes added as new columns.
174187
"""
175188
if not isinstance(block, pa.Table):
176189
block = BlockAccessor.for_block(block).to_arrow()
@@ -192,8 +205,9 @@ def load_uri_bytes(uri_iterator):
192205
193206
Takes an iterator of URIs and yields bytes for each.
194207
Uses lazy filesystem resolution - resolves once and reuses for subsequent URIs.
208+
If a filesystem was provided explicitly, it will be used for all URIs.
195209
"""
196-
cached_fs = None
210+
cached_fs = filesystem
197211
for uri in uri_iterator:
198212
read_bytes = None
199213
try:
@@ -267,9 +281,15 @@ class PartitionActor:
267281

268282
INIT_SAMPLE_BATCH_SIZE = 25
269283

270-
def __init__(self, uri_column_names: List[str], data_context: DataContext):
284+
def __init__(
285+
self,
286+
uri_column_names: List[str],
287+
data_context: DataContext,
288+
filesystem: Optional["pa.fs.FileSystem"] = None,
289+
):
271290
self._uri_column_names = uri_column_names
272291
self._data_context = data_context
292+
self._filesystem = filesystem
273293
self._batch_size_estimate = None
274294

275295
def __call__(self, block: pa.Table) -> Iterator[pa.Table]:
@@ -345,7 +365,7 @@ def get_file_size(uri_path, fs):
345365
# Get the filesystem from the URIs (assumes all URIs use same filesystem for sampling)
346366
# This is for sampling the file sizes which doesn't require a full resolution of the paths.
347367
try:
348-
paths, fs = _resolve_paths_and_filesystem(uris)
368+
paths, fs = _resolve_paths_and_filesystem(uris, filesystem=self._filesystem)
349369
fs = RetryingPyFileSystem.wrap(
350370
fs, retryable_errors=self._data_context.retried_io_errors
351371
)

python/ray/data/dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ def with_column(
896896
uri_column_names=[expr.uri_column_name],
897897
output_bytes_column_names=[column_name],
898898
ray_remote_args=ray_remote_args,
899+
filesystem=expr.filesystem,
899900
)
900901
logical_plan = LogicalPlan(download_op, self.context)
901902
else:

python/ray/data/expressions.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -821,14 +821,39 @@ class _CallableClassSpec:
821821
cls: The original callable class type
822822
args: Positional arguments for the constructor
823823
kwargs: Keyword arguments for the constructor
824+
_cached_key: Pre-computed key that survives serialization
824825
"""
825826

826827
cls: type
827828
args: Tuple[Any, ...] = ()
828829
kwargs: Dict[str, Any] = field(default_factory=dict)
830+
_cached_key: Optional[Tuple] = field(default=None, compare=False, repr=False)
831+
832+
def __post_init__(self):
833+
"""Pre-compute and cache the key at construction time.
834+
835+
This ensures the same key survives serialization, since the cached
836+
key tuple (containing the already-computed repr strings) gets pickled
837+
and unpickled as-is.
838+
"""
839+
if self._cached_key is None:
840+
class_id = f"{self.cls.__module__}.{self.cls.__qualname__}"
841+
try:
842+
key = (
843+
class_id,
844+
self.args,
845+
tuple(sorted(self.kwargs.items())),
846+
)
847+
# Verify the key is actually hashable (args may contain lists)
848+
hash(key)
849+
except TypeError:
850+
# Fallback for unhashable args/kwargs - use repr for comparison
851+
key = (class_id, repr(self.args), repr(self.kwargs))
852+
# Use object.__setattr__ since dataclass is frozen
853+
object.__setattr__(self, "_cached_key", key)
829854

830855
def make_key(self) -> Tuple:
831-
"""Create a hashable key for UDF instance lookup.
856+
"""Return the pre-computed hashable key for UDF instance lookup.
832857
833858
The key uniquely identifies a UDF by its class and constructor arguments.
834859
This ensures that the same class with different constructor args
@@ -837,18 +862,7 @@ def make_key(self) -> Tuple:
837862
Returns:
838863
A hashable tuple that uniquely identifies this UDF configuration.
839864
"""
840-
try:
841-
key = (
842-
id(self.cls),
843-
self.args,
844-
tuple(sorted(self.kwargs.items())),
845-
)
846-
# Verify the key is actually hashable (args may contain lists)
847-
hash(key)
848-
return key
849-
except TypeError:
850-
# Fallback for unhashable args/kwargs - use repr for comparison
851-
return (id(self.cls), repr(self.args), repr(self.kwargs))
865+
return self._cached_key
852866

853867

854868
class _CallableClassUDF:
@@ -1304,6 +1318,7 @@ class DownloadExpr(Expr):
13041318
"""Expression that represents a download operation."""
13051319

13061320
uri_column_name: str
1321+
filesystem: "pyarrow.fs.FileSystem" = None
13071322
data_type: DataType = field(default_factory=lambda: DataType.binary(), init=False)
13081323

13091324
def structurally_equals(self, other: Any) -> bool:
@@ -1448,7 +1463,11 @@ def star() -> StarExpr:
14481463

14491464

14501465
@PublicAPI(stability="alpha")
1451-
def download(uri_column_name: str) -> DownloadExpr:
1466+
def download(
1467+
uri_column_name: str,
1468+
*,
1469+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
1470+
) -> DownloadExpr:
14521471
"""
14531472
Create a download expression that downloads content from URIs.
14541473
@@ -1458,6 +1477,8 @@ def download(uri_column_name: str) -> DownloadExpr:
14581477
14591478
Args:
14601479
uri_column_name: The name of the column containing URIs to download from
1480+
filesystem: PyArrow filesystem to use for reading remote files.
1481+
If None, the filesystem is auto-detected from the path scheme.
14611482
Returns:
14621483
A DownloadExpr that will download content from the specified URI column
14631484
@@ -1472,7 +1493,7 @@ def download(uri_column_name: str) -> DownloadExpr:
14721493
>>> # Add downloaded bytes column
14731494
>>> ds_with_bytes = ds.with_column("bytes", download("uri"))
14741495
"""
1475-
return DownloadExpr(uri_column_name=uri_column_name)
1496+
return DownloadExpr(uri_column_name=uri_column_name, filesystem=filesystem)
14761497

14771498

14781499
# ──────────────────────────────────────

python/ray/data/tests/test_download_expression.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,39 @@ def test_download_expression_with_pandas_blocks(self, tmp_path):
250250
finally:
251251
ctx.enable_pandas_block = old_enable_pandas_block
252252

253+
def test_download_expression_with_custom_filesystem(self, tmp_path):
254+
import pyarrow.fs as pafs
255+
256+
# 1. Setup paths
257+
subdir = tmp_path / "data"
258+
subdir.mkdir()
259+
260+
file_name = "test_file.txt"
261+
file_path = subdir / file_name
262+
sample_content = b"File content with custom fs"
263+
file_path.write_bytes(sample_content)
264+
265+
# 2. Setup SubTreeFileSystem
266+
# This treats 'subdir' as the root '/'
267+
base_fs = pafs.LocalFileSystem()
268+
custom_fs = pafs.SubTreeFileSystem(str(subdir), base_fs)
269+
270+
# 3. Create Dataset
271+
# Note: We use the relative 'file_name' because the FS is rooted at 'subdir'
272+
ds = ray.data.from_items([{"file_uri": file_name, "file_id": 0}])
273+
274+
# 4. Execute Download
275+
ds_with_downloads = ds.with_column(
276+
"content", download("file_uri", filesystem=custom_fs)
277+
)
278+
279+
# 5. Assertions
280+
results = ds_with_downloads.take_all()
281+
282+
assert len(results) == 1
283+
assert results[0]["content"] == sample_content
284+
assert results[0]["file_id"] == 0
285+
253286

254287
class TestDownloadExpressionErrors:
255288
"""Test error conditions and edge cases for download expressions."""

0 commit comments

Comments
 (0)