Skip to content

Sort backends #4886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions xarray/backends/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@

from .common import BACKEND_ENTRYPOINTS

STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"]

def remove_duplicates(backend_entrypoints):

def remove_duplicates(pkg_entrypoints):

# sort and group entrypoints by name
backend_entrypoints = sorted(backend_entrypoints, key=lambda ep: ep.name)
backend_entrypoints_grouped = itertools.groupby(
backend_entrypoints, key=lambda ep: ep.name
)
pkg_entrypoints = sorted(pkg_entrypoints, key=lambda ep: ep.name)
pkg_entrypoints_grouped = itertools.groupby(pkg_entrypoints, key=lambda ep: ep.name)
# check if there are multiple entrypoints for the same name
unique_backend_entrypoints = []
for name, matches in backend_entrypoints_grouped:
unique_pkg_entrypoints = []
for name, matches in pkg_entrypoints_grouped:
matches = list(matches)
unique_backend_entrypoints.append(matches[0])
unique_pkg_entrypoints.append(matches[0])
matches_len = len(matches)
if matches_len > 1:
selected_module_name = matches[0].module_name
Expand All @@ -30,7 +30,7 @@ def remove_duplicates(backend_entrypoints):
f"\n {all_module_names}.\n It will be used: {selected_module_name}.",
RuntimeWarning,
)
return unique_backend_entrypoints
return unique_pkg_entrypoints


def detect_parameters(open_dataset):
Expand All @@ -51,13 +51,13 @@ def detect_parameters(open_dataset):
return tuple(parameters_list)


def create_engines_dict(backend_entrypoints):
engines = {}
for backend_ep in backend_entrypoints:
name = backend_ep.name
backend = backend_ep.load()
engines[name] = backend
return engines
def backends_dict_from_pkg(pkg_entrypoints):
backend_entrypoints = {}
for pkg_ep in pkg_entrypoints:
name = pkg_ep.name
backend = pkg_ep.load()
backend_entrypoints[name] = backend
return backend_entrypoints


def set_missing_parameters(backend_entrypoints):
Expand All @@ -67,11 +67,23 @@ def set_missing_parameters(backend_entrypoints):
backend.open_dataset_parameters = detect_parameters(open_dataset)


def build_engines(entrypoints):
def sort_backends(backend_entrypoints):
ordered_backends_entrypoints = {}
for be_name in STANDARD_BACKENDS_ORDER:
if backend_entrypoints.get(be_name, None) is not None:
ordered_backends_entrypoints[be_name] = backend_entrypoints.pop(be_name)
ordered_backends_entrypoints.update(
{name: backend_entrypoints[name] for name in sorted(backend_entrypoints)}
)
return ordered_backends_entrypoints


def build_engines(pkg_entrypoints):
backend_entrypoints = BACKEND_ENTRYPOINTS.copy()
pkg_entrypoints = remove_duplicates(entrypoints)
external_backend_entrypoints = create_engines_dict(pkg_entrypoints)
pkg_entrypoints = remove_duplicates(pkg_entrypoints)
external_backend_entrypoints = backends_dict_from_pkg(pkg_entrypoints)
backend_entrypoints.update(external_backend_entrypoints)
backend_entrypoints = sort_backends(backend_entrypoints)
set_missing_parameters(backend_entrypoints)
engines = {}
for name, backend in backend_entrypoints.items():
Expand All @@ -81,8 +93,8 @@ def build_engines(entrypoints):

@functools.lru_cache(maxsize=1)
def list_engines():
entrypoints = pkg_resources.iter_entry_points("xarray.backends")
return build_engines(entrypoints)
pkg_entrypoints = pkg_resources.iter_entry_points("xarray.backends")
return build_engines(pkg_entrypoints)


def guess_engine(store_spec):
Expand Down
34 changes: 32 additions & 2 deletions xarray/tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def test_remove_duplicates_warnings(dummy_duplicated_entrypoints):


@mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=None))
def test_create_engines_dict():
def test_backends_dict_from_pkg():
specs = [
"engine1 = xarray.tests.test_plugins:backend_1",
"engine2 = xarray.tests.test_plugins:backend_2",
]
entrypoints = [pkg_resources.EntryPoint.parse(spec) for spec in specs]
engines = plugins.create_engines_dict(entrypoints)
engines = plugins.backends_dict_from_pkg(entrypoints)
assert len(engines) == 2
assert engines.keys() == set(("engine1", "engine2"))

Expand Down Expand Up @@ -111,8 +111,38 @@ def test_build_engines():
"cfgrib = xarray.tests.test_plugins:backend_1"
)
backend_entrypoints = plugins.build_engines([dummy_pkg_entrypoint])

assert isinstance(backend_entrypoints["cfgrib"], DummyBackendEntrypoint1)
assert backend_entrypoints["cfgrib"].open_dataset_parameters == (
"filename_or_obj",
"decoder",
)


@mock.patch(
"pkg_resources.EntryPoint.load",
mock.MagicMock(return_value=DummyBackendEntrypoint1),
)
def test_build_engines_sorted():
dummy_pkg_entrypoints = [
pkg_resources.EntryPoint.parse(
"dummy2 = xarray.tests.test_plugins:backend_1",
),
pkg_resources.EntryPoint.parse(
"dummy1 = xarray.tests.test_plugins:backend_1",
),
]
backend_entrypoints = plugins.build_engines(dummy_pkg_entrypoints)
backend_entrypoints = list(backend_entrypoints)

indices = []
for be in plugins.STANDARD_BACKENDS_ORDER:
try:
index = backend_entrypoints.index(be)
backend_entrypoints.pop(index)
indices.append(index)
except ValueError:
pass

assert set(indices) < {0, -1}
assert list(backend_entrypoints) == sorted(backend_entrypoints)