Skip to content

Commit 2b444af

Browse files
Illviljanpre-commit-ci[bot]andersy005
authored
Add T_DuckArray type hint to Variable.data (#8203)
* Add T_DuckArray * Add type to variable.data * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * Update variable.py * Update variable.py * Update variable.py * Update variable.py * Update variable.py * chunk renaming * Update parallelcompat.py * fix attrs? * Update alignment.py * Update test_parallelcompat.py * Update test_variable.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_variable.py * Update test_variable.py * Update test_variable.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Anderson Banihirwe <[email protected]>
1 parent 3d59258 commit 2b444af

File tree

8 files changed

+66
-46
lines changed

8 files changed

+66
-46
lines changed

xarray/core/alignment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
if TYPE_CHECKING:
2727
from xarray.core.dataarray import DataArray
2828
from xarray.core.dataset import Dataset
29-
from xarray.core.types import JoinOptions, T_DataArray, T_Dataset
29+
from xarray.core.types import JoinOptions, T_DataArray, T_Dataset, T_DuckArray
3030

3131

3232
def reindex_variables(
@@ -173,7 +173,7 @@ def __init__(
173173

174174
def _normalize_indexes(
175175
self,
176-
indexes: Mapping[Any, Any],
176+
indexes: Mapping[Any, Any | T_DuckArray],
177177
) -> tuple[NormalizedIndexes, NormalizedIndexVars]:
178178
"""Normalize the indexes/indexers used for re-indexing or alignment.
179179
@@ -194,7 +194,7 @@ def _normalize_indexes(
194194
f"Indexer has dimensions {idx.dims} that are different "
195195
f"from that to be indexed along '{k}'"
196196
)
197-
data = as_compatible_data(idx)
197+
data: T_DuckArray = as_compatible_data(idx)
198198
pd_idx = safe_cast_to_index(data)
199199
pd_idx.name = k
200200
if isinstance(pd_idx, pd.MultiIndex):

xarray/core/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7481,7 +7481,7 @@ def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset:
74817481
else:
74827482
variables[k] = f(v, *args, **kwargs)
74837483
if keep_attrs:
7484-
variables[k].attrs = v._attrs
7484+
variables[k]._attrs = v._attrs
74857485
attrs = self._attrs if keep_attrs else None
74867486
return self._replace_with_new_dims(variables, attrs=attrs)
74877487

xarray/core/parallelcompat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
T_ChunkedArray = TypeVar("T_ChunkedArray")
2626

2727
if TYPE_CHECKING:
28-
from xarray.core.types import T_Chunks, T_NormalizedChunks
28+
from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks
2929

3030

3131
@functools.lru_cache(maxsize=1)
@@ -257,7 +257,7 @@ def normalize_chunks(
257257

258258
@abstractmethod
259259
def from_array(
260-
self, data: np.ndarray, chunks: T_Chunks, **kwargs
260+
self, data: T_DuckArray | np.typing.ArrayLike, chunks: T_Chunks, **kwargs
261261
) -> T_ChunkedArray:
262262
"""
263263
Create a chunked array from a non-chunked numpy-like array.

xarray/core/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def copy(
161161
T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
162162
T_Alignable = TypeVar("T_Alignable", bound="Alignable")
163163

164+
# Temporary placeholder for indicating an array api compliant type.
165+
# hopefully in the future we can narrow this down more:
166+
T_DuckArray = TypeVar("T_DuckArray", bound=Any)
167+
164168
ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"]
165169
DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"]
166170
DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"]

xarray/core/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
import pandas as pd
7474

7575
if TYPE_CHECKING:
76-
from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims
76+
from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims, T_DuckArray
7777

7878
K = TypeVar("K")
7979
V = TypeVar("V")
@@ -253,7 +253,7 @@ def is_list_like(value: Any) -> TypeGuard[list | tuple]:
253253
return isinstance(value, (list, tuple))
254254

255255

256-
def is_duck_array(value: Any) -> bool:
256+
def is_duck_array(value: Any) -> TypeGuard[T_DuckArray]:
257257
if isinstance(value, np.ndarray):
258258
return True
259259
return (

xarray/core/variable.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import Hashable, Iterable, Mapping, Sequence
99
from datetime import timedelta
1010
from functools import partial
11-
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn
11+
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast
1212

1313
import numpy as np
1414
import pandas as pd
@@ -66,6 +66,7 @@
6666
PadModeOptions,
6767
PadReflectOptions,
6868
QuantileMethods,
69+
T_DuckArray,
6970
T_Variable,
7071
)
7172

@@ -86,7 +87,7 @@ class MissingDimensionsError(ValueError):
8687
# TODO: move this to an xarray.exceptions module?
8788

8889

89-
def as_variable(obj, name=None) -> Variable | IndexVariable:
90+
def as_variable(obj: T_DuckArray | Any, name=None) -> Variable | IndexVariable:
9091
"""Convert an object into a Variable.
9192
9293
Parameters
@@ -142,7 +143,7 @@ def as_variable(obj, name=None) -> Variable | IndexVariable:
142143
elif isinstance(obj, (set, dict)):
143144
raise TypeError(f"variable {name!r} has invalid type {type(obj)!r}")
144145
elif name is not None:
145-
data = as_compatible_data(obj)
146+
data: T_DuckArray = as_compatible_data(obj)
146147
if data.ndim != 1:
147148
raise MissingDimensionsError(
148149
f"cannot set variable {name!r} with {data.ndim!r}-dimensional data "
@@ -230,7 +231,9 @@ def _possibly_convert_datetime_or_timedelta_index(data):
230231
return data
231232

232233

233-
def as_compatible_data(data, fastpath: bool = False):
234+
def as_compatible_data(
235+
data: T_DuckArray | ArrayLike, fastpath: bool = False
236+
) -> T_DuckArray:
234237
"""Prepare and wrap data to put in a Variable.
235238
236239
- If data does not have the necessary attributes, convert it to ndarray.
@@ -243,7 +246,7 @@ def as_compatible_data(data, fastpath: bool = False):
243246
"""
244247
if fastpath and getattr(data, "ndim", 0) > 0:
245248
# can't use fastpath (yet) for scalars
246-
return _maybe_wrap_data(data)
249+
return cast("T_DuckArray", _maybe_wrap_data(data))
247250

248251
from xarray.core.dataarray import DataArray
249252

@@ -252,7 +255,7 @@ def as_compatible_data(data, fastpath: bool = False):
252255

253256
if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
254257
data = _possibly_convert_datetime_or_timedelta_index(data)
255-
return _maybe_wrap_data(data)
258+
return cast("T_DuckArray", _maybe_wrap_data(data))
256259

257260
if isinstance(data, tuple):
258261
data = utils.to_0d_object_array(data)
@@ -279,7 +282,7 @@ def as_compatible_data(data, fastpath: bool = False):
279282
if not isinstance(data, np.ndarray) and (
280283
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
281284
):
282-
return data
285+
return cast("T_DuckArray", data)
283286

284287
# validate whether the data is valid data types.
285288
data = np.asarray(data)
@@ -335,7 +338,14 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic):
335338

336339
__slots__ = ("_dims", "_data", "_attrs", "_encoding")
337340

338-
def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
341+
def __init__(
342+
self,
343+
dims,
344+
data: T_DuckArray | ArrayLike,
345+
attrs=None,
346+
encoding=None,
347+
fastpath=False,
348+
):
339349
"""
340350
Parameters
341351
----------
@@ -355,9 +365,9 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
355365
Well-behaved code to serialize a Variable should ignore
356366
unrecognized encoding items.
357367
"""
358-
self._data = as_compatible_data(data, fastpath=fastpath)
368+
self._data: T_DuckArray = as_compatible_data(data, fastpath=fastpath)
359369
self._dims = self._parse_dimensions(dims)
360-
self._attrs = None
370+
self._attrs: dict[Any, Any] | None = None
361371
self._encoding = None
362372
if attrs is not None:
363373
self.attrs = attrs
@@ -410,7 +420,7 @@ def _in_memory(self):
410420
)
411421

412422
@property
413-
def data(self) -> Any:
423+
def data(self: T_Variable):
414424
"""
415425
The Variable's data as an array. The underlying array type
416426
(e.g. dask, sparse, pint) is preserved.
@@ -429,12 +439,12 @@ def data(self) -> Any:
429439
return self.values
430440

431441
@data.setter
432-
def data(self, data):
442+
def data(self: T_Variable, data: T_DuckArray | ArrayLike) -> None:
433443
data = as_compatible_data(data)
434-
if data.shape != self.shape:
444+
if data.shape != self.shape: # type: ignore[attr-defined]
435445
raise ValueError(
436446
f"replacement data must match the Variable's shape. "
437-
f"replacement data has shape {data.shape}; Variable has shape {self.shape}"
447+
f"replacement data has shape {data.shape}; Variable has shape {self.shape}" # type: ignore[attr-defined]
438448
)
439449
self._data = data
440450

@@ -996,7 +1006,7 @@ def reset_encoding(self: T_Variable) -> T_Variable:
9961006
return self._replace(encoding={})
9971007

9981008
def copy(
999-
self: T_Variable, deep: bool = True, data: ArrayLike | None = None
1009+
self: T_Variable, deep: bool = True, data: T_DuckArray | ArrayLike | None = None
10001010
) -> T_Variable:
10011011
"""Returns a copy of this object.
10021012
@@ -1058,24 +1068,26 @@ def copy(
10581068
def _copy(
10591069
self: T_Variable,
10601070
deep: bool = True,
1061-
data: ArrayLike | None = None,
1071+
data: T_DuckArray | ArrayLike | None = None,
10621072
memo: dict[int, Any] | None = None,
10631073
) -> T_Variable:
10641074
if data is None:
1065-
ndata = self._data
1075+
data_old = self._data
10661076

1067-
if isinstance(ndata, indexing.MemoryCachedArray):
1077+
if isinstance(data_old, indexing.MemoryCachedArray):
10681078
# don't share caching between copies
1069-
ndata = indexing.MemoryCachedArray(ndata.array)
1079+
ndata = indexing.MemoryCachedArray(data_old.array)
1080+
else:
1081+
ndata = data_old
10701082

10711083
if deep:
10721084
ndata = copy.deepcopy(ndata, memo)
10731085

10741086
else:
10751087
ndata = as_compatible_data(data)
1076-
if self.shape != ndata.shape:
1088+
if self.shape != ndata.shape: # type: ignore[attr-defined]
10771089
raise ValueError(
1078-
f"Data shape {ndata.shape} must match shape of object {self.shape}"
1090+
f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined]
10791091
)
10801092

10811093
attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
@@ -1248,11 +1260,11 @@ def chunk(
12481260
inline_array=inline_array,
12491261
)
12501262

1251-
data = self._data
1252-
if chunkmanager.is_chunked_array(data):
1253-
data = chunkmanager.rechunk(data, chunks) # type: ignore[arg-type]
1263+
data_old = self._data
1264+
if chunkmanager.is_chunked_array(data_old):
1265+
data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type]
12541266
else:
1255-
if isinstance(data, indexing.ExplicitlyIndexed):
1267+
if isinstance(data_old, indexing.ExplicitlyIndexed):
12561268
# Unambiguously handle array storage backends (like NetCDF4 and h5py)
12571269
# that can't handle general array indexing. For example, in netCDF4 you
12581270
# can do "outer" indexing along two dimensions independent, which works
@@ -1261,20 +1273,22 @@ def chunk(
12611273
# Using OuterIndexer is a pragmatic choice: dask does not yet handle
12621274
# different indexing types in an explicit way:
12631275
# https://github.com/dask/dask/issues/2883
1264-
data = indexing.ImplicitToExplicitIndexingAdapter(
1265-
data, indexing.OuterIndexer
1276+
ndata = indexing.ImplicitToExplicitIndexingAdapter(
1277+
data_old, indexing.OuterIndexer
12661278
)
1279+
else:
1280+
ndata = data_old
12671281

12681282
if utils.is_dict_like(chunks):
1269-
chunks = tuple(chunks.get(n, s) for n, s in enumerate(data.shape))
1283+
chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape))
12701284

1271-
data = chunkmanager.from_array(
1272-
data,
1285+
data_chunked = chunkmanager.from_array(
1286+
ndata,
12731287
chunks, # type: ignore[arg-type]
12741288
**_from_array_kwargs,
12751289
)
12761290

1277-
return self._replace(data=data)
1291+
return self._replace(data=data_chunked)
12781292

12791293
def to_numpy(self) -> np.ndarray:
12801294
"""Coerces wrapped data to numpy and returns a numpy.ndarray"""

xarray/tests/test_parallelcompat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
guess_chunkmanager,
1313
list_chunkmanagers,
1414
)
15-
from xarray.core.types import T_Chunks, T_NormalizedChunks
15+
from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks
1616
from xarray.tests import has_dask, requires_dask
1717

1818

@@ -76,7 +76,7 @@ def normalize_chunks(
7676
return normalize_chunks(chunks, shape, limit, dtype, previous_chunks)
7777

7878
def from_array(
79-
self, data: np.ndarray, chunks: T_Chunks, **kwargs
79+
self, data: T_DuckArray | np.typing.ArrayLike, chunks: T_Chunks, **kwargs
8080
) -> DummyChunkedArray:
8181
from dask import array as da
8282

xarray/tests/test_variable.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from copy import copy, deepcopy
66
from datetime import datetime, timedelta
77
from textwrap import dedent
8+
from typing import Generic
89

910
import numpy as np
1011
import pandas as pd
@@ -26,6 +27,7 @@
2627
VectorizedIndexer,
2728
)
2829
from xarray.core.pycompat import array_type
30+
from xarray.core.types import T_DuckArray
2931
from xarray.core.utils import NDArrayMixin
3032
from xarray.core.variable import as_compatible_data, as_variable
3133
from xarray.tests import (
@@ -2529,7 +2531,7 @@ def test_to_index_variable_copy(self) -> None:
25292531
assert a.dims == ("x",)
25302532

25312533

2532-
class TestAsCompatibleData:
2534+
class TestAsCompatibleData(Generic[T_DuckArray]):
25332535
def test_unchanged_types(self):
25342536
types = (np.asarray, PandasIndexingAdapter, LazilyIndexedArray)
25352537
for t in types:
@@ -2610,17 +2612,17 @@ def test_tz_datetime(self) -> None:
26102612
times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz))
26112613
with warnings.catch_warnings():
26122614
warnings.simplefilter("ignore")
2613-
actual = as_compatible_data(times_s)
2615+
actual: T_DuckArray = as_compatible_data(times_s)
26142616
assert actual.array == times_s
26152617
assert actual.array.dtype == pd.DatetimeTZDtype("ns", tz)
26162618

26172619
series = pd.Series(times_s)
26182620
with warnings.catch_warnings():
26192621
warnings.simplefilter("ignore")
2620-
actual = as_compatible_data(series)
2622+
actual2: T_DuckArray = as_compatible_data(series)
26212623

2622-
np.testing.assert_array_equal(actual, series.values)
2623-
assert actual.dtype == np.dtype("datetime64[ns]")
2624+
np.testing.assert_array_equal(actual2, series.values)
2625+
assert actual2.dtype == np.dtype("datetime64[ns]")
26242626

26252627
def test_full_like(self) -> None:
26262628
# For more thorough tests, see test_variable.py

0 commit comments

Comments
 (0)