Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 49 additions & 17 deletions python/ray/train/v2/_internal/callbacks/datasets.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,47 @@
import copy
from typing import Any, Callable, Dict, List, Union
from typing import Dict, List

import ray.train
from ray.data import Dataset
from ray.data import DataIterator, NodeIdStr
from ray.data.context import DataContext
from ray.train.v2._internal.data_integration.interfaces import (
DatasetShardMetadata,
DatasetShardProvider,
GenDataset,
)
from ray.train.v2._internal.execution.callback import WorkerGroupCallback
from ray.train.v2._internal.execution.worker_group.worker_group import (
Worker,
WorkerGroup,
)

# A type representing either a ray.data.Dataset or a function that returns a
# ray.data.Dataset and accepts no arguments.
GenDataset = Union[Dataset, Callable[[], Dataset]]

class RayDatasetShardProvider:
def __init__(
self,
datasets: Dict[str, GenDataset],
data_config: ray.train.DataConfig,
world_size: int,
worker_node_ids: List[NodeIdStr],
):
# Maps (world_rank, dataset_name) to a DataIterator.
self._dataset_iterators: List[Dict[str, DataIterator]] = data_config.configure(
datasets=datasets,
world_size=world_size,
worker_handles=None,
worker_node_ids=worker_node_ids,
)

def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator:
ds_shards_for_rank = self._dataset_iterators[dataset_info.world_rank]
if dataset_info.dataset_name not in ds_shards_for_rank:
raise KeyError(
f"Dataset shard for '{dataset_info.dataset_name}' not found. "
"Please ensure that the dataset is passed through the Trainer `datasets` "
"argument."
)

return ds_shards_for_rank[dataset_info.dataset_name]


