Skip to content
2 changes: 1 addition & 1 deletion python/ray/air/tests/test_new_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ class MyTrainer(DataParallelTrainer):
def __init__(self, **kwargs):
def train_loop_fn():
train_ds = train.get_dataset_shard("train")
new_execution_options = train_ds._base_dataset.context.execution_options
new_execution_options = train_ds.get_context().execution_options
if original_execution_options.is_resource_limits_default():
# If the original resource limits are default, the new resource
# limits should be the default as well.
Expand Down
33 changes: 24 additions & 9 deletions python/ray/data/_internal/iterator/stream_split_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,14 @@ def create(
),
).remote(base_dataset, n, locality_hints)

return [
StreamSplitDataIterator(base_dataset, coord_actor, i, n) for i in range(n)
]
return [StreamSplitDataIterator(coord_actor, i, n) for i in range(n)]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm -- this StreamSplitDataIterator was being serialized 2x?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, each worker only gets one iterator so it's only serialized once. The other copy of the dataset object came from the TrainRunContext which was removed in #61953


def __init__(
self,
base_dataset: "Dataset",
coord_actor: ray.actor.ActorHandle,
output_split_idx: int,
world_size: int,
):
self._base_dataset = base_dataset
self._coord_actor = coord_actor
self._output_split_idx = output_split_idx
self._world_size = world_size
Expand Down Expand Up @@ -100,7 +96,6 @@ def gen_blocks() -> Iterator[RefBundle]:
schema=block_ref_and_md.schema,
)

self._base_dataset._plan._run_index += 1
# Return None for executor since StreamSplitDataIterator has its own
# mechanism for reporting prefetched bytes via SplitCoordinator.
return gen_blocks(), self._iter_stats, False, None
Expand All @@ -119,17 +114,17 @@ def stats(self) -> str:

def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
Comment thread
JasonLi1909 marked this conversation as resolved.
Outdated
"""Implements DataIterator."""
return self._base_dataset.schema()
return ray.get(self._coord_actor.get_dataset_schema.remote())
Comment thread
JasonLi1909 marked this conversation as resolved.
Outdated
Comment thread
cursor[bot] marked this conversation as resolved.
Comment thread
cursor[bot] marked this conversation as resolved.

def get_context(self) -> DataContext:
return self._base_dataset.context
return ray.get(self._coord_actor.get_dataset_context.remote())

def world_size(self) -> int:
"""Returns the number of splits total."""
return self._world_size

def _get_dataset_tag(self):
return f"{self._base_dataset.get_dataset_id()}_split_{self._output_split_idx}"
return ray.get(self._coord_actor.get_dataset_tag.remote(self._output_split_idx))


