|
4 | 4 | from typing import TYPE_CHECKING, Any, Dict, List, Optional |
5 | 5 |
|
6 | 6 | from ray.data import DataIterator |
| 7 | +from ray.train.v2._internal.data_integration.interfaces import DatasetShardMetadata |
7 | 8 | from ray.train.v2._internal.execution import collective_impl |
8 | 9 | from ray.train.v2._internal.execution.context import ( |
9 | 10 | get_train_context as get_internal_train_context, |
@@ -68,14 +69,11 @@ def get_all_reported_checkpoints(self) -> List["ReportedCheckpoint"]: |
68 | 69 | pass |
69 | 70 |
|
70 | 71 | @abstractmethod |
71 | | - def get_dataset_shard(self, dataset_name: str) -> DataIterator: |
| 72 | + def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator: |
72 | 73 | """Get the dataset shard for this training process. |
73 | 74 |
|
74 | | - This method is used by the public API function :func:`ray.train.get_dataset_shard`. |
75 | | - Users should typically call ``ray.train.get_dataset_shard()`` instead of calling this method directly. |
76 | | -
|
77 | 75 | Args: |
78 | | - dataset_name: The name of the dataset to get the shard for. |
| 76 | + dataset_info: The metadata of the dataset to get the shard for. |
79 | 77 |
|
80 | 78 | Returns: |
81 | 79 | The DataIterator shard for this worker. |
@@ -131,14 +129,8 @@ def report( |
131 | 129 | def get_checkpoint(self): |
132 | 130 | return get_internal_train_context().get_checkpoint() |
133 | 131 |
|
134 | | - def get_dataset_shard(self, dataset_name: str) -> DataIterator: |
135 | | - from ray.train.v2._internal.data_integration.interfaces import ( |
136 | | - DatasetShardMetadata, |
137 | | - ) |
138 | | - |
139 | | - return get_internal_train_context().get_dataset_shard( |
140 | | - DatasetShardMetadata(dataset_name=dataset_name) |
141 | | - ) |
| 132 | + def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator: |
| 133 | + return get_internal_train_context().get_dataset_shard(dataset_info) |
142 | 134 |
|
143 | 135 | def get_context(self) -> DistributedTrainContext: |
144 | 136 | return DistributedTrainContext() |
@@ -182,7 +174,8 @@ def report( |
182 | 174 | def get_checkpoint(self) -> Optional["Checkpoint"]: |
183 | 175 | return self._last_checkpoint |
184 | 176 |
|
185 | | - def get_dataset_shard(self, dataset_name: str) -> DataIterator: |
| 177 | + def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator: |
| 178 | + dataset_name = dataset_info.dataset_name |
186 | 179 | assert ( |
187 | 180 | self._dataset_shards is not None and dataset_name in self._dataset_shards |
188 | 181 | ), f"Dataset shard {dataset_name} not found." |
|
0 commit comments