Skip to content

Alignment of n-dimensional indexes with partially excluded dims #10293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 16, 2025
6 changes: 5 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ v2025.05.0 (unreleased)

New Features
~~~~~~~~~~~~

- Allow an Xarray index that uses multiple dimensions checking equality with another
index for only a subset of those dimensions (i.e., ignoring the dimensions
that are excluded from alignment).
(:issue:`10243`, :pull:`10293`)
By `Benoit Bovy <https://github.com/benbovy>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
30 changes: 27 additions & 3 deletions xarray/core/coordinate_transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from collections.abc import Hashable, Iterable, Mapping
from typing import Any
from typing import Any, overload

import numpy as np

Expand Down Expand Up @@ -64,8 +66,30 @@ def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]:
"""
raise NotImplementedError

def equals(self, other: "CoordinateTransform") -> bool:
"""Check equality with another CoordinateTransform of the same kind."""
@overload
def equals(self, other: CoordinateTransform) -> bool: ...

@overload
def equals(
self, other: CoordinateTransform, *, exclude: frozenset[Hashable] | None = None
) -> bool: ...

def equals(self, other: CoordinateTransform, **kwargs) -> bool:
"""Check equality with another CoordinateTransform of the same kind.

Parameters
----------
other : CoordinateTransform
The other Index object to compare with this object.
exclude : frozenset of hashable, optional
Dimensions excluded from checking. It is None by default, (i.e.,
when this method is not called in the context of alignment). For a
n-dimensional transform this option allows a CoordinateTransform to
optionally ignore any dimension in ``exclude`` when comparing
``self`` with ``other``. For a 1-dimensional transform this kwarg
can be safely ignored, as this method is not called when all of the
transform's dimensions are also excluded from alignment.
"""
raise NotImplementedError

def generate_coords(
Expand Down
69 changes: 62 additions & 7 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import collections.abc
import copy
import inspect
from collections import defaultdict
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -348,7 +349,15 @@ def reindex_like(self, other: Self) -> dict[Hashable, Any]:
"""
raise NotImplementedError(f"{self!r} doesn't support re-indexing labels")

def equals(self, other: Index) -> bool:
@overload
def equals(self, other: Index) -> bool: ...

@overload
def equals(
self, other: Index, *, exclude: frozenset[Hashable] | None = None
) -> bool: ...

def equals(self, other: Index, **kwargs) -> bool:
"""Compare this index with another index of the same type.

Implementation is optional but required in order to support alignment.
Expand All @@ -357,11 +366,22 @@ def equals(self, other: Index) -> bool:
----------
other : Index
The other Index object to compare with this object.
exclude : frozenset of hashable, optional
Dimensions excluded from checking. It is None by default, (i.e.,
when this method is not called in the context of alignment). For a
n-dimensional index this option allows an Index to optionally ignore
any dimension in ``exclude`` when comparing ``self`` with ``other``.
For a 1-dimensional index this kwarg can be safely ignored, as this
method is not called when all of the index's dimensions are also
excluded from alignment (note: the index's dimensions correspond to
the union of the dimensions of all coordinate variables associated
with this index).

Returns
-------
is_equal : bool
``True`` if the indexes are equal, ``False`` otherwise.

"""
raise NotImplementedError()

Expand Down Expand Up @@ -863,7 +883,7 @@ def sel(

return IndexSelResult({self.dim: indexer})

def equals(self, other: Index):
def equals(self, other: Index, *, exclude: frozenset[Hashable] | None = None):
if not isinstance(other, PandasIndex):
return False
return self.index.equals(other.index) and self.dim == other.dim
Expand Down Expand Up @@ -1542,10 +1562,12 @@ def sel(

return IndexSelResult(results)

def equals(self, other: Index) -> bool:
def equals(
self, other: Index, *, exclude: frozenset[Hashable] | None = None
) -> bool:
if not isinstance(other, CoordinateTransformIndex):
return False
return self.transform.equals(other.transform)
return self.transform.equals(other.transform, exclude=exclude)

def rename(
self,
Expand Down Expand Up @@ -1925,6 +1947,36 @@ def default_indexes(
return indexes


def _wrap_index_equals(
index: Index,
) -> Callable[[Index, frozenset[Hashable]], bool]:
# TODO: remove this Index.equals() wrapper (backward compatibility)

sig = inspect.signature(index.equals)

if len(sig.parameters) == 1:
index_cls_name = type(index).__module__ + "." + type(index).__qualname__
emit_user_level_warning(
f"the signature ``{index_cls_name}.equals(self, other)`` is deprecated. "
f"Please update it to "
f"``{index_cls_name}.equals(self, other, *, exclude=None)`` "
"or kindly ask the maintainers of ``{index_cls_name}`` to do it. "
"See documentation of xarray.Index.equals() for more info.",
FutureWarning,
)
exclude_kwarg = False
else:
exclude_kwarg = True

def equals_wrapper(other: Index, exclude: frozenset[Hashable]) -> bool:
if exclude_kwarg:
return index.equals(other, exclude=exclude)
else:
return index.equals(other)

return equals_wrapper


def indexes_equal(
index: Index,
other_index: Index,
Expand Down Expand Up @@ -1966,6 +2018,7 @@ def indexes_equal(

def indexes_all_equal(
elements: Sequence[tuple[Index, dict[Hashable, Variable]]],
exclude_dims: frozenset[Hashable],
) -> bool:
"""Check if indexes are all equal.

Expand All @@ -1990,9 +2043,11 @@ def check_variables():

same_type = all(type(indexes[0]) is type(other_idx) for other_idx in indexes[1:])
if same_type:
index_equals_func = _wrap_index_equals(indexes[0])
try:
not_equal = any(
not indexes[0].equals(other_idx) for other_idx in indexes[1:]
not index_equals_func(other_idx, exclude_dims)
for other_idx in indexes[1:]
)
except NotImplementedError:
not_equal = check_variables()
Expand Down
4 changes: 3 additions & 1 deletion xarray/indexes/range_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]:
positions = (labels - self.start) / self.step
return {self.dim: positions}

def equals(self, other: CoordinateTransform) -> bool:
def equals(
self, other: CoordinateTransform, exclude: frozenset[Hashable] | None = None
) -> bool:
if not isinstance(other, RangeCoordinateTransform):
return False

Expand Down
Loading
Loading