Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/ray/data/_internal/iterator/stream_split_iterator.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@scottjlee do you know if this would remove the need for #40116?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, worth to double check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that PR should still be needed. It propagates DataContext from the driver to trainer workers. While this PR propagates DataContext from the trainer workers to the split coordinator actor.

Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def __init__(
dataset.context.execution_options.locality_with_output = locality_hints
logger.info(f"Auto configuring locality_with_output={locality_hints}")

# Set current DataContext.
ray.data.DataContext._set_current(dataset.context)

self._base_dataset = dataset
self._n = n
self._equal = equal
Expand Down
26 changes: 26 additions & 0 deletions python/ray/data/tests/test_streaming_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,32 @@ def f(x):
ds2.take_all()


def test_streaming_split_with_custom_data_context(
ray_start_10_cpus_shared, restore_data_context
):
# Tests that custom DataContext can be properly propagated
# when using `streaming_split()`.
block_size = 123 * 1024 * 1024
data_context = DataContext.get_current()
data_context.target_max_block_size = block_size
data_context.set_config("foo", "bar")

def f(x):
assert DataContext.get_current().target_max_block_size == block_size
assert DataContext.get_current().get_config("foo") == "bar"
return x

num_splits = 2
splits = ray.data.range(10, parallelism=10).map(f).streaming_split(num_splits)

@ray.remote
def consume(split):
for _ in split.iter_rows():
pass

assert ray.get([consume.remote(split) for split in splits]) == [None] * num_splits


if __name__ == "__main__":
import sys

Expand Down