-
Notifications
You must be signed in to change notification settings - Fork 7.3k
[Data] Distributed reads for from_huggingface #42599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1c5f7d0
ab001e5
dbad83d
027b88f
6220aff
dddf9d4
edd4c16
572202b
a5cb130
8d47704
1e72361
aed5369
06a00fd
76f69c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2358,13 +2358,18 @@ def from_spark( | |
|
|
||
| @PublicAPI | ||
| def from_huggingface( | ||
| dataset: Union["datasets.Dataset", "datasets.IterableDataset"], | ||
| dataset: Union["datasets.Dataset", "datasets.IterableDataset"], parallelism=-1 | ||
|
||
| ) -> Union[MaterializedDataset, Dataset]: | ||
| """Create a :class:`~ray.data.MaterializedDataset` from a | ||
| `Hugging Face Datasets Dataset <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset/>`_ | ||
| or a :class:`~ray.data.Dataset` from a `Hugging Face Datasets IterableDataset <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.IterableDataset/>`_. | ||
| For an `IterableDataset`, we use a streaming implementation to read data. | ||
|
|
||
| If the dataset is a public Hugging Face Dataset that is hosted on the Hugging Face Hub and | ||
| no transformations have been applied, then the `hosted parquet files <https://huggingface.co/docs/datasets-server/parquet#list-parquet-files>`_ | ||
| will be passed to :meth:`~ray.data.read_parquet` to perform a distributed read. All | ||
| other cases will be done with a single node read. | ||
|
|
||
| Example: | ||
|
|
||
| .. | ||
|
|
@@ -2403,18 +2408,36 @@ def from_huggingface( | |
| `DatasetDict <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.DatasetDict/>`_ | ||
| and `IterableDatasetDict <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.IterableDatasetDict/>`_ | ||
| are not supported. | ||
| parallelism: The amount of parallelism to use for the dataset if applicable (i.e. | ||
| if the dataset is a public Hugging Face Dataset without transforms applied). | ||
| Defaults to -1, which automatically determines the optimal parallelism for your | ||
| configuration. You should not need to manually set this value in most cases. | ||
| For details on how the parallelism is automatically determined and guidance | ||
| on how to tune it, see :ref:`Tuning read parallelism | ||
| <read_parallelism>`. Parallelism is upper bounded by the total number of | ||
| records in all the parquet files. | ||
|
|
||
| Returns: | ||
| A :class:`~ray.data.Dataset` holding rows from the `Hugging Face Datasets Dataset`_. | ||
| """ # noqa: E501 | ||
| import datasets | ||
|
|
||
| if isinstance(dataset, datasets.IterableDataset): | ||
| # HuggingFaceDatasource should not be imported at top level, because | ||
| # we only want the Hugging Face datasets package to be imported | ||
| # if Hugging Face Datasets are used. | ||
| from ray.data.datasource.huggingface_datasource import HuggingFaceDatasource | ||
| from ray.data.datasource.huggingface_datasource import HuggingFaceDatasource | ||
|
|
||
| if isinstance(dataset, (datasets.IterableDataset, datasets.Dataset)): | ||
| # Attempt to read data via Hugging Face Hub parquet files. If the | ||
| # returned list of files is empty, attempt read via other methods. | ||
| file_urls = HuggingFaceDatasource.list_parquet_urls_from_dataset(dataset) | ||
| if len(file_urls) > 0: | ||
| # If file urls are returned, the parquet files are available via API | ||
| # TODO: Add support for reading from http filesystem in FileBasedDatasource | ||
| # GH Issue: https://github.com/ray-project/ray/issues/42706 | ||
| import fsspec.implementations.http | ||
|
|
||
| http = fsspec.implementations.http.HTTPFileSystem() | ||
| return read_parquet(file_urls, parallelism=parallelism, filesystem=http) | ||
|
|
||
| if isinstance(dataset, datasets.IterableDataset): | ||
| # For an IterableDataset, we can use a streaming implementation to read data. | ||
| return read_datasource(HuggingFaceDatasource(dataset=dataset)) | ||
| if isinstance(dataset, datasets.Dataset): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1319,6 +1319,8 @@ def test_from_arrow_refs_e2e(ray_start_regular_shared, enable_optimizer): | |
| def test_from_huggingface_e2e(ray_start_regular_shared, enable_optimizer): | ||
| import datasets | ||
|
|
||
| from ray.data.tests.test_huggingface import hfds_assert_equals | ||
|
|
||
| data = datasets.load_dataset("tweet_eval", "emotion") | ||
| assert isinstance(data, datasets.DatasetDict) | ||
| ray_datasets = { | ||
|
|
@@ -1333,20 +1335,27 @@ def test_from_huggingface_e2e(ray_start_regular_shared, enable_optimizer): | |
| # needed for checking operator usage below. | ||
| assert len(ds.take_all()) > 0 | ||
| # Check that metadata fetch is included in stats; | ||
| # the underlying implementation uses the `FromArrow` operator. | ||
| assert "FromArrow" in ds.stats() | ||
| assert ds._plan._logical_plan.dag.name == "FromArrow" | ||
| assert ray.get(ray_datasets[ds_key].to_arrow_refs())[0].equals( | ||
| data[ds_key].data.table | ||
| ) | ||
| _check_usage_record(["FromArrow"]) | ||
|
|
||
| ray_dataset = ray.data.from_huggingface(data["train"]) | ||
| assert isinstance(ray_dataset, ray.data.Dataset) | ||
| assert len(ray_dataset.take_all()) > 0 | ||
| assert "FromArrow" in ray_dataset.stats() | ||
| assert ray_dataset._plan._logical_plan.dag.name == "FromArrow" | ||
| assert ray.get(ray_dataset.to_arrow_refs())[0].equals(data["train"].data.table) | ||
| # the underlying implementation uses the `ReadParquet` operator | ||
| # as this is an un-transformed public dataset. | ||
| assert "ReadParquet" in ds.stats() | ||
| assert ds._plan._logical_plan.dag.name == "ReadParquet" | ||
|
||
| # use sort by 'text' to match order of rows | ||
| hfds_assert_equals(data[ds_key], ds) | ||
| _check_usage_record(["ReadParquet"]) | ||
|
|
||
| # test transformed public dataset for fallback behavior | ||
| base_hf_dataset = data["train"] | ||
| hf_dataset_split = base_hf_dataset.train_test_split(test_size=0.2) | ||
| ray_dataset_split_train = ray.data.from_huggingface(hf_dataset_split["train"]) | ||
| assert isinstance(ray_dataset_split_train, ray.data.Dataset) | ||
| # `ds.take_all()` triggers execution with new backend, which is | ||
| # needed for checking operator usage below. | ||
| assert len(ray_dataset_split_train.take_all()) > 0 | ||
| # Check that metadata fetch is included in stats; | ||
| # the underlying implementation uses the `FromArrow` operator. | ||
| assert "FromArrow" in ray_dataset_split_train.stats() | ||
| assert ray_dataset_split_train._plan._logical_plan.dag.name == "FromArrow" | ||
| assert ray_dataset_split_train.count() == hf_dataset_split["train"].num_rows | ||
| _check_usage_record(["FromArrow"]) | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.