diff --git a/ci/lint/pydoclint-baseline.txt b/ci/lint/pydoclint-baseline.txt index dda36bc5155b..de5c705d5a74 100644 --- a/ci/lint/pydoclint-baseline.txt +++ b/ci/lint/pydoclint-baseline.txt @@ -1369,8 +1369,6 @@ python/ray/data/read_api.py DOC101: Function `read_text`: Docstring contains fewer arguments than in function signature. DOC103: Function `read_text`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [drop_empty_lines: bool]. DOC103: Function `read_numpy`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**numpy_load_args: ]. Arguments in the docstring but not in the function signature: [numpy_load_args: ]. - DOC104: Function `read_binary_files`: Arguments are the same in the docstring and the function signature, but are in a different order. - DOC105: Function `read_binary_files`: Argument names match, but type hints in these args do not match: paths, include_paths, filesystem, parallelism, ray_remote_args, arrow_open_stream_args, meta_provider, partition_filter, partitioning, ignore_missing_paths, shuffle, file_extensions, concurrency, override_num_blocks -------------------- python/ray/data/tests/test_split.py DOC106: Function `assert_split_assignment`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature diff --git a/python/ray/data/_internal/util.py b/python/ray/data/_internal/util.py index 3db54be3d547..77b6b75f4e9b 100644 --- a/python/ray/data/_internal/util.py +++ b/python/ray/data/_internal/util.py @@ -16,6 +16,7 @@ TYPE_CHECKING, Any, Callable, + Dict, Generator, Iterable, Iterator, @@ -1752,3 +1753,30 @@ def rows_same(actual: pd.DataFrame, expected: pd.DataFrame) -> bool: expected_items_counts = Counter(frozenset(row.items()) for row in expected_rows) return actual_items_counts == expected_items_counts + + +def merge_resources_to_ray_remote_args( + num_cpus: Optional[int], + num_gpus: Optional[int], + memory: Optional[int], + ray_remote_args: Dict[str, Any], +) -> Dict[str, Any]: + """Convert the given resources to Ray remote args. + + Args: + num_cpus: The number of CPUs to be added to the Ray remote args. + num_gpus: The number of GPUs to be added to the Ray remote args. + memory: The memory to be added to the Ray remote args. + ray_remote_args: The Ray remote args to be merged. + + Returns: + The converted arguments. + """ + ray_remote_args = ray_remote_args.copy() + if num_cpus is not None: + ray_remote_args["num_cpus"] = num_cpus + if num_gpus is not None: + ray_remote_args["num_gpus"] = num_gpus + if memory is not None: + ray_remote_args["memory"] = memory + return ray_remote_args diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 59391f36ce3f..5abd0ba21f5a 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -95,6 +95,7 @@ ConsumptionAPI, _validate_rows_per_file_args, get_compute_strategy, + merge_resources_to_ray_remote_args, ) from ray.data.aggregate import AggregateFn, Max, Mean, Min, Std, Sum, Unique from ray.data.block import ( @@ -403,14 +404,12 @@ def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]: concurrency=concurrency, ) - if num_cpus is not None: - ray_remote_args["num_cpus"] = num_cpus - - if num_gpus is not None: - ray_remote_args["num_gpus"] = num_gpus - - if memory is not None: - ray_remote_args["memory"] = memory + ray_remote_args = merge_resources_to_ray_remote_args( + num_cpus, + num_gpus, + memory, + ray_remote_args, + ) plan = self._plan.copy() map_op = MapRows( @@ -1368,14 +1367,12 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]: concurrency=concurrency, ) - if num_cpus is not None: - ray_remote_args["num_cpus"] = num_cpus - - if num_gpus is not None: - ray_remote_args["num_gpus"] = num_gpus - - if memory is not None: - ray_remote_args["memory"] = memory + ray_remote_args = merge_resources_to_ray_remote_args( + num_cpus, + num_gpus, + memory, + ray_remote_args, + ) plan = self._plan.copy() op = FlatMap( @@ -1403,6 +1400,9 @@ def filter( fn_kwargs: Optional[Dict[str, Any]] = None, fn_constructor_args: Optional[Iterable[Any]] = None, fn_constructor_kwargs: Optional[Dict[str, Any]] = None, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]] = None, ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, **ray_remote_args, @@ -1444,6 +1444,11 @@ def filter( This can only be provided if ``fn`` is a callable class. These arguments are top-level arguments in the underlying Ray actor construction task. compute: This argument is deprecated. Use ``concurrency`` argument. + num_cpus: The number of CPUs to reserve for each parallel map worker. + num_gpus: The number of GPUs to reserve for each parallel map worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel map + worker. + memory: The heap memory in bytes to reserve for each parallel map worker. concurrency: The semantics of this argument depend on the type of ``fn``: * If ``fn`` is a function and ``concurrency`` isn't set (default), the @@ -1518,6 +1523,12 @@ def filter( f"{type(fn).__name__} instead." ) + ray_remote_args = merge_resources_to_ray_remote_args( + num_cpus, + num_gpus, + memory, + ray_remote_args, + ) plan = self._plan.copy() op = Filter( input_op=self._logical_plan.dag, diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index 578a706c3d2b..46a81e5bc9db 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -71,6 +71,7 @@ from ray.data._internal.util import ( _autodetect_parallelism, get_table_block_metadata_schema, + merge_resources_to_ray_remote_args, ndarray_to_block, pandas_df_to_arrow_block, ) @@ -355,6 +356,9 @@ def read_datasource( datasource: Datasource, *, parallelism: int = -1, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, ray_remote_args: Dict[str, Any] = None, concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, @@ -365,6 +369,11 @@ def read_datasource( Args: datasource: The :class:`~ray.data.Datasource` to read data from. parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control number of tasks to run concurrently. This doesn't change the @@ -396,6 +405,13 @@ def read_datasource( if "scheduling_strategy" not in ray_remote_args: ray_remote_args["scheduling_strategy"] = ctx.scheduling_strategy + ray_remote_args = merge_resources_to_ray_remote_args( + num_cpus, + num_gpus, + memory, + ray_remote_args, + ) + datasource_or_legacy_reader = _get_datasource_or_legacy_reader( datasource, ctx, @@ -453,6 +469,9 @@ def read_audio( shuffle: Union[Literal["files"], None] = None, concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, ray_remote_args: Optional[Dict[str, Any]] = None, ): """Creates a :class:`~ray.data.Dataset` from audio files. @@ -502,6 +521,11 @@ def read_audio( By default, the number of output blocks is dynamically decided based on input data size and available resources. You shouldn't manually set this value in most cases. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. ray_remote_args: kwargs passed to :meth:`~ray.remote` in the read tasks. Returns: @@ -523,6 +547,9 @@ def read_audio( return read_datasource( datasource, ray_remote_args=ray_remote_args, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, concurrency=concurrency, override_num_blocks=override_num_blocks, ) @@ -543,6 +570,9 @@ def read_videos( shuffle: Union[Literal["files"], None] = None, concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, ray_remote_args: Optional[Dict[str, Any]] = None, ): """Creates a :class:`~ray.data.Dataset` from video files. @@ -592,6 +622,11 @@ def read_videos( total number of tasks run or the total number of output blocks. By default, concurrency is dynamically decided based on the available resources. ray_remote_args: kwargs passed to :meth:`~ray.remote` in the read tasks. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. Returns: A :class:`~ray.data.Dataset` containing video frames from the video files. @@ -612,6 +647,9 @@ def read_videos( return read_datasource( datasource, ray_remote_args=ray_remote_args, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, concurrency=concurrency, override_num_blocks=override_num_blocks, ) @@ -626,6 +664,9 @@ def read_mongo( pipeline: Optional[List[Dict]] = None, schema: Optional["pymongoarrow.api.Schema"] = None, parallelism: int = -1, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, ray_remote_args: Dict[str, Any] = None, concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, @@ -679,6 +720,11 @@ def read_mongo( schema: The schema used to read the collection. If None, it'll be inferred from the results of pipeline. parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control number of tasks to run concurrently. This doesn't change the @@ -710,6 +756,9 @@ def read_mongo( ) return read_datasource( datasource, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, parallelism=parallelism, ray_remote_args=ray_remote_args, concurrency=concurrency, @@ -724,6 +773,9 @@ def read_bigquery( query: Optional[str] = None, *, parallelism: int = -1, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, ray_remote_args: Dict[str, Any] = None, concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, @@ -768,6 +820,11 @@ def read_bigquery( dataset: The name of the dataset hosted in BigQuery in the format of ``dataset_id.table_id``. Both the dataset_id and table_id must exist otherwise an exception will be raised. parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control number of tasks to run concurrently. This doesn't change the @@ -785,6 +842,9 @@ def read_bigquery( datasource = BigQueryDatasource(project_id=project_id, dataset=dataset, query=query) return read_datasource( datasource, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, parallelism=parallelism, ray_remote_args=ray_remote_args, concurrency=concurrency, @@ -799,6 +859,9 @@ def read_parquet( filesystem: Optional["pyarrow.fs.FileSystem"] = None, columns: Optional[List[str]] = None, parallelism: int = -1, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, ray_remote_args: Dict[str, Any] = None, tensor_column_schema: Optional[Dict[str, Tuple[np.dtype, Tuple[int, ...]]]] = None, meta_provider: Optional[FileMetadataProvider] = None, @@ -896,6 +959,11 @@ def read_parquet( columns: A list of column names to read. Only the specified columns are read during the file scan. parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. tensor_column_schema: A dict of column name to PyArrow dtype and shape mappings for converting a Parquet column containing serialized @@ -962,6 +1030,9 @@ def read_parquet( ) return read_datasource( datasource, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, parallelism=parallelism, ray_remote_args=ray_remote_args, concurrency=concurrency, @@ -975,6 +1046,9 @@ def read_images( *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = -1, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, meta_provider: Optional[BaseFileMetadataProvider] = None, ray_remote_args: Dict[str, Any] = None, arrow_open_file_args: Optional[Dict[str, Any]] = None, @@ -1048,6 +1122,11 @@ class string the filesystem is automatically selected based on the scheme of the paths. For example, if the path begins with ``s3://``, the `S3FileSystem` is used. parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. meta_provider: [Deprecated] A :ref:`file metadata provider `. Custom metadata providers may be able to resolve file metadata more quickly and/or accurately. In most cases, you do not need to set this. If ``None``, @@ -1119,6 +1198,9 @@ class string ) return read_datasource( datasource, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, parallelism=parallelism, ray_remote_args=ray_remote_args, concurrency=concurrency, @@ -1133,6 +1215,9 @@ def read_parquet_bulk( filesystem: Optional["pyarrow.fs.FileSystem"] = None, columns: Optional[List[str]] = None, parallelism: int = -1, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, ray_remote_args: Dict[str, Any] = None, arrow_open_file_args: Optional[Dict[str, Any]] = None, tensor_column_schema: Optional[Dict[str, Tuple[np.dtype, Tuple[int, ...]]]] = None, @@ -1184,6 +1269,11 @@ def read_parquet_bulk( columns: A list of column names to read. Only the specified columns are read during the file scan. parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. arrow_open_file_args: kwargs passed to `pyarrow.fs.FileSystem.open_input_file `_. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control number of tasks to run concurrently. This doesn't change the @@ -2722,6 +2919,9 @@ def read_hudi( return read_datasource( datasource=datasource, ray_remote_args=ray_remote_args, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, concurrency=concurrency, override_num_blocks=override_num_blocks, ) @@ -3180,6 +3380,9 @@ def read_delta_sharing_tables( timestamp: Optional[str] = None, json_predicate_hints: Optional[str] = None, ray_remote_args: Optional[Dict[str, Any]] = None, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, ) -> Dataset: @@ -3223,6 +3426,11 @@ def read_delta_sharing_tables( details, see: https://github.com/delta-io/delta-sharing/blob/main/PROTOCOL.md#json-predicates-for-filtering. ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control the number of tasks to run concurrently. This doesn't change the total number of tasks run or the total number of output blocks. By default, @@ -3252,6 +3460,9 @@ def read_delta_sharing_tables( return ray.data.read_datasource( datasource=datasource, ray_remote_args=ray_remote_args, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, concurrency=concurrency, override_num_blocks=override_num_blocks, ) @@ -3534,6 +3745,7 @@ def from_torch( dataset: A `Torch Dataset`_. local_read: If ``True``, perform the read as a local read. + Returns: A :class:`~ray.data.Dataset` containing the Torch dataset samples. """ # noqa: E501 @@ -3571,6 +3783,9 @@ def read_iceberg( scan_kwargs: Optional[Dict[str, str]] = None, catalog_kwargs: Optional[Dict[str, str]] = None, ray_remote_args: Optional[Dict[str, Any]] = None, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, override_num_blocks: Optional[int] = None, ) -> Dataset: """Create a :class:`~ray.data.Dataset` from an Iceberg table. @@ -3615,6 +3830,11 @@ def read_iceberg( #pyiceberg.catalog.load_catalog>`_. ray_remote_args: Optional arguments to pass to :func:`ray.remote` in the read tasks. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. override_num_blocks: Override the number of output blocks from all read tasks. By default, the number of output blocks is dynamically decided based on input data size and available resources, and capped at the number of @@ -3638,6 +3858,9 @@ def read_iceberg( dataset = read_datasource( datasource=datasource, parallelism=parallelism, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, override_num_blocks=override_num_blocks, ray_remote_args=ray_remote_args, ) @@ -3654,6 +3877,9 @@ def read_lance( storage_options: Optional[Dict[str, str]] = None, scanner_options: Optional[Dict[str, Any]] = None, ray_remote_args: Optional[Dict[str, Any]] = None, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, ) -> Dataset: @@ -3684,6 +3910,11 @@ def read_lance( see `LanceDB API doc `_ ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control number of tasks to run concurrently. This doesn't change the total number of tasks run or the total number of output blocks. By default, @@ -3707,6 +3938,9 @@ def read_lance( return read_datasource( datasource=datasource, ray_remote_args=ray_remote_args, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, concurrency=concurrency, override_num_blocks=override_num_blocks, ) @@ -3723,6 +3957,9 @@ def read_clickhouse( client_settings: Optional[Dict[str, Any]] = None, client_kwargs: Optional[Dict[str, Any]] = None, ray_remote_args: Optional[Dict[str, Any]] = None, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, ) -> Dataset: @@ -3764,6 +4001,11 @@ def read_clickhouse( client_kwargs: Optional additional arguments to pass to the ClickHouse client. For more information, see `ClickHouse Core Settings `_. ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control number of tasks to run concurrently. This doesn't change the total number of tasks run or the total number of output blocks. By default, @@ -3789,6 +4031,9 @@ def read_clickhouse( return read_datasource( datasource=datasource, ray_remote_args=ray_remote_args, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, concurrency=concurrency, override_num_blocks=override_num_blocks, ) @@ -3886,6 +4131,9 @@ def read_delta( filesystem: Optional["pyarrow.fs.FileSystem"] = None, columns: Optional[List[str]] = None, parallelism: int = -1, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, ray_remote_args: Optional[Dict[str, Any]] = None, meta_provider: Optional[FileMetadataProvider] = None, partition_filter: Optional[PathPartitionFilter] = None, @@ -3917,6 +4165,11 @@ def read_delta( columns: A list of column names to read. Only the specified columns are read during the file scan. parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel read + worker. + memory: The heap memory in bytes to reserve for each parallel read worker. ray_remote_args: kwargs passed to :meth:`~ray.remote` in the read tasks. meta_provider: A :ref:`file metadata provider `. Custom metadata providers may be able to resolve file metadata more quickly and/or diff --git a/python/ray/data/tests/test_util.py b/python/ray/data/tests/test_util.py index 03d74268ea03..557925fa6168 100644 --- a/python/ray/data/tests/test_util.py +++ b/python/ray/data/tests/test_util.py @@ -26,6 +26,7 @@ _check_pyarrow_version, find_partition_index, iterate_with_retry, + merge_resources_to_ray_remote_args, rows_same, ) from ray.data.tests.conftest import * # noqa: F401, F403 @@ -345,6 +346,21 @@ def test_find_partition_index_duplicates_descending(): assert find_partition_index(table, (3,), sort_key) == 0 +def test_merge_resources_to_ray_remote_args(): + ray_remote_args = {} + ray_remote_args = merge_resources_to_ray_remote_args(1, 1, 1, ray_remote_args) + assert ray_remote_args == {"num_cpus": 1, "num_gpus": 1, "memory": 1} + + ray_remote_args = {"other_resource": 1} + ray_remote_args = merge_resources_to_ray_remote_args(1, 1, 1, ray_remote_args) + assert ray_remote_args == { + "num_cpus": 1, + "num_gpus": 1, + "memory": 1, + "other_resource": 1, + } + + @pytest.mark.parametrize( "actual, expected, expected_equal", [