class DatasetsSetupCallback(WorkerGroupCallback):
Expand Down Expand Up @@ -45,26 +74,29 @@ def get_train_total_resources(
these resources logically from its available pool."""
return scaling_config.total_resources

def before_init_train_context(self, workers: List[Worker]) -> Dict[str, List[Any]]:
# Configure dataset shards
datasets = {k: v() if callable(v) else v for k, v in self._datasets.items()}
node_ids = [worker.metadata.node_id for worker in workers]
# --------------------------
# WorkerGroupCallback
# --------------------------

def before_init_train_context(
self, workers: List[Worker]
) -> Dict[str, List[DatasetShardProvider]]:
world_size = len(workers)
worker_node_ids = [worker.metadata.node_id for worker in workers]

# Notify the DataConfig about the total resources reserved for training.
total_train_resources = self.get_train_total_resources(self._scaling_config)
self._data_config.set_train_total_resources(
total_train_resources.get("CPU", 0), total_train_resources.get("GPU", 0)
)

dataset_shards = self._data_config.configure(
datasets,
world_size=len(workers),
worker_handles=None,
worker_node_ids=node_ids,
dataset_manager = RayDatasetShardProvider(
datasets=self._datasets,
data_config=self._data_config,
world_size=world_size,
worker_node_ids=worker_node_ids,
)
assert len(dataset_shards) == len(workers)

return {"dataset_shards": dataset_shards}
return {"dataset_shard_provider": [dataset_manager] * len(workers)}

def after_worker_group_start(self, worker_group: WorkerGroup):
# Propagate DataContext
Expand Down
Empty file.
30 changes: 30 additions & 0 deletions python/ray/train/v2/_internal/data_integration/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from dataclasses import dataclass
from typing import Callable, Protocol, Union

from ray.data import DataIterator, Dataset

# A type representing either a ray.data.Dataset or a function that returns a
# ray.data.Dataset and accepts no arguments.
GenDataset = Union[Dataset, Callable[[], Dataset]]


@dataclass
class DatasetShardMetadata:
"""Metadata about a dataset shard used for lookup and configuration."""

dataset_name: str
world_rank: int


class DatasetShardProvider(Protocol):
def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator:
"""Get the dataset shard for the given dataset info.
Args:
dataset_info: The metadata of the shard to retrieve,
including the dataset name and worker rank.
Returns:
The :class:`~ray.data.DataIterator` shard for the given dataset info.
Raises:
KeyError: If the dataset shard for the given dataset info is not found.
"""
...
18 changes: 8 additions & 10 deletions python/ray/train/v2/_internal/execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from ray.train.v2.api.config import RunConfig, ScalingConfig

if TYPE_CHECKING:
from ray.train.v2._internal.data_integration.interfaces import (
DatasetShardMetadata,
DatasetShardProvider,
)
from ray.train.v2._internal.execution.callback import TrainContextCallback
from ray.train.v2._internal.execution.worker_group.thread_runner import ThreadRunner

Expand Down Expand Up @@ -92,7 +96,7 @@ class TrainContext:
distributed_context: DistributedContext
execution_context: ExecutionContext
storage_context: StorageContext
dataset_shards: Dict[str, DataIterator]
dataset_shard_provider: "DatasetShardProvider"
checkpoint: Optional[Checkpoint] = None

@_copy_doc(session.get_experiment_name)
Expand Down Expand Up @@ -133,27 +137,21 @@ def get_synchronization_actor(self):
def get_checkpoint(self):
return self.checkpoint

def get_dataset_shard(self, dataset_name: str) -> DataIterator:
def get_dataset_shard(self, dataset_info: "DatasetShardMetadata") -> DataIterator:
"""Returns the :class:`ray.data.DataIterator` shard for this worker.

Call :meth:`~ray.data.DataIterator.iter_torch_batches` or
:meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the
appropriate framework-specific data type.

Args:
dataset_name: Name of the dataset shard.
dataset_info: The shard metadata, including the dataset name and worker rank.
Returns:
The ``DataIterator`` shard with the given name for this worker.
Raises:
KeyError: If the dataset shard with the given name is not found.
"""
try:
return self.dataset_shards[dataset_name]
except KeyError:
raise KeyError(
f"Dataset {dataset_name} not found. Available datasets: "
f"{list(self.dataset_shards.keys())}."
)
return self.dataset_shard_provider.get_dataset_shard(dataset_info)

def get_context_callbacks(self) -> List["TrainContextCallback"]:
return self.execution_context.train_context_callbacks
Expand Down
10 changes: 9 additions & 1 deletion python/ray/train/v2/_internal/execution/train_fn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,15 @@ def get_dataset_shard(self, dataset_name: str) -> DataIterator:
Returns:
The DataIterator shard for this worker.
"""
return get_internal_train_context().get_dataset_shard(dataset_name)
from ray.train.v2._internal.data_integration.interfaces import (
DatasetShardMetadata,
)

dataset_info = DatasetShardMetadata(
dataset_name=dataset_name,
world_rank=get_internal_train_context().get_world_rank(),
)
return get_internal_train_context().get_dataset_shard(dataset_info)

def get_context(self) -> ExternalTrainContext:
return ExternalTrainContext()
Expand Down
10 changes: 6 additions & 4 deletions python/ray/train/v2/_internal/execution/worker_group/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import socket
from dataclasses import dataclass
from functools import cached_property
from typing import Callable, Dict, List, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, Union

import ray
import ray._private.ray_constants as ray_constants
from .thread_runner import ThreadRunner
from ray.actor import ActorHandle
from ray.data.iterator import DataIterator
from ray.train import Checkpoint
from ray.train.v2._internal.constants import (
DEFAULT_ENABLE_WORKER_LOGGING,
Expand Down Expand Up @@ -40,6 +39,9 @@
from ray.train.v2._internal.util import ObjectRefWrapper
from ray.types import ObjectRef

if TYPE_CHECKING:
from ray.train.v2._internal.data_integration.interfaces import DatasetShardProvider

T = TypeVar("T")

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -192,7 +194,7 @@ def init_train_context(
synchronization_actor: SynchronizationActor,
storage_context: StorageContext,
worker_callbacks: List[Union[WorkerCallback, TrainContextCallback]],
dataset_shards: Dict[str, DataIterator] = None,
dataset_shard_provider: Optional["DatasetShardProvider"] = None,
checkpoint: Optional[Checkpoint] = None,
):
self._callbacks = [c for c in worker_callbacks if isinstance(c, WorkerCallback)]
Expand All @@ -211,8 +213,8 @@ def init_train_context(
train_context_callbacks=context_callbacks_to_propagate,
),
storage_context=storage_context,
dataset_shards=dataset_shards or {},
checkpoint=checkpoint,
dataset_shard_provider=dataset_shard_provider,
)
# Configure the train and root logger for the worker processes.
if ray_constants.env_bool(
Expand Down
22 changes: 14 additions & 8 deletions python/ray/train/v2/tests/test_data_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from ray.data import DataContext, ExecutionResources
from ray.data._internal.iterator.stream_split_iterator import StreamSplitDataIterator
from ray.data.tests.conftest import restore_data_context # noqa: F401
from ray.train.v2._internal.callbacks import DatasetsSetupCallback
from ray.train.v2._internal.callbacks.datasets import DatasetsSetupCallback
from ray.train.v2._internal.data_integration.interfaces import DatasetShardMetadata
from ray.train.v2._internal.execution.context import TrainRunContext
from ray.train.v2._internal.execution.worker_group.worker_group import (
WorkerGroupContext,
Expand Down Expand Up @@ -87,13 +88,18 @@ def test_dataset_setup_callback(ray_start_4_cpus):
data_config=data_config,
scaling_config=scaling_config,
)
dataset_shards = callback.before_init_train_context(worker_group.get_workers())[
"dataset_shards"
]
assert len(dataset_shards) == NUM_WORKERS

processed_train_ds = dataset_shards[0]["train"]
processed_valid_ds = dataset_shards[0]["valid"]
dataset_manager_for_each_worker = callback.before_init_train_context(
worker_group.get_workers()
)["dataset_shard_provider"]
assert len(dataset_manager_for_each_worker) == NUM_WORKERS

dataset_manager = dataset_manager_for_each_worker[0]
processed_train_ds = dataset_manager.get_dataset_shard(
DatasetShardMetadata(dataset_name="train", world_rank=0)
)
processed_valid_ds = dataset_manager.get_dataset_shard(
DatasetShardMetadata(dataset_name="valid", world_rank=0)
)

assert isinstance(processed_train_ds, StreamSplitDataIterator)
assert not isinstance(processed_valid_ds, StreamSplitDataIterator)
Expand Down