diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 10526214a08..92f3590f02c 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -28,7 +28,7 @@ ) from ..core.dataarray import DataArray from ..core.dataset import Dataset, _get_chunk, _maybe_chunk -from ..core.utils import close_on_error, is_grib_path, is_remote_uri +from ..core.utils import close_on_error, is_grib_path, is_remote_uri, read_magic_number from .common import AbstractDataStore, ArrayWriter from .locks import _get_scheduler @@ -121,17 +121,7 @@ def _get_default_engine_netcdf(): def _get_engine_from_magic_number(filename_or_obj): - # check byte header to determine file type - if isinstance(filename_or_obj, bytes): - magic_number = filename_or_obj[:8] - else: - if filename_or_obj.tell() != 0: - raise ValueError( - "file-like object read/write pointer not at zero " - "please close and reopen, or use a context manager" - ) - magic_number = filename_or_obj.read(8) - filename_or_obj.seek(0) + magic_number = read_magic_number(filename_or_obj) if magic_number.startswith(b"CDF"): engine = "scipy" @@ -139,6 +129,7 @@ def _get_engine_from_magic_number(filename_or_obj): engine = "h5netcdf" else: raise ValueError( + "cannot guess the engine, " f"{magic_number} is not the signature of any supported file format " "did you mean to pass a string for a path instead?" ) diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py index ab065027da1..d1974c7f3f8 100644 --- a/xarray/backends/apiv2.py +++ b/xarray/backends/apiv2.py @@ -4,11 +4,7 @@ from ..core.dataset import _get_chunk, _maybe_chunk from ..core.utils import is_remote_uri from . import plugins -from .api import ( - _autodetect_engine, - _get_backend_cls, - _protect_dataset_variables_inplace, -) +from .api import _get_backend_cls, _protect_dataset_variables_inplace def _get_mtime(filename_or_obj): @@ -245,7 +241,7 @@ def open_dataset( backend_kwargs = {} if engine is None: - engine = _autodetect_engine(filename_or_obj) + engine = plugins.guess_engine(filename_or_obj) engines = plugins.list_engines() backend = _get_backend_cls(engine, engines=engines) diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index 4559df4cb74..67ad63d682f 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -1,3 +1,5 @@ +import os + import numpy as np from ..core import indexing @@ -73,6 +75,14 @@ def get_encoding(self): return encoding +def guess_can_open_cfgrib(store_spec): + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".grib", ".grib2", ".grb", ".grb2"} + + def open_backend_dataset_cfgrib( filename_or_obj, *, @@ -117,4 +127,6 @@ def open_backend_dataset_cfgrib( return ds -cfgrib_backend = BackendEntrypoint(open_dataset=open_backend_dataset_cfgrib) +cfgrib_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_cfgrib, guess_can_open=guess_can_open_cfgrib +) diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index d31f3da2a82..c39db2ae85d 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -1,10 +1,12 @@ import functools +import io +import os from distutils.version import LooseVersion import numpy as np from ..core import indexing -from ..core.utils import FrozenDict, is_remote_uri +from ..core.utils import FrozenDict, is_remote_uri, read_magic_number from ..core.variable import Variable from .common import WritableCFDataStore, find_root_and_group from .file_manager import CachingFileManager, DummyFileManager @@ -128,19 +130,12 @@ def open( "can't open netCDF4/HDF5 as bytes " "try passing a path or file-like object" ) - elif hasattr(filename, "tell"): - if filename.tell() != 0: + elif isinstance(filename, io.IOBase): + magic_number = read_magic_number(filename) + if not magic_number.startswith(b"\211HDF\r\n\032\n"): raise ValueError( - "file-like object read/write pointer not at zero " - "please close and reopen, or use a context manager" + f"{magic_number} is not the signature of a valid netCDF file" ) - else: - magic_number = filename.read(8) - filename.seek(0) - if not magic_number.startswith(b"\211HDF\r\n\032\n"): - raise ValueError( - f"{magic_number} is not the signature of a valid netCDF file" - ) if format not in [None, "NETCDF4"]: raise ValueError("invalid format for h5netcdf backend") @@ -325,6 +320,20 @@ def close(self, **kwargs): self._manager.close(**kwargs) +def guess_can_open_h5netcdf(store_spec): + try: + return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n") + except TypeError: + pass + + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + + return ext in {".nc", ".nc4", ".cdf"} + + def open_backend_dataset_h5netcdf( filename_or_obj, *, @@ -364,4 +373,6 @@ def open_backend_dataset_h5netcdf( return ds -h5netcdf_backend = BackendEntrypoint(open_dataset=open_backend_dataset_h5netcdf) +h5netcdf_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_h5netcdf, guess_can_open=guess_can_open_h5netcdf +) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index dac56bf4ae3..8b6fa1e17b7 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -505,6 +505,16 @@ def close(self, **kwargs): self._manager.close(**kwargs) +def guess_can_open_netcdf4(store_spec): + if isinstance(store_spec, str) and is_remote_uri(store_spec): + return True + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".nc", ".nc4", ".cdf"} + + def open_backend_dataset_netcdf4( filename_or_obj, mask_and_scale=True, @@ -550,4 +560,6 @@ def open_backend_dataset_netcdf4( return ds -netcdf4_backend = BackendEntrypoint(open_dataset=open_backend_dataset_netcdf4) +netcdf4_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_netcdf4, guess_can_open=guess_can_open_netcdf4 +) diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index 9e78a3b71e8..75d03aa3a64 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -1,5 +1,6 @@ import inspect import itertools +import logging import warnings from functools import lru_cache @@ -7,11 +8,12 @@ class BackendEntrypoint: - __slots__ = ("open_dataset", "open_dataset_parameters") + __slots__ = ("guess_can_open", "open_dataset", "open_dataset_parameters") - def __init__(self, open_dataset, open_dataset_parameters=None): + def __init__(self, open_dataset, open_dataset_parameters=None, guess_can_open=None): self.open_dataset = open_dataset self.open_dataset_parameters = open_dataset_parameters + self.guess_can_open = guess_can_open def remove_duplicates(backend_entrypoints): @@ -76,3 +78,21 @@ def list_engines(): engines = create_engines_dict(backend_entrypoints) set_missing_parameters(engines) return engines + + +def guess_engine(store_spec): + engines = list_engines() + + # use the pre-defined selection order for netCDF files + for engine in ["netcdf4", "h5netcdf", "scipy"]: + if engine in engines and engines[engine].guess_can_open(store_spec): + return engine + + for engine, backend in engines.items(): + try: + if backend.guess_can_open and backend.guess_can_open(store_spec): + return engine + except Exception: + logging.exception(f"{engine!r} fails while guessing") + + raise ValueError("cannot guess the engine, try passing one explicitly") diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 1a8a23687f7..5f2762dfa15 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -2,7 +2,7 @@ from ..core import indexing from ..core.pycompat import integer_types -from ..core.utils import Frozen, FrozenDict, close_on_error, is_dict_like +from ..core.utils import Frozen, FrozenDict, close_on_error, is_dict_like, is_remote_uri from ..core.variable import Variable from .common import AbstractDataStore, BackendArray, robust_getitem from .plugins import BackendEntrypoint @@ -96,6 +96,10 @@ def get_dimensions(self): return Frozen(self.ds.dimensions) +def guess_can_open_pydap(store_spec): + return isinstance(store_spec, str) and is_remote_uri(store_spec) + + def open_backend_dataset_pydap( filename_or_obj, mask_and_scale=True, @@ -127,4 +131,6 @@ def open_backend_dataset_pydap( return ds -pydap_backend = BackendEntrypoint(open_dataset=open_backend_dataset_pydap) +pydap_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_pydap, guess_can_open=guess_can_open_pydap +) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 6f3b74238fc..1f26812342e 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,9 +1,10 @@ -from io import BytesIO +import io +import os import numpy as np from ..core.indexing import NumpyIndexingAdapter -from ..core.utils import Frozen, FrozenDict, close_on_error +from ..core.utils import Frozen, FrozenDict, close_on_error, read_magic_number from ..core.variable import Variable from .common import BackendArray, WritableCFDataStore from .file_manager import CachingFileManager, DummyFileManager @@ -78,7 +79,7 @@ def _open_scipy_netcdf(filename, mode, mmap, version): if isinstance(filename, bytes) and filename.startswith(b"CDF"): # it's a NetCDF3 bytestring - filename = BytesIO(filename) + filename = io.BytesIO(filename) try: return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, version=version) @@ -222,6 +223,19 @@ def close(self): self._manager.close() +def guess_can_open_scipy(store_spec): + try: + return read_magic_number(store_spec).startswith(b"CDF") + except TypeError: + pass + + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".nc", ".nc4", ".cdf", ".gz"} + + def open_backend_dataset_scipy( filename_or_obj, mask_and_scale=True, @@ -255,4 +269,6 @@ def open_backend_dataset_scipy( return ds -scipy_backend = BackendEntrypoint(open_dataset=open_backend_dataset_scipy) +scipy_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_scipy, guess_can_open=guess_can_open_scipy +) diff --git a/xarray/backends/store.py b/xarray/backends/store.py index d51a835f467..1e1edab555d 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -1,8 +1,13 @@ from .. import conventions from ..core.dataset import Dataset +from .common import AbstractDataStore from .plugins import BackendEntrypoint +def guess_can_open_store(store_spec): + return isinstance(store_spec, AbstractDataStore) + + def open_backend_dataset_store( store, *, @@ -38,4 +43,6 @@ def open_backend_dataset_store( return ds -store_backend = BackendEntrypoint(open_dataset=open_backend_dataset_store) +store_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_store, guess_can_open=guess_can_open_store +) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 1a98b24b9b7..093b30d088d 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -2,6 +2,7 @@ """ import contextlib import functools +import io import itertools import os.path import re @@ -603,6 +604,23 @@ def is_remote_uri(path: str) -> bool: return bool(re.search(r"^https?\://", path)) +def read_magic_number(filename_or_obj, count=8): + # check byte header to determine file type + if isinstance(filename_or_obj, bytes): + magic_number = filename_or_obj[:count] + elif isinstance(filename_or_obj, io.IOBase): + if filename_or_obj.tell() != 0: + raise ValueError( + "file-like object read/write pointer not at the start of the file, " + "please close and reopen, or use a context manager" + ) + magic_number = filename_or_obj.read(count) + filename_or_obj.seek(0) + else: + raise TypeError(f"cannot read the magic number form {type(filename_or_obj)}") + return magic_number + + def is_grib_path(path: str) -> bool: _, ext = os.path.splitext(path) return ext in [".grib", ".grb", ".grib2", ".grb2"] diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 5e8758419ac..1ddc16c52e4 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2623,7 +2623,7 @@ def test_open_badbytes(self): with raises_regex(ValueError, "HDF5 as bytes"): with open_dataset(b"\211HDF\r\n\032\n", engine="h5netcdf"): pass - with raises_regex(ValueError, "not the signature of any supported file"): + with raises_regex(ValueError, "cannot guess the engine"): with open_dataset(b"garbage"): pass with raises_regex(ValueError, "can only read bytes"): @@ -2636,7 +2636,7 @@ def test_open_badbytes(self): def test_open_twice(self): expected = create_test_data() expected.attrs["foo"] = "bar" - with raises_regex(ValueError, "read/write pointer not at zero"): + with raises_regex(ValueError, "read/write pointer not at the start"): with create_tmp_file() as tmp_file: expected.to_netcdf(tmp_file, engine="h5netcdf") with open(tmp_file, "rb") as f: @@ -2669,7 +2669,7 @@ def test_open_fileobj(self): open_dataset(f, engine="scipy") f.seek(8) - with raises_regex(ValueError, "read/write pointer not at zero"): + with raises_regex(ValueError, "read/write pointer not at the start"): open_dataset(f)