Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/ray/train/v2/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,23 @@ py_test(
],
)

py_test(
name = "test_dataset_manager",
size = "medium",
srcs = ["tests/test_dataset_manager.py"],
env = {"RAY_TRAIN_V2_ENABLED": "1"},
tags = [
"data_integration",
"exclusive",
"team:ml",
"train_v2",
],
deps = [
":conftest",
"//:ray_lib",
],
)

py_test(
name = "test_env_callbacks",
size = "small",
Expand Down
139 changes: 66 additions & 73 deletions python/ray/train/v2/_internal/callbacks/datasets.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,89 @@
import copy
import logging
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Dict, List, Optional

import ray
import ray.train
from ray.actor import ActorHandle
from ray.exceptions import GetTimeoutError
from ray.train.v2._internal.data_integration.interfaces import (
DatasetShardMetadata,
DatasetShardProvider,
GenDataset,
)
from ray.train.v2._internal.execution.callback import (
ControllerCallback,
WorkerGroupCallback,
)
from ray.train.v2._internal.execution.callback import WorkerGroupCallback
from ray.train.v2._internal.execution.context import TrainRunContext
from ray.train.v2._internal.execution.worker_group.worker_group import (
Worker,
WorkerGroup,
WorkerGroupContext,
)
from ray.types import ObjectRef
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

if TYPE_CHECKING:
from ray.data import DataIterator, Dataset
from ray.data import DataIterator, Dataset, NodeIdStr
from ray.data.context import DataContext

logger = logging.getLogger(__name__)


class RayDatasetShardProvider:
"""A shard provider that Train workers use to access a DataIterator for a dataset."""
def __init__(
self,
datasets: Dict[str, GenDataset],
data_config: ray.train.DataConfig,
data_context: "DataContext",
world_size: int,
worker_node_ids: List["NodeIdStr"],
):
from ray.train.v2._internal.data_integration.dataset_manager import (
DatasetManager,
)

def __init__(self, ds_iterators: Dict[str, "DataIterator"]):
# Maps dataset_name to a DataIterator.
self._dataset_iterators = ds_iterators
self._dataset_names = set(datasets)
self._dataset_manager = (
ray.remote(DatasetManager)
.options(
num_cpus=0,
scheduling_strategy=NodeAffinitySchedulingStrategy(
ray.get_runtime_context().get_node_id(), soft=False
),
)
.remote(
datasets=datasets,
data_config=data_config,
data_context=data_context,
world_size=world_size,
worker_node_ids=worker_node_ids,
)
)
self._cached_dataset_shards: Dict[str, "DataIterator"] = {}

def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> "DataIterator":
if dataset_info.dataset_name not in self._dataset_iterators:
dataset_name = dataset_info.dataset_name
if dataset_name not in self._dataset_names:
raise KeyError(
f"Dataset shard for '{dataset_info.dataset_name}' not found. "
f"Dataset shard for '{dataset_name}' not found. "
"Please ensure that the dataset is passed through the Trainer `datasets` "
"argument."
)

return self._dataset_iterators[dataset_info.dataset_name]
if dataset_name not in self._cached_dataset_shards:
self._cached_dataset_shards[dataset_name] = ray.get(
self._dataset_manager.get_dataset_shard.remote(dataset_info)
)

return self._cached_dataset_shards[dataset_name]

def shutdown_data_executors(self) -> None:
"""
Attempts to eagerly shutdown the data executors for datasets, freeing resources allocated to data execution.
"""
try:
self._dataset_manager.shutdown_data_executors.remote()
except Exception:
logger.debug("Failed to invoke remote cleanup of Dataset Manager.")


class DatasetsCallback(WorkerGroupCallback, ControllerCallback):
class DatasetsCallback(WorkerGroupCallback):
"""A callback for managing Ray Datasets for the worker group."""

def __init__(
Expand All @@ -58,8 +94,7 @@ def __init__(
self._datasets = datasets
self._data_config = copy.deepcopy(train_run_context.dataset_config)
self._scaling_config = train_run_context.scaling_config
self._coordinator_actors: List[ActorHandle] = []
self._shutdown_refs: List[ObjectRef] = []
self._dataset_shard_provider: Optional[RayDatasetShardProvider] = None

# Capture the current DataContext to propagate it to
# the Train workers later.
Expand All @@ -86,34 +121,6 @@ def get_train_total_resources(
return {}
return scaling_config.total_resources

def _get_coordinator_actors(
self, ds_iterators_per_rank: List[Dict[str, "DataIterator"]]
) -> List[ActorHandle]:
"""
Returns a list of each unique SplitCoordinator actor handle given the iterators per rank.
These handles will later be used to call shutdown on the actors.
"""
from ray.data._internal.iterator.stream_split_iterator import (
StreamSplitDataIterator,
)

# Note: Currently, we only need to check rank 0 for split iterators.
# In the future, if datasets can be split across only a subset of ranks,
# we may need to process all ranks.
rank_0_iterators = ds_iterators_per_rank[0]
coord_actors = [
iterator._coord_actor
for iterator in rank_0_iterators.values()
if isinstance(iterator, StreamSplitDataIterator)
]
return coord_actors

def _shutdown_data_executors(self):
"""Eagerly shutdown the data executors of the split coordinator actors."""
self._shutdown_refs = [
coord.shutdown_executor.remote() for coord in self._coordinator_actors
]

# --------------------------
# WorkerGroupCallback
# --------------------------
Expand All @@ -123,29 +130,23 @@ def before_init_train_context(
) -> Dict[str, List[DatasetShardProvider]]:
world_size = len(workers)
worker_node_ids = [worker.metadata.node_id for worker in workers]
datasets = {k: v() if callable(v) else v for k, v in self._datasets.items()}

# TODO: Move this to the constructor.
# Notify the DataConfig about the total resources reserved for training.
total_train_resources = self.get_train_total_resources(self._scaling_config)
self._data_config.set_train_total_resources(
total_train_resources.get("CPU", 0), total_train_resources.get("GPU", 0)
)

datasets = {k: v() if callable(v) else v for k, v in self._datasets.items()}
ds_iterators_per_rank = self._data_config.configure(
self._dataset_shard_provider = RayDatasetShardProvider(
datasets=datasets,
data_config=self._data_config,
data_context=self._data_context,
world_size=world_size,
worker_handles=None,
worker_node_ids=worker_node_ids,
)
assert len(ds_iterators_per_rank) == world_size

self._coordinator_actors = self._get_coordinator_actors(ds_iterators_per_rank)

shard_providers_per_rank = [
RayDatasetShardProvider(ds_iterators=ds_iterators_per_rank[rank])
for rank in range(world_size)
]
return {"dataset_shard_provider": shard_providers_per_rank}
return {"dataset_shard_provider": [self._dataset_shard_provider] * world_size}

def after_worker_group_start(self, worker_group: WorkerGroup):
# Propagate DataContext
Expand All @@ -162,21 +163,13 @@ def _propagate_data_context(ctx: "DataContext"):
def after_worker_group_shutdown(
self, worker_group_context: WorkerGroupContext
) -> None:
self._shutdown_data_executors()
shard_provider = self._dataset_shard_provider
if shard_provider:
shard_provider.shutdown_data_executors()

def after_worker_group_abort(
self, worker_group_context: WorkerGroupContext
) -> None:
self._shutdown_data_executors()

# --------------------------
# ControllerCallback
# --------------------------

async def before_controller_shutdown(self):
try:
ray.get(self._shutdown_refs, timeout=5)
except GetTimeoutError:
logger.error("Ray Data executor shutdown task timed out after 5 seconds.")
except Exception:
logger.exception("Failed to gracefully terminate Ray Data executors.")
shard_provider = self._dataset_shard_provider
if shard_provider:
shard_provider.shutdown_data_executors()
Loading
Loading