Skip to content

Commit 54be9b1

Browse files
committed
WIP
1 parent 58308e1 commit 54be9b1

File tree

3 files changed

+71
-24
lines changed

3 files changed

+71
-24
lines changed

xarray/coding/strings.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from xarray.core.utils import module_available
1919
from xarray.core.variable import Variable
2020
from xarray.namedarray.parallelcompat import get_chunked_array_type
21-
from xarray.namedarray.pycompat import is_chunked_array
21+
from xarray.namedarray.pycompat import is_chunked_array, to_numpy
2222

2323
HAS_NUMPY_2_0 = module_available("numpy", minversion="2.0.0.dev0")
2424

@@ -135,7 +135,8 @@ def decode(self, variable, name=None):
135135
if data.dtype == "S1" and dims:
136136
encoding["char_dim_name"] = dims[-1]
137137
dims = dims[:-1]
138-
data = char_to_bytes(data)
138+
# TODO (duck array encoding)
139+
data = char_to_bytes(to_numpy(data))
139140
return Variable(dims, data, attrs, encoding)
140141

141142

xarray/coding/times.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,24 @@
2121
unpack_for_encoding,
2222
)
2323
from xarray.core import indexing
24+
from xarray.core.array_api_compat import get_array_namespace
2425
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
25-
from xarray.core.duck_array_ops import array_all, array_any, asarray, ravel, reshape
26+
from xarray.core.duck_array_ops import (
27+
array_all,
28+
array_any,
29+
asarray,
30+
astype,
31+
concatenate,
32+
isnull,
33+
ravel,
34+
reshape,
35+
)
2636
from xarray.core.formatting import first_n_items, format_timestamp, last_item
2737
from xarray.core.pdcompat import default_precision_timestamp, timestamp_as_unit
2838
from xarray.core.utils import attempt_import, emit_user_level_warning
2939
from xarray.core.variable import Variable
3040
from xarray.namedarray.parallelcompat import T_ChunkedArray, get_chunked_array_type
31-
from xarray.namedarray.pycompat import is_chunked_array, to_numpy
41+
from xarray.namedarray.pycompat import is_chunked_array, to_duck_array, to_numpy
3242
from xarray.namedarray.utils import is_duck_dask_array
3343

