Skip to content

Support non-str Hashables in DataArray #8559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ Deprecations
Bug fixes
~~~~~~~~~

- Support non-string hashable dimensions in :py:class:`xarray.DataArray` (:issue:`8546`, :pull:`8559`).
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Vendor `SerializableLock` from dask and use as default lock for netcdf4 backends (:issue:`8442`, :pull:`8571`).
Expand Down
59 changes: 26 additions & 33 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
Generic,
Literal,
NoReturn,
TypeVar,
Union,
overload,
)

Expand Down Expand Up @@ -61,6 +63,7 @@
ReprObject,
_default,
either_dict_or_kwargs,
hashable,
)
from xarray.core.variable import (
IndexVariable,
Expand All @@ -73,23 +76,11 @@
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims

if TYPE_CHECKING:
from typing import TypeVar, Union

from dask.dataframe import DataFrame as DaskDataFrame
from dask.delayed import Delayed
from iris.cube import Cube as iris_Cube
from numpy.typing import ArrayLike

try:
from dask.dataframe import DataFrame as DaskDataFrame
except ImportError:
DaskDataFrame = None
try:
from dask.delayed import Delayed
except ImportError:
Delayed = None # type: ignore[misc,assignment]
try:
from iris.cube import Cube as iris_Cube
except ImportError:
iris_Cube = None

from xarray.backends import ZarrStore
from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes
from xarray.core.groupby import DataArrayGroupBy
Expand Down Expand Up @@ -140,7 +131,9 @@ def _check_coords_dims(shape, coords, dim):


def _infer_coords_and_dims(
shape, coords, dims
shape: tuple[int, ...],
coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None,
dims: str | Iterable[Hashable] | None,
) -> tuple[Mapping[Hashable, Any], tuple[Hashable, ...]]:
"""All the logic for creating a new DataArray"""

Expand All @@ -157,8 +150,7 @@ def _infer_coords_and_dims(

if isinstance(dims, str):
dims = (dims,)

if dims is None:
elif dims is None:
dims = [f"dim_{n}" for n in range(len(shape))]
if coords is not None and len(coords) == len(shape):
# try to infer dimensions from coords
Expand All @@ -168,16 +160,15 @@ def _infer_coords_and_dims(
for n, (dim, coord) in enumerate(zip(dims, coords)):
coord = as_variable(coord, name=dims[n]).to_index_variable()
dims[n] = coord.name
dims = tuple(dims)
elif len(dims) != len(shape):
dims_tuple = tuple(dims)
if len(dims_tuple) != len(shape):
raise ValueError(
"different number of dimensions on data "
f"and dims: {len(shape)} vs {len(dims)}"
f"and dims: {len(shape)} vs {len(dims_tuple)}"
)
else:
for d in dims:
if not isinstance(d, str):
raise TypeError(f"dimension {d} is not a string")
for d in dims_tuple:
if not hashable(d):
raise TypeError(f"Dimension {d} is not hashable")

new_coords: Mapping[Hashable, Any]

Expand All @@ -189,17 +180,21 @@ def _infer_coords_and_dims(
for k, v in coords.items():
new_coords[k] = as_variable(v, name=k)
elif coords is not None:
for dim, coord in zip(dims, coords):
for dim, coord in zip(dims_tuple, coords):
var = as_variable(coord, name=dim)
var.dims = (dim,)
new_coords[dim] = var.to_index_variable()

_check_coords_dims(shape, new_coords, dims)
_check_coords_dims(shape, new_coords, dims_tuple)

return new_coords, dims
return new_coords, dims_tuple


def _check_data_shape(data, coords, dims):
def _check_data_shape(
data: Any,
coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None,
dims: str | Iterable[Hashable] | None,
) -> Any:
if data is dtypes.NA:
data = np.nan
if coords is not None and utils.is_scalar(data, include_0d=False):
Expand Down Expand Up @@ -405,10 +400,8 @@ class DataArray(
def __init__(
self,
data: Any = dtypes.NA,
coords: Sequence[Sequence[Any] | pd.Index | DataArray]
| Mapping[Any, Any]
| None = None,
dims: Hashable | Sequence[Hashable] | None = None,
coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None,
dims: str | Iterable[Hashable] | None = None,
name: Hashable | None = None,
attrs: Mapping | None = None,
# internal parameters
Expand Down
11 changes: 2 additions & 9 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@
from xarray.util.deprecation_helpers import _deprecate_positional_args

if TYPE_CHECKING:
from dask.dataframe import DataFrame as DaskDataFrame
from dask.delayed import Delayed
from numpy.typing import ArrayLike

from xarray.backends import AbstractDataStore, ZarrStore
Expand Down Expand Up @@ -164,15 +166,6 @@
)
from xarray.core.weighted import DatasetWeighted

try:
from dask.delayed import Delayed
except ImportError:
Delayed = None # type: ignore[misc,assignment]
try:
from dask.dataframe import DataFrame as DaskDataFrame
except ImportError:
DaskDataFrame = None


# list of attributes of pd.DatetimeIndex that are ndarrays of time info
_DATETIMEINDEX_COMPONENTS = [
Expand Down
11 changes: 6 additions & 5 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def _importorskip(
raise ImportError("Minimum version not satisfied")
except ImportError:
has = False
func = pytest.mark.skipif(not has, reason=f"requires {modname}")

reason = f"requires {modname}"
if minversion is not None:
reason += f">={minversion}"
func = pytest.mark.skipif(not has, reason=reason)
return has, func


Expand Down Expand Up @@ -122,10 +126,7 @@ def _importorskip(
not has_pandas_version_two, reason="requires pandas 2.0.0"
)
has_numpy_array_api, requires_numpy_array_api = _importorskip("numpy", "1.26.0")
has_h5netcdf_ros3 = _importorskip("h5netcdf", "1.3.0")
requires_h5netcdf_ros3 = pytest.mark.skipif(
not has_h5netcdf_ros3[0], reason="requires h5netcdf 1.3.0"
)
has_h5netcdf_ros3, requires_h5netcdf_ros3 = _importorskip("h5netcdf", "1.3.0")

has_netCDF4_1_6_2_or_above, requires_netCDF4_1_6_2_or_above = _importorskip(
"netCDF4", "1.6.2"
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,8 @@ def test_constructor_invalid(self) -> None:
with pytest.raises(ValueError, match=r"not a subset of the .* dim"):
DataArray(data, {"x": [0, 1, 2]})

with pytest.raises(TypeError, match=r"is not a string"):
DataArray(data, dims=["x", None])
with pytest.raises(TypeError, match=r"is not hashable"):
DataArray(data, dims=["x", []]) # type: ignore[list-item]

with pytest.raises(ValueError, match=r"conflicting sizes for dim"):
DataArray([1, 2, 3], coords=[("x", [0, 1])])
Expand Down
53 changes: 53 additions & 0 deletions xarray/tests/test_hashable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Union

import pytest

from xarray import DataArray, Dataset, Variable

if TYPE_CHECKING:
from xarray.core.types import TypeAlias

DimT: TypeAlias = Union[int, tuple, "DEnum", "CustomHashable"]


class DEnum(Enum):
dim = "dim"


class CustomHashable:
def __init__(self, a: int) -> None:
self.a = a

def __hash__(self) -> int:
return self.a


parametrize_dim = pytest.mark.parametrize(
"dim",
[
pytest.param(5, id="int"),
pytest.param(("a", "b"), id="tuple"),
pytest.param(DEnum.dim, id="enum"),
pytest.param(CustomHashable(3), id="HashableObject"),
],
)


@parametrize_dim
def test_hashable_dims(dim: DimT) -> None:
v = Variable([dim], [1, 2, 3])
da = DataArray([1, 2, 3], dims=[dim])
Dataset({"a": ([dim], [1, 2, 3])})

# alternative constructors
DataArray(v)
Dataset({"a": v})
Dataset({"a": da})


@parametrize_dim
def test_dataset_variable_hashable_names(dim: DimT) -> None:
Dataset({dim: ("x", [1, 2, 3])})