diff --git a/xarray/backends/api.py b/xarray/backends/api.py index c7481e22b59..7e81870d653 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1,5 +1,6 @@ import os.path import warnings +from collections.abc import MutableMapping from glob import glob from io import BytesIO from numbers import Number @@ -345,7 +346,7 @@ def open_dataset( If True, decode the 'coordinates' attribute to identify coordinates in the resulting dataset. engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'cfgrib', \ - 'pseudonetcdf'}, optional + 'pseudonetcdf', 'zarr'}, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. @@ -383,6 +384,10 @@ def open_dataset( represented using ``np.datetime64[ns]`` objects. If False, always decode times to ``np.datetime64[ns]`` objects; if this is not possible raise an error. + overwrite_encoded_chunks: bool, optional + Whether to drop the zarr chunks encoded for each variable when a + dataset is loaded with specified chunk sizes (default: False) + Returns ------- @@ -409,7 +414,9 @@ def open_dataset( "pynio", "cfgrib", "pseudonetcdf", + "zarr", ] + if engine not in engines: raise ValueError( "unrecognized engine for open_dataset: {}\n" @@ -455,36 +462,7 @@ def maybe_decode_store(store, lock=False): _protect_dataset_variables_inplace(ds, cache) - if chunks is not None: - from dask.base import tokenize - - # if passed an actual file path, augment the token with - # the file modification time - if isinstance(filename_or_obj, str) and not is_remote_uri(filename_or_obj): - mtime = os.path.getmtime(filename_or_obj) - else: - mtime = None - token = tokenize( - filename_or_obj, - mtime, - group, - decode_cf, - mask_and_scale, - decode_times, - concat_characters, - decode_coords, - engine, - chunks, - drop_variables, - use_cftime, - ) - name_prefix = "open_dataset-%s" % token - ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) - ds2._file_obj = ds._file_obj - else: - ds2 = ds - - return ds2 + return ds if isinstance(filename_or_obj, Path): filename_or_obj = str(filename_or_obj) @@ -492,6 +470,17 @@ def maybe_decode_store(store, lock=False): if isinstance(filename_or_obj, AbstractDataStore): store = filename_or_obj + if isinstance(filename_or_obj, MutableMapping): + if engine == 'zarr': + # on ZarrStore, mode='r', synchronizer=None, group=None, + # consolidated=False. + overwrite_encoded_chunks = backend_kwargs.pop("overwrite_encoded_chunks", None) + store = backends.ZarrStore.open_group( + filename_or_obj, + group=group, + **backend_kwargs + ) + elif isinstance(filename_or_obj, str): filename_or_obj = _normalize_path(filename_or_obj) @@ -519,7 +508,15 @@ def maybe_decode_store(store, lock=False): store = backends.CfGribDataStore( filename_or_obj, lock=lock, **backend_kwargs ) - + elif engine == 'zarr': + # on ZarrStore, mode='r', synchronizer=None, group=None, + # consolidated=False. + overwrite_encoded_chunks = backend_kwargs.pop("overwrite_encoded_chunks", None) + store = backends.ZarrStore.open_group( + filename_or_obj, + group=group, + **backend_kwargs + ) else: if engine not in [None, "scipy", "h5netcdf"]: raise ValueError( @@ -542,7 +539,52 @@ def maybe_decode_store(store, lock=False): if isinstance(filename_or_obj, str): ds.encoding["source"] = filename_or_obj - return ds + if chunks is not None: + from dask.base import tokenize + if engine != 'zarr': + + # if passed an actual file path, augment the token with + # the file modification time + if isinstance(filename_or_obj, str) and not is_remote_uri(filename_or_obj): + mtime = os.path.getmtime(filename_or_obj) + else: + mtime = None + token = tokenize( + filename_or_obj, + mtime, + group, + decode_cf, + mask_and_scale, + decode_times, + concat_characters, + decode_coords, + engine, + chunks, + drop_variables, + use_cftime, + ) + name_prefix = "open_dataset-%s" % token + ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) + ds2._file_obj = ds._file_obj + + else: + + # adapted from Dataset.Chunk() and taken from open_zarr + if not isinstance(chunks, (int, dict)): + if chunks != "auto": + raise ValueError( + "chunks must be an int, dict, 'auto', or None. " + "Instead found %s. " % chunks + ) + if isinstance(chunks, int): + chunks = dict.fromkeys(ds.dims, chunks) + + variables = {k: backends.ZarrStore.open_group.maybe_chunk(k, v, chunks) for k, v in ds.variables.items()} + ds2 = ds._replace_vars_and_dims(variables) + return ds2 + else: + ds2 = ds + return ds2 def open_dataarray( diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 973c167911e..c6e4d1b362a 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -9,6 +9,8 @@ from ..core.variable import Variable from .common import AbstractWritableDataStore, BackendArray, _encode_variable_name +from .api import open_dataset + # need some special secret attributes to tell us the dimensions DIMENSION_KEY = "_ARRAY_DIMENSIONS" @@ -356,6 +358,48 @@ def encode_variable(self, variable): def encode_attribute(self, a): return encode_zarr_attr_value(a) + + def get_chunk(name, var, chunks): + chunk_spec = dict(zip(var.dims, var.encoding.get("chunks"))) + + # Coordinate labels aren't chunked + if var.ndim == 1 and var.dims[0] == name: + return chunk_spec + + if chunks == "auto": + return chunk_spec + + for dim in var.dims: + if dim in chunks: + spec = chunks[dim] + if isinstance(spec, int): + spec = (spec,) + if isinstance(spec, (tuple, list)) and chunk_spec[dim]: + if any(s % chunk_spec[dim] for s in spec): + warnings.warn( + "Specified Dask chunks %r would " + "separate Zarr chunk shape %r for " + "dimension %r. This significantly " + "degrades performance. Consider " + "rechunking after loading instead." + % (chunks[dim], chunk_spec[dim], dim), + stacklevel=2, + ) + chunk_spec[dim] = chunks[dim] + return chunk_spec + + def maybe_chunk(name, var, chunks): + chunk_spec = get_chunk(name, var, chunks) + + if (var.ndim > 0) and (chunk_spec is not None): + # does this cause any data to be read? + token2 = tokenize(name, var._data) + name2 = "zarr-%s" % token2 + var = var.chunk(chunk_spec, name=name2, lock=None) + if overwrite_encoded_chunks and var.chunks is not None: + var.encoding["chunks"] = tuple(x[0] for x in var.chunk) + return var + def store( self, variables, @@ -569,126 +613,35 @@ def open_zarr( ---------- http://zarr.readthedocs.io/ """ - if "auto_chunk" in kwargs: - auto_chunk = kwargs.pop("auto_chunk") - if auto_chunk: - chunks = "auto" # maintain backwards compatibility - else: - chunks = None - warnings.warn( - "auto_chunk is deprecated. Use chunks='auto' instead.", - FutureWarning, - stacklevel=2, - ) + warnings.warn( + "open_zarr is deprecated. Use open_dataset(engine='zarr') instead.", + DeprecationWarning, + ) if kwargs: raise TypeError( "open_zarr() got unexpected keyword arguments " + ",".join(kwargs.keys()) ) - if not isinstance(chunks, (int, dict)): - if chunks != "auto" and chunks is not None: - raise ValueError( - "chunks must be an int, dict, 'auto', or None. " - "Instead found %s. " % chunks - ) - - if chunks == "auto": - try: - import dask.array # noqa - except ImportError: - chunks = None - - if not decode_cf: - mask_and_scale = False - decode_times = False - concat_characters = False - decode_coords = False - - def maybe_decode_store(store, lock=False): - ds = conventions.decode_cf( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - ) + backend_kwargs = { + "synchronizer": synchronizer, + "consolidated": consolidated, + "overwrite_encoded_chunks": overwrite_encoded_chunks, + } - # TODO: this is where we would apply caching - - return ds - - # Zarr supports a wide range of access modes, but for now xarray either - # reads or writes from a store, never both. For open_zarr, we only read - mode = "r" - zarr_store = ZarrStore.open_group( - store, - mode=mode, - synchronizer=synchronizer, + ds = open_dataset( + filename_or_obj=store, group=group, - consolidated=consolidated, + decode_cf=decode_cf, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + engine="zarr", + chunks=chunks, + drop_variables=drop_variables, + backend_kwargs=backend_kwargs, ) - ds = maybe_decode_store(zarr_store) - - # auto chunking needs to be here and not in ZarrStore because variable - # chunks do not survive decode_cf - # return trivial case - if not chunks: - return ds - - # adapted from Dataset.Chunk() - if isinstance(chunks, int): - chunks = dict.fromkeys(ds.dims, chunks) - - if isinstance(chunks, tuple) and len(chunks) == len(ds.dims): - chunks = dict(zip(ds.dims, chunks)) - - def get_chunk(name, var, chunks): - chunk_spec = dict(zip(var.dims, var.encoding.get("chunks"))) - - # Coordinate labels aren't chunked - if var.ndim == 1 and var.dims[0] == name: - return chunk_spec - - if chunks == "auto": - return chunk_spec - - for dim in var.dims: - if dim in chunks: - spec = chunks[dim] - if isinstance(spec, int): - spec = (spec,) - if isinstance(spec, (tuple, list)) and chunk_spec[dim]: - if any(s % chunk_spec[dim] for s in spec): - warnings.warn( - "Specified Dask chunks %r would " - "separate Zarr chunk shape %r for " - "dimension %r. This significantly " - "degrades performance. Consider " - "rechunking after loading instead." - % (chunks[dim], chunk_spec[dim], dim), - stacklevel=2, - ) - chunk_spec[dim] = chunks[dim] - return chunk_spec - - def maybe_chunk(name, var, chunks): - from dask.base import tokenize - - chunk_spec = get_chunk(name, var, chunks) - - if (var.ndim > 0) and (chunk_spec is not None): - # does this cause any data to be read? - token2 = tokenize(name, var._data) - name2 = "zarr-%s" % token2 - var = var.chunk(chunk_spec, name=name2, lock=None) - if overwrite_encoded_chunks and var.chunks is not None: - var.encoding["chunks"] = tuple(x[0] for x in var.chunks) - return var - else: - return var - variables = {k: maybe_chunk(k, v, chunks) for k, v in ds.variables.items()} - return ds._replace_vars_and_dims(variables) + return ds diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 916c29ba7bd..0af34612b08 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1510,7 +1510,7 @@ def save(self, dataset, store_target, **kwargs): @contextlib.contextmanager def open(self, store_target, **kwargs): - with xr.open_zarr(store_target, **kwargs) as ds: + with xr.open_dataset(store_target, engine="zarr", chunks="auto" **kwargs) as ds: yield ds @contextlib.contextmanager @@ -1799,7 +1799,7 @@ def test_write_persistence_modes(self, group): ds.to_zarr(store_target, mode="w", group=group) ds_to_append.to_zarr(store_target, append_dim="time", group=group) original = xr.concat([ds, ds_to_append], dim="time") - actual = xr.open_zarr(store_target, group=group) + actual = xr.open_dataset(store_target, group=group, engine="zarr", chunks="auto") assert_identical(original, actual) def test_compressor_encoding(self): @@ -1839,7 +1839,7 @@ def test_append_write(self): ds.to_zarr(store_target, mode="w") ds_to_append.to_zarr(store_target, append_dim="time") original = xr.concat([ds, ds_to_append], dim="time") - assert_identical(original, xr.open_zarr(store_target)) + assert_identical(original, xr.open_dataset(store_target, chunks="auto")) @pytest.mark.xfail(reason="Zarr stores can not be appended to") def test_append_overwrite_values(self): @@ -1901,11 +1901,11 @@ def test_check_encoding_is_consistent_after_append(self): encoding = {"da": {"compressor": compressor}} ds.to_zarr(store_target, mode="w", encoding=encoding) ds_to_append.to_zarr(store_target, append_dim="time") - actual_ds = xr.open_zarr(store_target) + actual_ds = xr.open_dataset(store_target, chunks="auto") actual_encoding = actual_ds["da"].encoding["compressor"] assert actual_encoding.get_config() == compressor.get_config() assert_identical( - xr.open_zarr(store_target).compute(), + xr.open_dataset(store_target, chunks="auto").compute(), xr.concat([ds, ds_to_append], dim="time"), ) @@ -1920,7 +1920,7 @@ def test_append_with_new_variable(self): ds_with_new_var.to_zarr(store_target, mode="a") combined = xr.concat([ds, ds_to_append], dim="time") combined["new_var"] = ds_with_new_var["new_var"] - assert_identical(combined, xr.open_zarr(store_target)) + assert_identical(combined, xr.open_dataset(store_target, chunks="auto")) @requires_dask def test_to_zarr_compute_false_roundtrip(self): diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 8011171d223..07d738f0bc5 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -151,7 +151,7 @@ def test_dask_distributed_zarr_integration_test(loop, consolidated, compute): ) if not compute: maybe_futures.compute() - with xr.open_zarr(filename, **read_kwargs) as restored: + with xr.open_dataset(filename, chunks="auto", **read_kwargs) as restored: assert isinstance(restored.var1.data, da.Array) computed = restored.compute() assert_allclose(original, computed)