From 584b14ca29ebda8b75df8c8861a487391de734f4 Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Mon, 9 Mar 2026 16:10:25 -0700 Subject: [PATCH 01/11] [Data] Add get_dataset_schema to SplitCoordinator Move schema() resolution from StreamSplitDataIterator to the SplitCoordinator actor, which already holds the dataset. This avoids accessing _base_dataset directly from the iterator for schema calls, and adds thread-safe caching with a guard against schema resolution during active execution. Co-Authored-By: Claude Opus 4.6 Signed-off-by: JasonLi1909 --- .../iterator/stream_split_iterator.py | 16 +++++- .../data/tests/test_streaming_integration.py | 56 +++++++++++++------ 2 files changed, 54 insertions(+), 18 deletions(-) diff --git a/python/ray/data/_internal/iterator/stream_split_iterator.py b/python/ray/data/_internal/iterator/stream_split_iterator.py index f94efe2e71b9..5f9bca7f30df 100644 --- a/python/ray/data/_internal/iterator/stream_split_iterator.py +++ b/python/ray/data/_internal/iterator/stream_split_iterator.py @@ -110,7 +110,7 @@ def stats(self) -> str: def schema(self) -> Union[type, "pyarrow.lib.Schema"]: """Implements DataIterator.""" - return self._base_dataset.schema() + return ray.get(self._coord_actor.get_dataset_schema.remote()) def get_context(self) -> DataContext: return self._base_dataset.context @@ -152,6 +152,8 @@ def __init__( self._n = n self._locality_hints = locality_hints self._lock = threading.RLock() + self._dataset_state_lock = threading.Lock() + self._schema = None self._executor = None # Guarded by self._lock. @@ -175,6 +177,18 @@ def gen_epochs(): # Store the error raised from the `gen_epoch` call. self._gen_epoch_error: Optional[Exception] = None + def get_dataset_schema(self): + with self._dataset_state_lock: + if self._executor is not None: + raise RuntimeError( + "Cannot call schema() during active dataset execution. " + "Call schema() before iterating over the dataset, or call " + "schema() directly on the source Dataset object." + ) + if self._schema is None: + self._schema = self._base_dataset.schema() + return self._schema + def stats(self) -> DatasetStats: """Returns stats from the base dataset.""" if self._executor: diff --git a/python/ray/data/tests/test_streaming_integration.py b/python/ray/data/tests/test_streaming_integration.py index c494f0f7f77d..ce5b45605c7f 100644 --- a/python/ray/data/tests/test_streaming_integration.py +++ b/python/ray/data/tests/test_streaming_integration.py @@ -11,7 +11,7 @@ import ray from ray import cloudpickle from ray._common.test_utils import wait_for_condition -from ray.data._internal.execution.interfaces import ExecutionResources, RefBundle +from ray.data._internal.execution.interfaces import RefBundle from ray.data._internal.execution.operators.base_physical_operator import ( AllToAllOperator, ) @@ -387,24 +387,46 @@ def consume(self, split): assert res == ["ok"] * num_splits -@pytest.mark.skip( - reason="Incomplete implementation of _validate_dag causes other errors, so we " - "remove DAG validation for now; see https://github.com/ray-project/ray/pull/37829" -) -def test_e2e_option_propagation(ray_start_10_cpus_shared, restore_data_context): - def run(): - ray.data.range(5, override_num_blocks=5).map( - lambda x: x, compute=ray.data.ActorPoolStrategy(size=2) - ).take_all() +def test_streaming_split_schema_before_execution(ray_start_10_cpus_shared): + """Test schema retrieval from splits before execution starts.""" + ds = ray.data.range(20, override_num_blocks=20) + i1, i2 = ds.streaming_split(2, equal=True) - DataContext.get_current().execution_options.resource_limits = ExecutionResources() - run() + import concurrent.futures - DataContext.get_current().execution_options.resource_limits = ( - DataContext.get_current().execution_options.resource_limits.copy(cpu=1) - ) - with pytest.raises(ValueError): - run() + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + f1 = executor.submit(i1.schema) + f2 = executor.submit(i2.schema) + schema1 = f1.result() + schema2 = f2.result() + + assert schema1 is not None + assert "id" in schema1.names + assert schema1 == schema2 + + +def test_streaming_split_schema_during_execution(ray_start_10_cpus_shared): + """Test schema retrieval from splits during execution.""" + ds = ray.data.range(20, override_num_blocks=20) + i1, i2 = ds.streaming_split(2, equal=True) + + @ray.remote + def consume(x): + for _ in x.iter_rows(): + pass + + # Start iteration on both splits. + refs = [consume.remote(i1), consume.remote(i2)] + + # Give iteration time to start and create the executor. + time.sleep(2) + + # schema() should raise because execution is active. + with pytest.raises(ray.exceptions.RayTaskError, match="Cannot call schema()"): + i1.schema() + + # Let consumers finish. + ray.get(refs) def test_configure_spread_e2e(ray_start_10_cpus_shared, restore_data_context): From f83c3640b379498ea3e9b69198b7d27b3eb7cd31 Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Mon, 9 Mar 2026 16:38:27 -0700 Subject: [PATCH 02/11] [Data] Remove _base_dataset from StreamSplitDataIterator Move remaining _base_dataset usages (get_context, _get_dataset_tag) to SplitCoordinator and remove the redundant client-side _run_index increment. This fully decouples StreamSplitDataIterator from direct dataset access. Co-Authored-By: Claude Opus 4.6 --- .../iterator/stream_split_iterator.py | 16 ++++++++------- .../data/tests/test_streaming_integration.py | 20 +++++++++++++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/python/ray/data/_internal/iterator/stream_split_iterator.py b/python/ray/data/_internal/iterator/stream_split_iterator.py index 5f9bca7f30df..b78915337be8 100644 --- a/python/ray/data/_internal/iterator/stream_split_iterator.py +++ b/python/ray/data/_internal/iterator/stream_split_iterator.py @@ -53,18 +53,14 @@ def create( ), ).remote(_DatasetWrapper(base_dataset), n, locality_hints) - return [ - StreamSplitDataIterator(base_dataset, coord_actor, i, n) for i in range(n) - ] + return [StreamSplitDataIterator(coord_actor, i, n) for i in range(n)] def __init__( self, - base_dataset: "Dataset", coord_actor: ray.actor.ActorHandle, output_split_idx: int, world_size: int, ): - self._base_dataset = base_dataset self._coord_actor = coord_actor self._output_split_idx = output_split_idx self._world_size = world_size @@ -113,14 +109,14 @@ def schema(self) -> Union[type, "pyarrow.lib.Schema"]: return ray.get(self._coord_actor.get_dataset_schema.remote()) def get_context(self) -> DataContext: - return self._base_dataset.context + return ray.get(self._coord_actor.get_dataset_context.remote()) def world_size(self) -> int: """Returns the number of splits total.""" return self._world_size def _get_dataset_tag(self): - return f"{self._base_dataset.get_dataset_id()}_split_{self._output_split_idx}" + return ray.get(self._coord_actor.get_dataset_tag.remote(self._output_split_idx)) @ray.remote(num_cpus=0) @@ -177,6 +173,12 @@ def gen_epochs(): # Store the error raised from the `gen_epoch` call. self._gen_epoch_error: Optional[Exception] = None + def get_dataset_context(self) -> DataContext: + return self._data_context + + def get_dataset_tag(self, output_split_idx: int) -> str: + return f"{self._base_dataset.get_dataset_id()}_split_{output_split_idx}" + def get_dataset_schema(self): with self._dataset_state_lock: if self._executor is not None: diff --git a/python/ray/data/tests/test_streaming_integration.py b/python/ray/data/tests/test_streaming_integration.py index ce5b45605c7f..3ffdc995f1fd 100644 --- a/python/ray/data/tests/test_streaming_integration.py +++ b/python/ray/data/tests/test_streaming_integration.py @@ -429,6 +429,26 @@ def consume(x): ray.get(refs) +def test_streaming_split_context(ray_start_10_cpus_shared): + """Test that get_context() returns a valid DataContext from the coordinator.""" + ds = ray.data.range(10) + i1, i2 = ds.streaming_split(2, equal=True) + + ctx = i1.get_context() + assert isinstance(ctx, ray.data.DataContext) + + +def test_streaming_split_dataset_tag(ray_start_10_cpus_shared): + """Test that _get_dataset_tag() returns correct tags from the coordinator.""" + ds = ray.data.range(10) + i1, i2 = ds.streaming_split(2, equal=True) + + tag1 = i1._get_dataset_tag() + tag2 = i2._get_dataset_tag() + assert "_split_0" in tag1 + assert "_split_1" in tag2 + + def test_configure_spread_e2e(ray_start_10_cpus_shared, restore_data_context): from ray import remote_function From 993cc04ae78a21bb4c8fa983ff9a08da7f7c1478 Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Sun, 22 Mar 2026 01:37:26 -0700 Subject: [PATCH 03/11] updated get_dataset_schema condition to succeed after schema execution and updated tests Signed-off-by: JasonLi1909 --- .../iterator/stream_split_iterator.py | 4 +- .../data/tests/test_streaming_integration.py | 41 +++++++++++-------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/python/ray/data/_internal/iterator/stream_split_iterator.py b/python/ray/data/_internal/iterator/stream_split_iterator.py index b78915337be8..14c8ea6005f0 100644 --- a/python/ray/data/_internal/iterator/stream_split_iterator.py +++ b/python/ray/data/_internal/iterator/stream_split_iterator.py @@ -181,10 +181,10 @@ def get_dataset_tag(self, output_split_idx: int) -> str: def get_dataset_schema(self): with self._dataset_state_lock: - if self._executor is not None: + if self._executor is not None and self._executor.is_alive(): raise RuntimeError( "Cannot call schema() during active dataset execution. " - "Call schema() before iterating over the dataset, or call " + "Call schema() before or after iterating over the dataset, or call " "schema() directly on the source Dataset object." ) if self._schema is None: diff --git a/python/ray/data/tests/test_streaming_integration.py b/python/ray/data/tests/test_streaming_integration.py index 3ffdc995f1fd..1611fed08624 100644 --- a/python/ray/data/tests/test_streaming_integration.py +++ b/python/ray/data/tests/test_streaming_integration.py @@ -392,13 +392,8 @@ def test_streaming_split_schema_before_execution(ray_start_10_cpus_shared): ds = ray.data.range(20, override_num_blocks=20) i1, i2 = ds.streaming_split(2, equal=True) - import concurrent.futures - - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: - f1 = executor.submit(i1.schema) - f2 = executor.submit(i2.schema) - schema1 = f1.result() - schema2 = f2.result() + schema1 = i1.schema() + schema2 = i2.schema() assert schema1 is not None assert "id" in schema1.names @@ -411,22 +406,32 @@ def test_streaming_split_schema_during_execution(ray_start_10_cpus_shared): i1, i2 = ds.streaming_split(2, equal=True) @ray.remote - def consume(x): + def consume_and_check_schema(x): for _ in x.iter_rows(): - pass + with pytest.raises(RuntimeError, match="Cannot call schema()"): + x.schema() + break - # Start iteration on both splits. - refs = [consume.remote(i1), consume.remote(i2)] + ray.get([consume_and_check_schema.remote(i1), consume_and_check_schema.remote(i2)]) - # Give iteration time to start and create the executor. - time.sleep(2) - # schema() should raise because execution is active. - with pytest.raises(ray.exceptions.RayTaskError, match="Cannot call schema()"): - i1.schema() +def test_streaming_split_schema_after_execution(ray_start_10_cpus_shared): + """Test schema retrieval after execution completes.""" + ds = ray.data.range(20, override_num_blocks=20) + i1, i2 = ds.streaming_split(2, equal=True) + + @ray.remote + def consume(x): + for _ in x.iter_rows(): + pass + + # Run a full epoch to completion. + ray.get([consume.remote(i1), consume.remote(i2)]) - # Let consumers finish. - ray.get(refs) + # schema() should work after execution finishes. + schema = i1.schema() + assert schema is not None + assert "id" in schema.names def test_streaming_split_context(ray_start_10_cpus_shared): From eaadbf1387ce649f951bb9d849e23b95b721e10c Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Thu, 26 Mar 2026 13:13:59 -0700 Subject: [PATCH 04/11] fix tests Signed-off-by: JasonLi1909 --- .../train/v2/tests/test_data_integration.py | 398 +++++++++++++++++- 1 file changed, 394 insertions(+), 4 deletions(-) diff --git a/python/ray/train/v2/tests/test_data_integration.py b/python/ray/train/v2/tests/test_data_integration.py index 11db7bd62807..26e66ac0f94a 100644 --- a/python/ray/train/v2/tests/test_data_integration.py +++ b/python/ray/train/v2/tests/test_data_integration.py @@ -1,3 +1,4 @@ +import asyncio from unittest.mock import MagicMock import pytest @@ -131,12 +132,12 @@ def test_dataset_setup_callback(ray_start_4_cpus): # The callback should have excluded the resources reserved for training. assert ( - processed_train_ds._base_dataset.context.execution_options.exclude_resources - == ExecutionResources(cpu=NUM_WORKERS, gpu=NUM_WORKERS) + processed_train_ds.get_context().execution_options.exclude_resources + == ExecutionResources.zero() ) assert ( - processed_valid_ds._base_dataset.context.execution_options.exclude_resources - == ExecutionResources(cpu=NUM_WORKERS, gpu=NUM_WORKERS) + processed_valid_ds.get_context().execution_options.exclude_resources + == ExecutionResources.zero() ) @@ -293,6 +294,395 @@ def check_resource_limits(config): trainer.fit() +def test_per_dataset_execution_options_single(ray_start_4_cpus): + """Test that a single ExecutionOptions object applies to all datasets.""" + NUM_ROWS = 100 + NUM_WORKERS = 2 + + train_ds = ray.data.range(NUM_ROWS) + val_ds = ray.data.range(NUM_ROWS) + + # Create execution options with specific settings + execution_options = ExecutionOptions() + execution_options.preserve_order = True + execution_options.verbose_progress = True + + data_config = ray.train.DataConfig(execution_options=execution_options) + + def train_fn(): + train_shard = ray.train.get_dataset_shard("train") + val_shard = ray.train.get_dataset_shard("val") + + # Verify both datasets have the same execution options + assert train_shard.get_context().execution_options.preserve_order is True + assert train_shard.get_context().execution_options.verbose_progress is True + assert val_shard.get_context().execution_options.preserve_order is True + assert val_shard.get_context().execution_options.verbose_progress is True + + trainer = DataParallelTrainer( + train_fn, + datasets={"train": train_ds, "val": val_ds}, + dataset_config=data_config, + scaling_config=ray.train.ScalingConfig(num_workers=NUM_WORKERS), + ) + trainer.fit() + + +def test_per_dataset_execution_options_dict(ray_start_4_cpus): + """Test that a dict of ExecutionOptions maps to specific datasets, and datasets not in the dict get default ingest options. Also tests resource limits.""" + NUM_ROWS = 100 + NUM_WORKERS = 2 + + train_ds = ray.data.range(NUM_ROWS) + val_ds = ray.data.range(NUM_ROWS) + test_ds = ray.data.range(NUM_ROWS) + test_ds_2 = ray.data.range(NUM_ROWS) + + # Create different execution options for different datasets + train_options = ExecutionOptions() + train_options.preserve_order = True + train_options.verbose_progress = True + train_options.resource_limits = train_options.resource_limits.copy(cpu=4, gpu=2) + + val_options = ExecutionOptions() + val_options.preserve_order = False + val_options.verbose_progress = False + val_options.resource_limits = val_options.resource_limits.copy(cpu=2, gpu=1) + + execution_options_dict = { + "train": train_options, + "val": val_options, + } + + data_config = ray.train.DataConfig(execution_options=execution_options_dict) + + def train_fn(): + train_shard = ray.train.get_dataset_shard("train") + val_shard = ray.train.get_dataset_shard("val") + test_shard = ray.train.get_dataset_shard("test") + test_shard_2 = ray.train.get_dataset_shard("test_2") + + # Verify each dataset in the dict gets its specific options + assert train_shard.get_context().execution_options.preserve_order is True + assert train_shard.get_context().execution_options.verbose_progress is True + assert val_shard.get_context().execution_options.preserve_order is False + assert val_shard.get_context().execution_options.verbose_progress is False + + # Verify resource limits + assert train_shard.get_context().execution_options.resource_limits.cpu == 4 + assert train_shard.get_context().execution_options.resource_limits.gpu == 2 + assert val_shard.get_context().execution_options.resource_limits.cpu == 2 + assert val_shard.get_context().execution_options.resource_limits.gpu == 1 + + # Verify dataset not in the dict gets default options + assert ( + test_shard.get_context().execution_options.preserve_order + == test_shard_2.get_context().execution_options.preserve_order + ) + assert ( + test_shard.get_context().execution_options.verbose_progress + == test_shard_2.get_context().execution_options.verbose_progress + ) + assert ( + test_shard.get_context().execution_options.resource_limits.cpu + == test_shard_2.get_context().execution_options.resource_limits.cpu + ) + assert ( + test_shard.get_context().execution_options.resource_limits.gpu + == test_shard_2.get_context().execution_options.resource_limits.gpu + ) + + trainer = DataParallelTrainer( + train_fn, + datasets={ + "train": train_ds, + "val": val_ds, + "test": test_ds, + "test_2": test_ds_2, + }, + dataset_config=data_config, + scaling_config=ray.train.ScalingConfig(num_workers=NUM_WORKERS), + ) + trainer.fit() + + +def test_exclude_train_resources_applies_to_each_dataset(ray_start_4_cpus): + """Test that user-defined per-dataset exclude_resources are preserved. + Under the V2 cluster autoscaler (default), training resources are NOT added + to exclude_resources (they are handled by the AutoscalingCoordinator), so + only the user-defined values should appear.""" + NUM_ROWS = 100 + NUM_WORKERS = 2 + + # Create different execution options for different datasets + train_options = ExecutionOptions() + train_options.exclude_resources = train_options.exclude_resources.copy(cpu=2, gpu=1) + + test_options = ExecutionOptions() + test_options.exclude_resources = test_options.exclude_resources.copy(cpu=1, gpu=0) + + # val dataset not in dict, should get default options + execution_options_dict = { + "train": train_options, + "test": test_options, + } + data_config = ray.train.DataConfig(execution_options=execution_options_dict) + + def train_fn(): + # Under the V2 cluster autoscaler, only user-defined exclude_resources + # should be present. Training resources are NOT added to exclude_resources. + + # Check train dataset — only user-defined exclude_resources + train_ds = ray.train.get_dataset_shard("train") + train_exec_options = train_ds.get_context().execution_options + assert train_exec_options.is_resource_limits_default() + assert train_exec_options.exclude_resources.cpu == 2 + assert train_exec_options.exclude_resources.gpu == 1 + + # Check test dataset — only user-defined exclude_resources + test_ds = ray.train.get_dataset_shard("test") + test_exec_options = test_ds.get_context().execution_options + assert test_exec_options.is_resource_limits_default() + assert test_exec_options.exclude_resources.cpu == 1 + assert test_exec_options.exclude_resources.gpu == 0 + + # Check val dataset — no user-defined exclude_resources, so zero + val_ds = ray.train.get_dataset_shard("val") + val_exec_options = val_ds.get_context().execution_options + assert val_exec_options.is_resource_limits_default() + default_options = ray.train.DataConfig.default_ingest_options() + assert ( + val_exec_options.exclude_resources.cpu + == default_options.exclude_resources.cpu + ) + assert ( + val_exec_options.exclude_resources.gpu + == default_options.exclude_resources.gpu + ) + + trainer = DataParallelTrainer( + train_fn, + datasets={ + "train": ray.data.range(NUM_ROWS), + "test": ray.data.range(NUM_ROWS), + "val": ray.data.range(NUM_ROWS), + }, + dataset_config=data_config, + scaling_config=ray.train.ScalingConfig(num_workers=NUM_WORKERS), + ) + trainer.fit() + + +def test_datasets_callback_v1_uses_exclude_resources(ray_start_4_cpus, monkeypatch): + """Under the V1 cluster autoscaler, exclude_resources should still be set by DataConfig.""" + monkeypatch.setenv("RAY_DATA_CLUSTER_AUTOSCALER", "V1") + + NUM_WORKERS = 2 + + train_ds = ray.data.range(1000) + valid_ds = ray.data.range(1000) + + data_config = ray.train.DataConfig(datasets_to_split=["train"]) + scaling_config = ray.train.ScalingConfig( + num_workers=NUM_WORKERS, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1} + ) + + worker_group_context = WorkerGroupContext( + run_attempt_id="attempt_1", + train_fn_ref=DummyObjectRefWrapper(lambda: None), + num_workers=scaling_config.num_workers, + resources_per_worker=scaling_config.resources_per_worker, + ) + train_run_context = create_dummy_run_context( + dataset_config=data_config, + scaling_config=scaling_config, + ) + worker_group = DummyWorkerGroup( + train_run_context=train_run_context, + worker_group_context=worker_group_context, + ) + worker_group._start() + + callback = DatasetsSetupCallback( + train_run_context=train_run_context, + datasets={"train": train_ds, "valid": valid_ds}, + ) + dataset_manager_for_each_worker = callback.before_init_train_context( + worker_group.get_workers() + )["dataset_shard_provider"] + + dataset_manager = dataset_manager_for_each_worker[0] + processed_train_ds = dataset_manager.get_dataset_shard( + DatasetShardMetadata(dataset_name="train") + ) + processed_valid_ds = dataset_manager.get_dataset_shard( + DatasetShardMetadata(dataset_name="valid") + ) + + # Under the V1 cluster autoscaler, exclude_resources should be set with training resources. + assert ( + processed_train_ds.get_context().execution_options.exclude_resources + == ExecutionResources(cpu=NUM_WORKERS, gpu=NUM_WORKERS) + ) + assert ( + processed_valid_ds.get_context().execution_options.exclude_resources + == ExecutionResources(cpu=NUM_WORKERS, gpu=NUM_WORKERS) + ) + + +def test_v2_no_negative_exclude_resources(ray_start_4_cpus): + """Regression test: under the V2 cluster autoscaler, exclude_resources is not set, + so the scenario that previously caused negative global limits (small cluster, + multiple datasets, large training reservation) no longer fails. + + Before the fix, with 4 CPUs, 2 datasets (2 executors), and 3 CPUs for training: + each executor gets 4 // 2 = 2 CPUs, minus 3 exclude_resources = -1 CPU -> assertion error. + """ + NUM_WORKERS = 3 + + train_ds = ray.data.range(100) + valid_ds = ray.data.range(100) + + data_config = ray.train.DataConfig(datasets_to_split=["train"]) + # 3 workers * 1 CPU each = 3 CPUs for training, leaving 1 CPU for data. + # With 2 datasets, each data executor gets 4 // 2 = 2 CPUs from coordinator. + # If exclude_resources were set to 3, that would give 2 - 3 = -1 -> crash. + scaling_config = ray.train.ScalingConfig(num_workers=NUM_WORKERS) + + worker_group_context = WorkerGroupContext( + run_attempt_id="attempt_1", + train_fn_ref=DummyObjectRefWrapper(lambda: None), + num_workers=scaling_config.num_workers, + resources_per_worker=scaling_config.resources_per_worker, + ) + train_run_context = create_dummy_run_context( + dataset_config=data_config, + scaling_config=scaling_config, + ) + worker_group = DummyWorkerGroup( + train_run_context=train_run_context, + worker_group_context=worker_group_context, + ) + worker_group._start() + + callback = DatasetsSetupCallback( + train_run_context=train_run_context, + datasets={"train": train_ds, "valid": valid_ds}, + ) + dataset_manager_for_each_worker = callback.before_init_train_context( + worker_group.get_workers() + )["dataset_shard_provider"] + + dataset_manager = dataset_manager_for_each_worker[0] + processed_train_ds = dataset_manager.get_dataset_shard( + DatasetShardMetadata(dataset_name="train") + ) + processed_valid_ds = dataset_manager.get_dataset_shard( + DatasetShardMetadata(dataset_name="valid") + ) + + # Under the V2 cluster autoscaler (default), exclude_resources should be + # zero regardless of how many training resources are reserved. + assert ( + processed_train_ds.get_context().execution_options.exclude_resources + == ExecutionResources.zero() + ) + assert ( + processed_valid_ds.get_context().execution_options.exclude_resources + == ExecutionResources.zero() + ) + + +def test_fixed_scaling_policy_coordinator_lifecycle(): + """Test that FixedScalingPolicy registers training resources with the + AutoscalingCoordinator on start, periodically re-requests to keep + the reservation alive, and cancels on shutdown/abort.""" + from unittest.mock import patch + + from freezegun import freeze_time + + from ray.data._internal.cluster_autoscaler.default_autoscaling_coordinator import ( + ResourceRequestPriority, + ) + from ray.train.v2._internal.execution.scaling_policy import ( + AUTOSCALING_REQUESTS_EXPIRE_TIME_S, + AUTOSCALING_REQUESTS_INTERVAL_S, + ) + from ray.train.v2._internal.execution.scaling_policy.fixed import ( + FixedScalingPolicy, + ) + + resources_per_worker = {"CPU": 4, "GPU": 1} + num_workers = 2 + scaling_config = ray.train.ScalingConfig( + num_workers=num_workers, + use_gpu=True, + resources_per_worker=resources_per_worker, + ) + + mock_coordinator = MagicMock() + expected_request_kwargs = dict( + requester_id="train-test-run-123", + resources=[resources_per_worker] * num_workers, + expire_after_s=AUTOSCALING_REQUESTS_EXPIRE_TIME_S, + priority=ResourceRequestPriority.HIGH, + ) + + with patch( + "ray.get", + side_effect=lambda x, **_: x, + ): + policy = FixedScalingPolicy(scaling_config) + # Inject mock coordinator + policy.__dict__["_autoscaling_coordinator"] = mock_coordinator + + with freeze_time() as frozen_time: + # Simulate controller start + mock_run_context = MagicMock() + mock_run_context.run_id = "test-run-123" + policy.after_controller_start(mock_run_context) + + assert policy._requester_id == "train-test-run-123" + + # Verify request_resources was called with the correct arguments + mock_coordinator.request_resources.remote.assert_called_once_with( + **expected_request_kwargs + ) + + # Calling make_decision immediately should NOT re-request (interval not passed) + worker_group_state = MagicMock() + worker_group_status = MagicMock() + policy.make_decision_for_running_worker_group( + worker_group_state=worker_group_state, + worker_group_status=worker_group_status, + ) + assert mock_coordinator.request_resources.remote.call_count == 1 + + # Advance past the interval — should re-request + frozen_time.tick(AUTOSCALING_REQUESTS_INTERVAL_S) + policy.make_decision_for_running_worker_group( + worker_group_state=worker_group_state, + worker_group_status=worker_group_status, + ) + assert mock_coordinator.request_resources.remote.call_count == 2 + mock_coordinator.request_resources.remote.assert_called_with( + **expected_request_kwargs + ) + + # Simulate controller shutdown + asyncio.run(policy.before_controller_shutdown()) + mock_coordinator.cancel_request.remote.assert_called_once_with( + requester_id="train-test-run-123", + ) + + # Reset and test abort path + mock_coordinator.cancel_request.remote.reset_mock() + policy.before_controller_abort() + mock_coordinator.cancel_request.remote.assert_called_once_with( + requester_id="train-test-run-123", + ) + + if __name__ == "__main__": import sys From b94271f7211c7f8a0ccaf4440bdb4b6fa566b9d7 Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Thu, 26 Mar 2026 13:35:29 -0700 Subject: [PATCH 05/11] test fix Signed-off-by: JasonLi1909 --- .../data/tests/test_streaming_integration.py | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/python/ray/data/tests/test_streaming_integration.py b/python/ray/data/tests/test_streaming_integration.py index 1611fed08624..cbb515aee3b6 100644 --- a/python/ray/data/tests/test_streaming_integration.py +++ b/python/ray/data/tests/test_streaming_integration.py @@ -402,17 +402,39 @@ def test_streaming_split_schema_before_execution(ray_start_10_cpus_shared): def test_streaming_split_schema_during_execution(ray_start_10_cpus_shared): """Test schema retrieval from splits during execution.""" - ds = ray.data.range(20, override_num_blocks=20) + from ray._common.test_utils import SignalActor + + # Use two signals to coordinate: `started` confirms the executor is running, + # `blocker` keeps map tasks alive so the executor stays active. + started = SignalActor.remote() + blocker = SignalActor.remote() + + def blocking_fn(row): + ray.get(started.send.remote()) + ray.get(blocker.wait.remote()) + return row + + ds = ray.data.range(20, override_num_blocks=20).map(blocking_fn) i1, i2 = ds.streaming_split(2, equal=True) @ray.remote - def consume_and_check_schema(x): + def consume(x): for _ in x.iter_rows(): - with pytest.raises(RuntimeError, match="Cannot call schema()"): - x.schema() - break + pass + + # Start consumers — this triggers the executor on the coordinator. + refs = [consume.remote(i1), consume.remote(i2)] + + # Wait until a map task has started, guaranteeing the executor is alive. + ray.get(started.wait.remote()) + + # schema() should raise because execution is active. + with pytest.raises(ray.exceptions.RayTaskError, match="Cannot call schema()"): + i1.schema() - ray.get([consume_and_check_schema.remote(i1), consume_and_check_schema.remote(i2)]) + # Unblock map tasks so consumers can finish. + ray.get(blocker.send.remote()) + ray.get(refs) def test_streaming_split_schema_after_execution(ray_start_10_cpus_shared): From d31cd12f10dbfe97a022d204bd2fe99383880d63 Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Thu, 26 Mar 2026 13:56:50 -0700 Subject: [PATCH 06/11] use current executor Signed-off-by: JasonLi1909 --- python/ray/data/_internal/iterator/stream_split_iterator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/iterator/stream_split_iterator.py b/python/ray/data/_internal/iterator/stream_split_iterator.py index dd8034e39a5a..2337321ca4e2 100644 --- a/python/ray/data/_internal/iterator/stream_split_iterator.py +++ b/python/ray/data/_internal/iterator/stream_split_iterator.py @@ -156,7 +156,7 @@ def __init__( self._dataset_state_lock = threading.Lock() self._schema = None self._current_executor = None - + # Guarded by self._lock. self._next_bundle: Dict[int, RefBundle] = {} self._unfinished_clients_in_epoch = n @@ -181,7 +181,7 @@ def get_dataset_tag(self, output_split_idx: int) -> str: def get_dataset_schema(self): with self._dataset_state_lock: - if self._executor is not None and self._executor.is_alive(): + if self._current_executor is not None and self._current_executor.is_alive(): raise RuntimeError( "Cannot call schema() during active dataset execution. " "Call schema() before or after iterating over the dataset, or call " From be1dc9f685dc06ca6b456e28c79ae64f78b27736 Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Thu, 26 Mar 2026 13:59:09 -0700 Subject: [PATCH 07/11] remove index counting Signed-off-by: JasonLi1909 --- python/ray/data/_internal/iterator/stream_split_iterator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ray/data/_internal/iterator/stream_split_iterator.py b/python/ray/data/_internal/iterator/stream_split_iterator.py index 2337321ca4e2..38a3cbc63cd1 100644 --- a/python/ray/data/_internal/iterator/stream_split_iterator.py +++ b/python/ray/data/_internal/iterator/stream_split_iterator.py @@ -96,7 +96,6 @@ def gen_blocks() -> Iterator[RefBundle]: schema=block_ref_and_md.schema, ) - self._base_dataset._plan._run_index += 1 # Return None for executor since StreamSplitDataIterator has its own # mechanism for reporting prefetched bytes via SplitCoordinator. return gen_blocks(), self._iter_stats, False, None From 6c70a9576a079c466fbb73e0908d9e36e6b94551 Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Thu, 26 Mar 2026 16:03:58 -0700 Subject: [PATCH 08/11] test fix 2 Signed-off-by: JasonLi1909 --- python/ray/air/tests/test_new_dataset_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/air/tests/test_new_dataset_config.py b/python/ray/air/tests/test_new_dataset_config.py index a4003d24b6f5..0506b512b1cf 100644 --- a/python/ray/air/tests/test_new_dataset_config.py +++ b/python/ray/air/tests/test_new_dataset_config.py @@ -285,7 +285,7 @@ class MyTrainer(DataParallelTrainer): def __init__(self, **kwargs): def train_loop_fn(): train_ds = train.get_dataset_shard("train") - new_execution_options = train_ds._base_dataset.context.execution_options + new_execution_options = train_ds.get_context().execution_options if original_execution_options.is_resource_limits_default(): # If the original resource limits are default, the new resource # limits should be the default as well. From e4076748006eb37acb469cf4b37d93dd32cc400e Mon Sep 17 00:00:00 2001 From: Jason Li <57246540+JasonLi1909@users.noreply.github.com> Date: Thu, 26 Mar 2026 19:36:46 -0700 Subject: [PATCH 09/11] Apply suggestion from @justinvyu Co-authored-by: Justin Yu Signed-off-by: Jason Li <57246540+JasonLi1909@users.noreply.github.com> --- python/ray/data/_internal/iterator/stream_split_iterator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/iterator/stream_split_iterator.py b/python/ray/data/_internal/iterator/stream_split_iterator.py index 38a3cbc63cd1..435457413c47 100644 --- a/python/ray/data/_internal/iterator/stream_split_iterator.py +++ b/python/ray/data/_internal/iterator/stream_split_iterator.py @@ -112,7 +112,7 @@ def stats(self) -> str: ) return summary.to_string() - def schema(self) -> Union[type, "pyarrow.lib.Schema"]: + def schema(self) -> ray.data.Schema: """Implements DataIterator.""" return ray.get(self._coord_actor.get_dataset_schema.remote()) From 017c5399ddc67d8addc0563a4331679bc2856652 Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Thu, 26 Mar 2026 19:45:36 -0700 Subject: [PATCH 10/11] use schema cache if available Signed-off-by: JasonLi1909 --- python/ray/data/_internal/iterator/stream_split_iterator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/iterator/stream_split_iterator.py b/python/ray/data/_internal/iterator/stream_split_iterator.py index 435457413c47..c45bf99a89e5 100644 --- a/python/ray/data/_internal/iterator/stream_split_iterator.py +++ b/python/ray/data/_internal/iterator/stream_split_iterator.py @@ -180,14 +180,15 @@ def get_dataset_tag(self, output_split_idx: int) -> str: def get_dataset_schema(self): with self._dataset_state_lock: + if self._schema is not None: + return self._schema if self._current_executor is not None and self._current_executor.is_alive(): raise RuntimeError( "Cannot call schema() during active dataset execution. " "Call schema() before or after iterating over the dataset, or call " "schema() directly on the source Dataset object." ) - if self._schema is None: - self._schema = self._base_dataset.schema() + self._schema = self._base_dataset.schema() return self._schema def stats(self) -> DatasetStats: From 41570d2ad39103c65e6fd17450f535a48b04ebdf Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Thu, 26 Mar 2026 23:04:37 -0700 Subject: [PATCH 11/11] type fix Signed-off-by: JasonLi1909 --- .../ray/data/_internal/iterator/stream_split_iterator.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/ray/data/_internal/iterator/stream_split_iterator.py b/python/ray/data/_internal/iterator/stream_split_iterator.py index c45bf99a89e5..7c6f28692ce2 100644 --- a/python/ray/data/_internal/iterator/stream_split_iterator.py +++ b/python/ray/data/_internal/iterator/stream_split_iterator.py @@ -1,7 +1,7 @@ import logging import threading import time -from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple import ray from ray.data._internal.execution.interfaces import ( @@ -17,9 +17,8 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy if TYPE_CHECKING: - import pyarrow - from ray.data.dataset import Dataset + from ray.data.dataset import Dataset, Schema logger = logging.getLogger(__name__) @@ -112,7 +111,7 @@ def stats(self) -> str: ) return summary.to_string() - def schema(self) -> ray.data.Schema: + def schema(self) -> Optional["Schema"]: """Implements DataIterator.""" return ray.get(self._coord_actor.get_dataset_schema.remote())