3444
try:
@@ -100,7 +110,7 @@ def _is_numpy_compatible_time_range(times):
100110
if is_np_datetime_like(times.dtype):
101111
return True
102112
# times array contains cftime objects
103-
times = np.asarray(times)
113+
times = to_duck_array(times)
104114
tmin = times.min()
105115
tmax = times.max()
106116
try:
@@ -309,8 +319,9 @@ def _decode_cf_datetime_dtype(
309319
# successfully. Otherwise, tracebacks end up swallowed by
310320
# Dataset.__repr__ when users try to view their lazily decoded array.
311321
values = indexing.ImplicitToExplicitIndexingAdapter(indexing.as_indexable(data))
312-
example_value = np.concatenate(
313-
[to_numpy(first_n_items(values, 1) or [0]), to_numpy(last_item(values) or [0])]
322+
zero = asarray([0], xp=get_array_namespace(values))
323+
example_value = concatenate(
324+
[first_n_items(values, 1) or zero, last_item(values) or zero]
314325
)
315326

316327
try:
@@ -342,7 +353,13 @@ def _decode_datetime_with_cftime(
342353
cftime = attempt_import("cftime")
343354
if num_dates.size > 0:
344355
return np.asarray(
345-
cftime.num2date(num_dates, units, calendar, only_use_cftime_datetimes=True)
356+
cftime.num2date(
357+
# cftime uses Cython so we must convert to numpy here.
358+
to_numpy(num_dates),
359+
units,
360+
calendar,
361+
only_use_cftime_datetimes=True,
362+
)
346363
)
347364
else:
348365
return np.array([], dtype=object)
@@ -357,7 +374,7 @@ def _check_date_for_units_since_refdate(
357374
f"Value {date} can't be represented as Datetime/Timedelta."
358375
)
359376
delta = date * np.timedelta64(1, unit)
360-
if not np.isnan(delta):
377+
if not isnull(delta):
361378
# this will raise on dtype overflow for integer dtypes
362379
if date.dtype.kind in "u" and not np.int64(delta) == date:
363380
raise OutOfBoundsTimedelta(
@@ -381,7 +398,7 @@ def _check_timedelta_range(value, data_unit, time_unit):
381398
"ignore", "invalid value encountered in multiply", RuntimeWarning
382399
)
383400
delta = value * np.timedelta64(1, data_unit)
384-
if not np.isnan(delta):
401+
if not isnull(delta):
385402
# this will raise on dtype overflow for integer dtypes
386403
if value.dtype.kind in "u" and not np.int64(delta) == value:
387404
raise OutOfBoundsTimedelta(
@@ -449,9 +466,9 @@ def _decode_datetime_with_pandas(
449466
# respectively. See https://github.com/pandas-dev/pandas/issues/56996 for
450467
# more details.
451468
if flat_num_dates.dtype.kind == "i":
452-
flat_num_dates = flat_num_dates.astype(np.int64)
469+
flat_num_dates = astype(flat_num_dates, np.int64)
453470
elif flat_num_dates.dtype.kind == "u":
454-
flat_num_dates = flat_num_dates.astype(np.uint64)
471+
flat_num_dates = astype(flat_num_dates, np.uint64)
455472

456473
try:
457474
time_unit, ref_date = _unpack_time_unit_and_ref_date(units)
@@ -483,9 +500,9 @@ def _decode_datetime_with_pandas(
483500
# overflow when converting to np.int64 would not be representable with a
484501
# timedelta64 value, and therefore would raise an error in the lines above.
485502
if flat_num_dates.dtype.kind in "iu":
486-
flat_num_dates = flat_num_dates.astype(np.int64)
503+
flat_num_dates = astype(flat_num_dates, np.int64)
487504
elif flat_num_dates.dtype.kind in "f":
488-
flat_num_dates = flat_num_dates.astype(np.float64)
505+
flat_num_dates = astype(flat_num_dates, np.float64)
489506

490507
timedeltas = _numbers_to_timedelta(
491508
flat_num_dates, time_unit, ref_date.unit, "datetime"
@@ -528,8 +545,12 @@ def decode_cf_datetime(
528545
)
529546
except (KeyError, OutOfBoundsDatetime, OutOfBoundsTimedelta, OverflowError):
530547
dates = _decode_datetime_with_cftime(
531-
flat_num_dates.astype(float), units, calendar
548+
astype(flat_num_dates, float), units, calendar
532549
)
550+
# This conversion to numpy is only needed for nanarg* below.
551+
# TODO: explore removing it.
552+
# Note that `dates` is already a numpy object array of cftime objects.
553+
num_dates = to_numpy(num_dates)
533554
# retrieve cftype
534555
dates_min = dates[np.nanargmin(num_dates)]
535556
dates_max = dates[np.nanargmax(num_dates)]
@@ -586,16 +607,16 @@ def _numbers_to_timedelta(
586607
"""Transform numbers to np.timedelta64."""
587608
# keep NaT/nan mask
588609
if flat_num.dtype.kind == "f":
589-
nan = np.asarray(np.isnan(flat_num))
610+
nan = isnull(flat_num)
590611
elif flat_num.dtype.kind == "i":
591-
nan = np.asarray(flat_num == np.iinfo(np.int64).min)
612+
nan = flat_num == np.iinfo(np.int64).min
592613

593614
# in case we need to change the unit, we fix the numbers here
594615
# this should be safe, as errors would have been raised above
595616
ns_time_unit = _NS_PER_TIME_DELTA[time_unit]
596617
ns_ref_date_unit = _NS_PER_TIME_DELTA[ref_unit]
597618
if ns_time_unit > ns_ref_date_unit:
598-
flat_num = np.asarray(flat_num * np.int64(ns_time_unit / ns_ref_date_unit))
619+
flat_num = flat_num * np.int64(ns_time_unit / ns_ref_date_unit)
599620
time_unit = ref_unit
600621

601622
# estimate fitting resolution for floating point values
@@ -618,12 +639,12 @@ def _numbers_to_timedelta(
618639
# to prevent casting NaN to int
619640
with warnings.catch_warnings():
620641
warnings.simplefilter("ignore", RuntimeWarning)
621-
flat_num = flat_num.astype(np.int64)
622-
if nan.any():
642+
flat_num = astype(flat_num, np.int64)
643+
if array_any(nan):
623644
flat_num[nan] = np.iinfo(np.int64).min
624645

625646
# cast to wanted type
626-
return flat_num.astype(f"timedelta64[{time_unit}]")
647+
return astype(flat_num, f"timedelta64[{time_unit}]")
627648

628649

629650
def decode_cf_timedelta(
@@ -712,8 +733,8 @@ def infer_datetime_units(dates) -> str:
712733
'hours', 'minutes' or 'seconds' (the first one that can evenly divide all
713734
unique time deltas in `dates`)
714735
"""
715-
dates = ravel(np.asarray(dates))
716-
if np.issubdtype(np.asarray(dates).dtype, "datetime64"):
736+
dates = ravel(to_duck_array(dates))
737+
if np.issubdtype(dates.dtype, "datetime64"):
717738
dates = to_datetime_unboxed(dates)
718739
dates = dates[pd.notnull(dates)]
719740
reference_date = dates[0] if len(dates) > 0 else "1970-01-01"

xarray/tests/namespace.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,30 @@
1-
from xarray.core import duck_array_ops
1+
import numpy as np
2+
3+
from xarray.core import array_api_compat, duck_array_ops
24

35

46
def reshape(array, shape, **kwargs):
57
return type(array)(duck_array_ops.reshape(array.array, shape=shape, **kwargs))
8+
9+
10+
def concatenate(arrays, axis):
11+
return type(arrays[0])(
12+
duck_array_ops.concatenate([a.array for a in arrays], axis=axis)
13+
)
14+
15+
16+
def result_type(*arrays_and_dtypes):
17+
parsed = [a.array if hasattr(a, "array") else a for a in arrays_and_dtypes]
18+
return array_api_compat.result_type(*parsed, xp=np)
19+
20+
21+
def astype(array, dtype, **kwargs):
22+
return type(array)(duck_array_ops.astype(array.array, dtype=dtype, **kwargs))
23+
24+
25+
def isnan(array):
26+
return type(array)(duck_array_ops.isnull(array.array))
27+
28+
29+
def any(array, *args, **kwargs): # TODO: keepdims
30+
return duck_array_ops.array_any(array.array, *args, **kwargs)

0 commit comments

Comments
 (0)