diff --git a/docs/releases.rst b/docs/releases.rst index 03c6c624..c5d3ed9a 100644 --- a/docs/releases.rst +++ b/docs/releases.rst @@ -10,6 +10,8 @@ New Features ~~~~~~~~~~~~ - Added experimental ManifestStore (:pull:`490`). +- Added experimental :py:func:`open_virtual_mfdataset` function (:issue:`345`, :pull:`349`). + By `Tom Nicholas `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index e25b7c5d..b6f8a58a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ requires-python = ">=3.11" dynamic = ["version"] dependencies = [ - "xarray>=2025.1.1", + "xarray>=2025.3.0", "numpy>=2.0.0", "universal-pathlib", "numcodecs>=0.15.1", @@ -122,6 +122,8 @@ dev = [ "pytest-xdist", "ruff", "s3fs", + "lithops", + "dask", ] [project.urls] diff --git a/virtualizarr/__init__.py b/virtualizarr/__init__.py index bd70f834..9d7ef92e 100644 --- a/virtualizarr/__init__.py +++ b/virtualizarr/__init__.py @@ -1,6 +1,6 @@ from virtualizarr.manifests import ChunkManifest, ManifestArray # type: ignore # noqa from virtualizarr.accessor import VirtualiZarrDatasetAccessor # type: ignore # noqa -from virtualizarr.backend import open_virtual_dataset # noqa: F401 +from virtualizarr.backend import open_virtual_dataset, open_virtual_mfdataset # noqa: F401 from importlib.metadata import version as _version diff --git a/virtualizarr/backend.py b/virtualizarr/backend.py index 7e3be787..16901c9c 100644 --- a/virtualizarr/backend.py +++ b/virtualizarr/backend.py @@ -1,14 +1,26 @@ +import os import warnings from collections.abc import Iterable, Mapping +from concurrent.futures import Executor from enum import Enum, auto from pathlib import Path from typing import ( + TYPE_CHECKING, Any, + Callable, + Literal, Optional, + Sequence, + cast, ) -from xarray import Dataset, Index +import xarray as xr +from xarray import DataArray, Dataset, Index, combine_by_coords +from xarray.backends.common import _find_absolute_paths +from xarray.core.types import NestedSequence +from xarray.structure.combine import _infer_concat_order_from_positions, _nested_combine +from virtualizarr.parallel import get_executor from virtualizarr.readers import ( DMRPPVirtualBackend, FITSVirtualBackend, @@ -20,6 +32,14 @@ from virtualizarr.readers.api import VirtualBackend from virtualizarr.utils import _FsspecFSFromFilepath +if TYPE_CHECKING: + from xarray.core.types import ( + CombineAttrsOptions, + CompatOptions, + JoinOptions, + ) + + # TODO add entrypoint to allow external libraries to add to this mapping VIRTUAL_BACKENDS = { "kerchunk": KerchunkVirtualBackend, @@ -197,3 +217,189 @@ def open_virtual_dataset( ) return vds + + +def open_virtual_mfdataset( + paths: str + | os.PathLike + | Sequence[str | os.PathLike] + | "NestedSequence[str | os.PathLike]", + concat_dim: ( + str + | DataArray + | Index + | Sequence[str] + | Sequence[DataArray] + | Sequence[Index] + | None + ) = None, + compat: "CompatOptions" = "no_conflicts", + preprocess: Callable[[Dataset], Dataset] | None = None, + data_vars: Literal["all", "minimal", "different"] | list[str] = "all", + coords="different", + combine: Literal["by_coords", "nested"] = "by_coords", + parallel: Literal["dask", "lithops", False] | Executor = False, + join: "JoinOptions" = "outer", + attrs_file: str | os.PathLike | None = None, + combine_attrs: "CombineAttrsOptions" = "override", + **kwargs, +) -> Dataset: + """ + Open multiple files as a single virtual dataset. + + If combine='by_coords' then the function ``combine_by_coords`` is used to combine + the datasets into one before returning the result, and if combine='nested' then + ``combine_nested`` is used. The filepaths must be structured according to which + combining function is used, the details of which are given in the documentation for + ``combine_by_coords`` and ``combine_nested``. By default ``combine='by_coords'`` + will be used. Global attributes from the ``attrs_file`` are used + for the combined dataset. + + Parameters + ---------- + paths + Same as in xarray.open_mfdataset + concat_dim + Same as in xarray.open_mfdataset + compat + Same as in xarray.open_mfdataset + preprocess + Same as in xarray.open_mfdataset + data_vars + Same as in xarray.open_mfdataset + coords + Same as in xarray.open_mfdataset + combine + Same as in xarray.open_mfdataset + parallel : "dask", "lithops", False, or instance of a subclass of ``concurrent.futures.Executor`` + Specify whether the open and preprocess steps of this function will be + performed in parallel using lithops, dask.delayed, or any executor compatible + with the ``concurrent.futures`` interface, or in serial. + Default is False, which will execute these steps in serial. + join + Same as in xarray.open_mfdataset + attrs_file + Same as in xarray.open_mfdataset + combine_attrs + Same as in xarray.open_mfdataset + **kwargs : optional + Additional arguments passed on to :py:func:`virtualizarr.open_virtual_dataset`. For an + overview of some of the possible options, see the documentation of + :py:func:`virtualizarr.open_virtual_dataset`. + + Returns + ------- + xarray.Dataset + + Notes + ----- + The results of opening each virtual dataset in parallel are sent back to the client process, so must not be too large. + """ + + # TODO this is practically all just copied from xarray.open_mfdataset - an argument for writing a virtualizarr engine for xarray? + + # TODO list kwargs passed to open_virtual_dataset explicitly in docstring? + + paths = cast(NestedSequence[str], _find_absolute_paths(paths)) + + if not paths: + raise OSError("no files to open") + + paths1d: list[str] + if combine == "nested": + if isinstance(concat_dim, str | DataArray) or concat_dim is None: + concat_dim = [concat_dim] # type: ignore[assignment] + + # This creates a flat list which is easier to iterate over, whilst + # encoding the originally-supplied structure as "ids". + # The "ids" are not used at all if combine='by_coords`. + combined_ids_paths = _infer_concat_order_from_positions(paths) + ids, paths1d = ( + list(combined_ids_paths.keys()), + list(combined_ids_paths.values()), + ) + elif concat_dim is not None: + raise ValueError( + "When combine='by_coords', passing a value for `concat_dim` has no " + "effect. To manually combine along a specific dimension you should " + "instead specify combine='nested' along with a value for `concat_dim`.", + ) + else: + paths1d = paths # type: ignore[assignment] + + # TODO this refactored preprocess and executor logic should be upstreamed into xarray - see https://github.com/pydata/xarray/pull/9932 + + if preprocess: + # TODO we could reexpress these using functools.partial but then we would hit this lithops bug: https://github.com/lithops-cloud/lithops/issues/1428 + + def _open_and_preprocess(path: str) -> xr.Dataset: + ds = open_virtual_dataset(path, **kwargs) + return preprocess(ds) + + open_func = _open_and_preprocess + else: + + def _open(path: str) -> xr.Dataset: + return open_virtual_dataset(path, **kwargs) + + open_func = _open + + executor = get_executor(parallel=parallel) + with executor() as exec: + # wait for all the workers to finish, and send their resulting virtual datasets back to the client for concatenation there + virtual_datasets = list( + exec.map( + open_func, + paths1d, + ) + ) + + # TODO add file closers + + # Combine all datasets, closing them in case of a ValueError + try: + if combine == "nested": + # Combined nested list by successive concat and merge operations + # along each dimension, using structure given by "ids" + combined_vds = _nested_combine( + virtual_datasets, + concat_dims=concat_dim, + compat=compat, + data_vars=data_vars, + coords=coords, + ids=ids, + join=join, + combine_attrs=combine_attrs, + ) + elif combine == "by_coords": + # Redo ordering from coordinates, ignoring how they were ordered + # previously + combined_vds = combine_by_coords( + virtual_datasets, + compat=compat, + data_vars=data_vars, + coords=coords, + join=join, + combine_attrs=combine_attrs, + ) + else: + raise ValueError( + f"{combine} is an invalid option for the keyword argument ``combine``" + ) + except ValueError: + for vds in virtual_datasets: + vds.close() + raise + + # combined_vds.set_close(partial(_multi_file_closer, closers)) + + # read global attributes from the attrs_file or from the first dataset + if attrs_file is not None: + if isinstance(attrs_file, os.PathLike): + attrs_file = cast(str, os.fspath(attrs_file)) + combined_vds.attrs = virtual_datasets[paths1d.index(attrs_file)].attrs + + # TODO should we just immediately close everything? + # TODO If loadable_variables is eager then we should have already read everything we're ever going to read into memory at this point + + return combined_vds diff --git a/virtualizarr/manifests/store.py b/virtualizarr/manifests/store.py index 8149b4d7..05d5f124 100644 --- a/virtualizarr/manifests/store.py +++ b/virtualizarr/manifests/store.py @@ -43,7 +43,7 @@ from virtualizarr.vendor.zarr.metadata import dict_to_buffer if TYPE_CHECKING: - from obstore.store import ObjectStore + from obstore.store import ObjectStore # type: ignore[import-not-found] StoreDict: TypeAlias = dict[str, ObjectStore] diff --git a/virtualizarr/parallel.py b/virtualizarr/parallel.py new file mode 100644 index 00000000..f067189d --- /dev/null +++ b/virtualizarr/parallel.py @@ -0,0 +1,324 @@ +import inspect +import warnings +from concurrent.futures import Executor, Future +from typing import Any, Callable, Iterable, Iterator, Literal, TypeVar + +# TODO this entire module could ideally be upstreamed into xarray as part of https://github.com/pydata/xarray/pull/9932 +# TODO the DaskDelayedExecutor class could ideally be upstreamed into dask +# TODO lithops should just not require a special wrapper class, see https://github.com/lithops-cloud/lithops/issues/1427 + + +# Type variable for return type +T = TypeVar("T") + + +def get_executor( + parallel: Literal["dask", "lithops"] | Executor | Literal[False], +) -> type[Executor]: + """Get an executor that follows the concurrent.futures.Executor ABC API.""" + + if parallel == "dask": + return DaskDelayedExecutor + elif parallel == "lithops": + return LithopsEagerFunctionExecutor + elif parallel is False: + return SerialExecutor + elif inspect.isclass(parallel) and issubclass(parallel, Executor): + return parallel + else: + raise ValueError( + f"Unrecognized argument to ``parallel``: {parallel}" + "Please supply either ``'dask'``, ``'lithops'``, ``False``, or a concrete subclass of ``concurrent.futures.Executor``." + ) + + +class SerialExecutor(Executor): + """ + A custom Executor that runs tasks sequentially, mimicking the + concurrent.futures.Executor interface. Useful as a default and for debugging. + """ + + def __init__(self): + # Track submitted futures to maintain interface compatibility + self._futures: list[Future] = [] + + def submit(self, fn: Callable[..., T], /, *args: Any, **kwargs: Any) -> Future[T]: + """ + Submit a callable to be executed. + + Unlike parallel executors, this runs the task immediately and sequentially. + + Parameters + ---------- + fn + The callable to execute + args + Positional arguments for the callable + kwargs + Keyword arguments for the callable + + Returns + ------- + A Future representing the result of the execution + """ + # Create a future to maintain interface compatibility + future: Future = Future() + + try: + # Execute the function immediately + result = fn(*args, **kwargs) + + # Set the result of the future + future.set_result(result) + except Exception as e: + # If an exception occurs, set it on the future + future.set_exception(e) + + # Keep track of futures for potential cleanup + self._futures.append(future) + + return future + + def map( + self, + fn: Callable[..., T], + *iterables: Iterable[Any], + timeout: float | None = None, + chunksize: int = 1, + ) -> Iterator[T]: + """ + Execute a function over an iterable sequentially. + + Parameters + ---------- + fn + Function to apply to each item + iterables + Iterables to process + timeout + Optional timeout (ignored in serial execution) + + Returns + ------- + Generator of results + """ + return map(fn, *iterables) + + def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: + """ + Shutdown the executor. + + Parameters + ---------- + wait + Whether to wait for pending futures (always True for serial executor) + """ + # In a serial executor, shutdown is a no-op + pass + + +class DaskDelayedExecutor(Executor): + """ + An Executor that uses dask.delayed for parallel computation. + + This executor mimics the concurrent.futures.Executor interface but uses Dask's delayed computation model. + """ + + def __init__(self): + """Initialize the Dask Delayed Executor.""" + + # Track submitted futures + self._futures: list[Future] = [] + + def submit(self, fn: Callable[..., T], /, *args: Any, **kwargs: Any) -> Future[T]: + """ + Submit a task to be computed with dask.delayed. + + Parameters + ---------- + fn + The callable to execute + args + Positional arguments for the callable + kwargs + Keyword arguments for the callable + + Returns + ------- + A Future representing the result of the execution + """ + import dask # type: ignore[import-untyped] + + # Create a delayed computation + delayed_task = dask.delayed(fn)(*args, **kwargs) + + # Create a concurrent.futures Future to maintain interface compatibility + future: Future = Future() + + try: + # Compute the result + result = delayed_task.compute() + + # Set the result on the future + future.set_result(result) + except Exception as e: + # Set any exception on the future + future.set_exception(e) + + # Track the future + self._futures.append(future) + + return future + + def map( + self, + fn: Callable[..., T], + *iterables: Iterable[Any], + timeout: float | None = None, + chunksize: int = 1, + ) -> Iterator[T]: + """ + Apply a function to an iterable using dask.delayed. + + Parameters + ---------- + fn + Function to apply to each item + iterables + Iterables to process + timeout + Optional timeout (ignored in serial execution) + + Returns + ------- + Generator of results + """ + import dask # type: ignore[import-untyped] + + if timeout is not None: + warnings.warn("Timeout parameter is not directly supported by Dask delayed") + + # Create delayed computations for each item + delayed_tasks = [dask.delayed(fn)(*items) for items in zip(*iterables)] + + # Compute all tasks + return iter(dask.compute(*delayed_tasks)) + + def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: + """ + Shutdown the executor + + Parameters + ---------- + wait + Whether to wait for pending futures (always True for serial executor)) + """ + # For Dask.delayed, shutdown is essentially a no-op + pass + + +class LithopsEagerFunctionExecutor(Executor): + """ + Lithops-based function executor which follows the concurrent.futures.Executor API. + + Only required because lithops doesn't follow the concurrent.futures.Executor API, see https://github.com/lithops-cloud/lithops/issues/1427. + """ + + def __init__(self, **kwargs): + import lithops # type: ignore[import-untyped] + + # Create Lithops client with optional configuration + self.lithops_client = lithops.FunctionExecutor(**kwargs) + + # Track submitted futures + self._futures: list[Future] = [] + + def submit(self, fn: Callable[..., T], /, *args: Any, **kwargs: Any) -> Future[T]: + """ + Submit a task to be computed using lithops. + + Parameters + ---------- + fn + The callable to execute + args + Positional arguments for the callable + kwargs + Keyword arguments for the callable + + Returns + ------- + A concurrent.futures.Future representing the result of the execution + """ + + # Create a concurrent.futures Future to maintain interface compatibility + future: Future = Future() + + try: + # Submit to Lithops + lithops_future = self.lithops_client.call_async(fn, *args, **kwargs) + + # Add a callback to set the result or exception + def _on_done(lithops_result): + try: + result = lithops_result.result() + future.set_result(result) + except Exception as e: + future.set_exception(e) + + # Register the callback + lithops_future.add_done_callback(_on_done) + except Exception as e: + # If submission fails, set exception immediately + future.set_exception(e) + + # Track the future + self._futures.append(future) + + return future + + def map( + self, + fn: Callable[..., T], + *iterables: Iterable[Any], + timeout: float | None = None, + chunksize: int = 1, + ) -> Iterator[T]: + """ + Apply a function to an iterable using lithops. + + Only needed because lithops.FunctionExecutor.map returns futures, unlike ``concurrent.futures.Executor.map``. + + Parameters + ---------- + fn + Function to apply to each item + iterables + Iterables to process + timeout + Optional timeout (ignored in serial execution) + + Returns + ------- + Generator of results + """ + import lithops # type: ignore[import-untyped] + + fexec = lithops.FunctionExecutor() + + futures = fexec.map(fn, *iterables) + results = fexec.get_result(futures) + + return results + + def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: + """ + Shutdown the executor. + + Parameters + ---------- + wait + Whether to wait for pending futures. + """ + # Should this call lithops .clean() method? + pass diff --git a/virtualizarr/tests/__init__.py b/virtualizarr/tests/__init__.py index bc1d5f2e..58682e32 100644 --- a/virtualizarr/tests/__init__.py +++ b/virtualizarr/tests/__init__.py @@ -35,11 +35,13 @@ def _importorskip( has_kerchunk, requires_kerchunk = _importorskip("kerchunk") has_fastparquet, requires_fastparquet = _importorskip("fastparquet") has_s3fs, requires_s3fs = _importorskip("s3fs") +has_lithops, requires_lithops = _importorskip("lithops") has_scipy, requires_scipy = _importorskip("scipy") has_tifffile, requires_tifffile = _importorskip("tifffile") has_imagecodecs, requires_imagecodecs = _importorskip("imagecodecs") has_hdf5plugin, requires_hdf5plugin = _importorskip("hdf5plugin") has_zarr_python, requires_zarr_python = _importorskip("zarr") +has_dask, requires_dask = _importorskip("dask") has_obstore, requires_obstore = _importorskip("obstore") parametrize_over_hdf_backends = pytest.mark.parametrize( diff --git a/virtualizarr/tests/test_backend.py b/virtualizarr/tests/test_backend.py index 4d23506c..291fb9f7 100644 --- a/virtualizarr/tests/test_backend.py +++ b/virtualizarr/tests/test_backend.py @@ -1,4 +1,6 @@ +import functools from collections.abc import Mapping +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from unittest.mock import patch @@ -9,7 +11,7 @@ from xarray import Dataset, open_dataset from xarray.core.indexes import Index -from virtualizarr import open_virtual_dataset +from virtualizarr import open_virtual_dataset, open_virtual_mfdataset from virtualizarr.backend import ( FileType, VirtualBackend, @@ -21,8 +23,10 @@ from virtualizarr.tests import ( has_astropy, parametrize_over_hdf_backends, + requires_dask, requires_hdf5plugin, requires_imagecodecs, + requires_lithops, requires_network, requires_s3fs, requires_scipy, @@ -502,3 +506,88 @@ def test_open_dataset_with_scalar(self, hdf5_scalar, hdf_backend): with open_virtual_dataset(hdf5_scalar, backend=hdf_backend) as vds: assert vds.scalar.dims == () assert vds.scalar.attrs == {"scalar": "true"} + + +preprocess_func = functools.partial( + xr.Dataset.rename_vars, + air="nair", +) + + +@requires_hdf5plugin +@requires_imagecodecs +@parametrize_over_hdf_backends +class TestOpenVirtualMFDataset: + @pytest.mark.parametrize("invalid_parallel_kwarg", ["ray", Dataset]) + def test_invalid_parallel_kwarg( + self, netcdf4_files_factory, invalid_parallel_kwarg, hdf_backend + ): + filepath1, filepath2 = netcdf4_files_factory() + + with pytest.raises(ValueError, match="Unrecognized argument"): + open_virtual_mfdataset( + [filepath1, filepath2], + combine="nested", + concat_dim="time", + backend=hdf_backend, + parallel=invalid_parallel_kwarg, + ) + + @pytest.mark.parametrize( + "parallel", + [ + False, + ThreadPoolExecutor, + pytest.param("dask", marks=requires_dask), + pytest.param("lithops", marks=requires_lithops), + ], + ) + @pytest.mark.parametrize( + "preprocess", + [ + None, + preprocess_func, + ], + ) + def test_parallel_open( + self, netcdf4_files_factory, hdf_backend, parallel, preprocess + ): + filepath1, filepath2 = netcdf4_files_factory() + vds1 = open_virtual_dataset(filepath1, backend=hdf_backend) + vds2 = open_virtual_dataset(filepath2, backend=hdf_backend) + + expected_vds = xr.concat([vds1, vds2], dim="time") + if preprocess: + expected_vds = preprocess_func(expected_vds) + + # test combine nested, which doesn't use in-memory indexes + combined_vds = open_virtual_mfdataset( + [filepath1, filepath2], + combine="nested", + concat_dim="time", + backend=hdf_backend, + parallel=parallel, + preprocess=preprocess, + ) + xrt.assert_identical(combined_vds, expected_vds) + + # test combine by coords using in-memory indexes + combined_vds = open_virtual_mfdataset( + [filepath1, filepath2], + combine="by_coords", + backend=hdf_backend, + parallel=parallel, + preprocess=preprocess, + ) + xrt.assert_identical(combined_vds, expected_vds) + + # test combine by coords again using in-memory indexes but for a glob + file_glob = Path(filepath1).parent.glob("air*.nc") + combined_vds = open_virtual_mfdataset( + file_glob, + combine="by_coords", + backend=hdf_backend, + parallel=parallel, + preprocess=preprocess, + ) + xrt.assert_identical(combined_vds, expected_vds)