@ray.remote(num_cpus=0)
Expand Down Expand Up @@ -157,6 +152,8 @@ def __init__(
self._n = n
self._locality_hints = locality_hints
self._lock = threading.RLock()
self._dataset_state_lock = threading.Lock()
self._schema = None
self._current_executor = None

# Guarded by self._lock.
Expand All @@ -175,6 +172,24 @@ def __init__(
# Store the error raised from the `gen_epoch` call.
self._gen_epoch_error: Optional[Exception] = None

def get_dataset_context(self) -> DataContext:
return self._data_context

def get_dataset_tag(self, output_split_idx: int) -> str:
return f"{self._base_dataset.get_dataset_id()}_split_{output_split_idx}"
Comment thread
cursor[bot] marked this conversation as resolved.

def get_dataset_schema(self):
with self._dataset_state_lock:
if self._current_executor is not None and self._current_executor.is_alive():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read the PR description but still a bit confused -- Are you implying there are scenarios where the schema is None, but current executor is not None, and that's why we need this guard? Like, can the executor be running when there is no schema produced yet? I think adding a comment about the if guard would be very helpful for future readers

@JasonLi1909 JasonLi1909 Mar 27, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._schema is for caching. If get_dataset_schema was not called previously, then it is very possible for it to be empty when called during execution. This guard is to primary to prevent two executions on the same dataset- which can lead to deadlock. Will update pr description to be more clear

raise RuntimeError(
"Cannot call schema() during active dataset execution. "
Comment on lines +185 to +186

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you try running this without raising the error to see what happens? and add the output to the PR description appendix

@JasonLi1909 JasonLi1909 Mar 27, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran without the safe guard, it can lead to a hang because the second executor will try to schedule tasks but won't have any resources available, added to PR description.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for checking

"Call schema() before or after iterating over the dataset, or call "
"schema() directly on the source Dataset object."
)
if self._schema is None:
self._schema = self._base_dataset.schema()
Comment thread
JasonLi1909 marked this conversation as resolved.
Outdated
return self._schema
Comment thread
cursor[bot] marked this conversation as resolved.
Comment thread
cursor[bot] marked this conversation as resolved.

def stats(self) -> DatasetStats:
"""Returns stats from the base dataset."""
if self._current_executor:
Expand Down
89 changes: 89 additions & 0 deletions python/ray/data/tests/test_streaming_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,95 @@ def consume(self, split):
assert res == ["ok"] * num_splits


def test_streaming_split_schema_before_execution(ray_start_10_cpus_shared):
"""Test schema retrieval from splits before execution starts."""
ds = ray.data.range(20, override_num_blocks=20)
i1, i2 = ds.streaming_split(2, equal=True)

schema1 = i1.schema()
schema2 = i2.schema()

assert schema1 is not None
assert "id" in schema1.names
assert schema1 == schema2


def test_streaming_split_schema_during_execution(ray_start_10_cpus_shared):
"""Test schema retrieval from splits during execution."""
from ray._common.test_utils import SignalActor

# Use two signals to coordinate: `started` confirms the executor is running,
# `blocker` keeps map tasks alive so the executor stays active.
started = SignalActor.remote()
blocker = SignalActor.remote()

def blocking_fn(row):
ray.get(started.send.remote())
ray.get(blocker.wait.remote())
return row

ds = ray.data.range(20, override_num_blocks=20).map(blocking_fn)
i1, i2 = ds.streaming_split(2, equal=True)

@ray.remote
def consume(x):
for _ in x.iter_rows():
pass

# Start consumers — this triggers the executor on the coordinator.
refs = [consume.remote(i1), consume.remote(i2)]

# Wait until a map task has started, guaranteeing the executor is alive.
ray.get(started.wait.remote())

# schema() should raise because execution is active.
with pytest.raises(ray.exceptions.RayTaskError, match="Cannot call schema()"):
i1.schema()

# Unblock map tasks so consumers can finish.
ray.get(blocker.send.remote())
ray.get(refs)


def test_streaming_split_schema_after_execution(ray_start_10_cpus_shared):
"""Test schema retrieval after execution completes."""
ds = ray.data.range(20, override_num_blocks=20)
i1, i2 = ds.streaming_split(2, equal=True)

@ray.remote
def consume(x):
for _ in x.iter_rows():
pass

# Run a full epoch to completion.
ray.get([consume.remote(i1), consume.remote(i2)])

# schema() should work after execution finishes.
schema = i1.schema()
assert schema is not None
assert "id" in schema.names


def test_streaming_split_context(ray_start_10_cpus_shared):
"""Test that get_context() returns a valid DataContext from the coordinator."""
ds = ray.data.range(10)
i1, i2 = ds.streaming_split(2, equal=True)

ctx = i1.get_context()
assert isinstance(ctx, ray.data.DataContext)


def test_streaming_split_dataset_tag(ray_start_10_cpus_shared):
"""Test that _get_dataset_tag() returns correct tags from the coordinator."""
ds = ray.data.range(10)
i1, i2 = ds.streaming_split(2, equal=True)

tag1 = i1._get_dataset_tag()
tag2 = i2._get_dataset_tag()
assert "_split_0" in tag1
assert "_split_1" in tag2


def test_configure_spread_e2e(ray_start_10_cpus_shared, restore_data_context):
from ray import remote_function

Expand Down
12 changes: 6 additions & 6 deletions python/ray/train/v2/tests/test_data_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ def test_datasets_callback(ray_start_4_cpus):
# Under the V2 cluster autoscaler (default), the scaling policy registers training resources
# with the AutoscalingCoordinator, so exclude_resources should not be set.
assert (
processed_train_ds._base_dataset.context.execution_options.exclude_resources
processed_train_ds.get_context().execution_options.exclude_resources
== ExecutionResources.zero()
)
assert (
processed_valid_ds._base_dataset.context.execution_options.exclude_resources
processed_valid_ds.get_context().execution_options.exclude_resources
== ExecutionResources.zero()
)

Expand Down Expand Up @@ -666,11 +666,11 @@ def test_datasets_callback_v1_uses_exclude_resources(ray_start_4_cpus, monkeypat

# Under the V1 cluster autoscaler, exclude_resources should be set with training resources.
assert (
processed_train_ds._base_dataset.context.execution_options.exclude_resources
processed_train_ds.get_context().execution_options.exclude_resources
== ExecutionResources(cpu=NUM_WORKERS, gpu=NUM_WORKERS)
)
assert (
processed_valid_ds._base_dataset.context.execution_options.exclude_resources
processed_valid_ds.get_context().execution_options.exclude_resources
== ExecutionResources(cpu=NUM_WORKERS, gpu=NUM_WORKERS)
)

Expand Down Expand Up @@ -729,11 +729,11 @@ def test_v2_no_negative_exclude_resources(ray_start_4_cpus):
# Under the V2 cluster autoscaler (default), exclude_resources should be
# zero regardless of how many training resources are reserved.
assert (
processed_train_ds._base_dataset.context.execution_options.exclude_resources
processed_train_ds.get_context().execution_options.exclude_resources
== ExecutionResources.zero()
)
assert (
processed_valid_ds._base_dataset.context.execution_options.exclude_resources
processed_valid_ds.get_context().execution_options.exclude_resources
== ExecutionResources.zero()
)

Expand Down
Loading