8
8
from collections .abc import Hashable , Iterable , Mapping , Sequence
9
9
from datetime import timedelta
10
10
from functools import partial
11
- from typing import TYPE_CHECKING , Any , Callable , Literal , NoReturn
11
+ from typing import TYPE_CHECKING , Any , Callable , Literal , NoReturn , cast
12
12
13
13
import numpy as np
14
14
import pandas as pd
66
66
PadModeOptions ,
67
67
PadReflectOptions ,
68
68
QuantileMethods ,
69
+ T_DuckArray ,
69
70
T_Variable ,
70
71
)
71
72
@@ -86,7 +87,7 @@ class MissingDimensionsError(ValueError):
86
87
# TODO: move this to an xarray.exceptions module?
87
88
88
89
89
- def as_variable (obj , name = None ) -> Variable | IndexVariable :
90
+ def as_variable (obj : T_DuckArray | Any , name = None ) -> Variable | IndexVariable :
90
91
"""Convert an object into a Variable.
91
92
92
93
Parameters
@@ -142,7 +143,7 @@ def as_variable(obj, name=None) -> Variable | IndexVariable:
142
143
elif isinstance (obj , (set , dict )):
143
144
raise TypeError (f"variable { name !r} has invalid type { type (obj )!r} " )
144
145
elif name is not None :
145
- data = as_compatible_data (obj )
146
+ data : T_DuckArray = as_compatible_data (obj )
146
147
if data .ndim != 1 :
147
148
raise MissingDimensionsError (
148
149
f"cannot set variable { name !r} with { data .ndim !r} -dimensional data "
@@ -230,7 +231,9 @@ def _possibly_convert_datetime_or_timedelta_index(data):
230
231
return data
231
232
232
233
233
- def as_compatible_data (data , fastpath : bool = False ):
234
+ def as_compatible_data (
235
+ data : T_DuckArray | ArrayLike , fastpath : bool = False
236
+ ) -> T_DuckArray :
234
237
"""Prepare and wrap data to put in a Variable.
235
238
236
239
- If data does not have the necessary attributes, convert it to ndarray.
@@ -243,7 +246,7 @@ def as_compatible_data(data, fastpath: bool = False):
243
246
"""
244
247
if fastpath and getattr (data , "ndim" , 0 ) > 0 :
245
248
# can't use fastpath (yet) for scalars
246
- return _maybe_wrap_data (data )
249
+ return cast ( "T_DuckArray" , _maybe_wrap_data (data ) )
247
250
248
251
from xarray .core .dataarray import DataArray
249
252
@@ -252,7 +255,7 @@ def as_compatible_data(data, fastpath: bool = False):
252
255
253
256
if isinstance (data , NON_NUMPY_SUPPORTED_ARRAY_TYPES ):
254
257
data = _possibly_convert_datetime_or_timedelta_index (data )
255
- return _maybe_wrap_data (data )
258
+ return cast ( "T_DuckArray" , _maybe_wrap_data (data ) )
256
259
257
260
if isinstance (data , tuple ):
258
261
data = utils .to_0d_object_array (data )
@@ -279,7 +282,7 @@ def as_compatible_data(data, fastpath: bool = False):
279
282
if not isinstance (data , np .ndarray ) and (
280
283
hasattr (data , "__array_function__" ) or hasattr (data , "__array_namespace__" )
281
284
):
282
- return data
285
+ return cast ( "T_DuckArray" , data )
283
286
284
287
# validate whether the data is valid data types.
285
288
data = np .asarray (data )
@@ -335,7 +338,14 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic):
335
338
336
339
__slots__ = ("_dims" , "_data" , "_attrs" , "_encoding" )
337
340
338
- def __init__ (self , dims , data , attrs = None , encoding = None , fastpath = False ):
341
+ def __init__ (
342
+ self ,
343
+ dims ,
344
+ data : T_DuckArray | ArrayLike ,
345
+ attrs = None ,
346
+ encoding = None ,
347
+ fastpath = False ,
348
+ ):
339
349
"""
340
350
Parameters
341
351
----------
@@ -355,9 +365,9 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
355
365
Well-behaved code to serialize a Variable should ignore
356
366
unrecognized encoding items.
357
367
"""
358
- self ._data = as_compatible_data (data , fastpath = fastpath )
368
+ self ._data : T_DuckArray = as_compatible_data (data , fastpath = fastpath )
359
369
self ._dims = self ._parse_dimensions (dims )
360
- self ._attrs = None
370
+ self ._attrs : dict [ Any , Any ] | None = None
361
371
self ._encoding = None
362
372
if attrs is not None :
363
373
self .attrs = attrs
@@ -410,7 +420,7 @@ def _in_memory(self):
410
420
)
411
421
412
422
@property
413
- def data (self ) -> Any :
423
+ def data (self : T_Variable ) :
414
424
"""
415
425
The Variable's data as an array. The underlying array type
416
426
(e.g. dask, sparse, pint) is preserved.
@@ -429,12 +439,12 @@ def data(self) -> Any:
429
439
return self .values
430
440
431
441
@data .setter
432
- def data (self , data ) :
442
+ def data (self : T_Variable , data : T_DuckArray | ArrayLike ) -> None :
433
443
data = as_compatible_data (data )
434
- if data .shape != self .shape :
444
+ if data .shape != self .shape : # type: ignore[attr-defined]
435
445
raise ValueError (
436
446
f"replacement data must match the Variable's shape. "
437
- f"replacement data has shape { data .shape } ; Variable has shape { self .shape } "
447
+ f"replacement data has shape { data .shape } ; Variable has shape { self .shape } " # type: ignore[attr-defined]
438
448
)
439
449
self ._data = data
440
450
@@ -996,7 +1006,7 @@ def reset_encoding(self: T_Variable) -> T_Variable:
996
1006
return self ._replace (encoding = {})
997
1007
998
1008
def copy (
999
- self : T_Variable , deep : bool = True , data : ArrayLike | None = None
1009
+ self : T_Variable , deep : bool = True , data : T_DuckArray | ArrayLike | None = None
1000
1010
) -> T_Variable :
1001
1011
"""Returns a copy of this object.
1002
1012
@@ -1058,24 +1068,26 @@ def copy(
1058
1068
def _copy (
1059
1069
self : T_Variable ,
1060
1070
deep : bool = True ,
1061
- data : ArrayLike | None = None ,
1071
+ data : T_DuckArray | ArrayLike | None = None ,
1062
1072
memo : dict [int , Any ] | None = None ,
1063
1073
) -> T_Variable :
1064
1074
if data is None :
1065
- ndata = self ._data
1075
+ data_old = self ._data
1066
1076
1067
- if isinstance (ndata , indexing .MemoryCachedArray ):
1077
+ if isinstance (data_old , indexing .MemoryCachedArray ):
1068
1078
# don't share caching between copies
1069
- ndata = indexing .MemoryCachedArray (ndata .array )
1079
+ ndata = indexing .MemoryCachedArray (data_old .array )
1080
+ else :
1081
+ ndata = data_old
1070
1082
1071
1083
if deep :
1072
1084
ndata = copy .deepcopy (ndata , memo )
1073
1085
1074
1086
else :
1075
1087
ndata = as_compatible_data (data )
1076
- if self .shape != ndata .shape :
1088
+ if self .shape != ndata .shape : # type: ignore[attr-defined]
1077
1089
raise ValueError (
1078
- f"Data shape { ndata .shape } must match shape of object { self .shape } "
1090
+ f"Data shape { ndata .shape } must match shape of object { self .shape } " # type: ignore[attr-defined]
1079
1091
)
1080
1092
1081
1093
attrs = copy .deepcopy (self ._attrs , memo ) if deep else copy .copy (self ._attrs )
@@ -1248,11 +1260,11 @@ def chunk(
1248
1260
inline_array = inline_array ,
1249
1261
)
1250
1262
1251
- data = self ._data
1252
- if chunkmanager .is_chunked_array (data ):
1253
- data = chunkmanager .rechunk (data , chunks ) # type: ignore[arg-type]
1263
+ data_old = self ._data
1264
+ if chunkmanager .is_chunked_array (data_old ):
1265
+ data_chunked = chunkmanager .rechunk (data_old , chunks ) # type: ignore[arg-type]
1254
1266
else :
1255
- if isinstance (data , indexing .ExplicitlyIndexed ):
1267
+ if isinstance (data_old , indexing .ExplicitlyIndexed ):
1256
1268
# Unambiguously handle array storage backends (like NetCDF4 and h5py)
1257
1269
# that can't handle general array indexing. For example, in netCDF4 you
1258
1270
# can do "outer" indexing along two dimensions independent, which works
@@ -1261,20 +1273,22 @@ def chunk(
1261
1273
# Using OuterIndexer is a pragmatic choice: dask does not yet handle
1262
1274
# different indexing types in an explicit way:
1263
1275
# https://github.com/dask/dask/issues/2883
1264
- data = indexing .ImplicitToExplicitIndexingAdapter (
1265
- data , indexing .OuterIndexer
1276
+ ndata = indexing .ImplicitToExplicitIndexingAdapter (
1277
+ data_old , indexing .OuterIndexer
1266
1278
)
1279
+ else :
1280
+ ndata = data_old
1267
1281
1268
1282
if utils .is_dict_like (chunks ):
1269
- chunks = tuple (chunks .get (n , s ) for n , s in enumerate (data .shape ))
1283
+ chunks = tuple (chunks .get (n , s ) for n , s in enumerate (ndata .shape ))
1270
1284
1271
- data = chunkmanager .from_array (
1272
- data ,
1285
+ data_chunked = chunkmanager .from_array (
1286
+ ndata ,
1273
1287
chunks , # type: ignore[arg-type]
1274
1288
** _from_array_kwargs ,
1275
1289
)
1276
1290
1277
- return self ._replace (data = data )
1291
+ return self ._replace (data = data_chunked )
1278
1292
1279
1293
def to_numpy (self ) -> np .ndarray :
1280
1294
"""Coerces wrapped data to numpy and returns a numpy.ndarray"""
0 commit comments