Skip to content

Commit 04b433a

Browse files
authored
Merge branch 'main' into grouper-public-api
2 parents 2bf596c + 5f24079 commit 04b433a

27 files changed

+808
-418
lines changed

doc/whats-new.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ New Features
2424
~~~~~~~~~~~~
2525
- New "random" method for converting to and from 360_day calendars (:pull:`8603`).
2626
By `Pascal Bourgault <https://github.com/aulemahal>`_.
27-
27+
- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array
28+
by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`,
29+
for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray`
30+
then, such as broadcasting.
31+
By `Ilan Gold <https://github.com/ilan-gold>`_.
2832

2933
Breaking changes
3034
~~~~~~~~~~~~~~~~
@@ -36,6 +40,12 @@ Bug fixes
3640

3741
Internal Changes
3842
~~~~~~~~~~~~~~~~
43+
- Migrates ``formatting_html`` functionality for `DataTree` into ``xarray/core`` (:pull: `8930`)
44+
By `Eni Awowale <https://github.com/eni-awowale>`_, `Julia Signell <https://github.com/jsignell>`_
45+
and `Tom Nicholas <https://github.com/TomNicholas>`_.
46+
- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`)
47+
By `Matt Savoie <https://github.com/flamingbear>`_ `Owen Littlejohns
48+
<https://github.com/owenlittlejohns>` and `Tom Nicholas <https://github.com/TomNicholas>`_.
3949

4050

4151
.. _whats-new.2024.03.0:

properties/test_pandas_roundtrip.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from hypothesis import given # isort:skip
1818

1919
numeric_dtypes = st.one_of(
20-
npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes()
20+
npst.unsigned_integer_dtypes(endianness="="),
21+
npst.integer_dtypes(endianness="="),
22+
npst.floating_dtypes(endianness="="),
2123
)
2224

2325
numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt))

pyproject.toml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"]
3333
complete = ["xarray[accel,io,parallel,viz,dev]"]
3434
dev = [
3535
"hypothesis",
36+
"mypy",
3637
"pre-commit",
3738
"pytest",
3839
"pytest-cov",
@@ -86,8 +87,8 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]
8687
[tool.mypy]
8788
enable_error_code = "redundant-self"
8889
exclude = [
89-
'xarray/util/generate_.*\.py',
90-
'xarray/datatree_/.*\.py',
90+
'xarray/util/generate_.*\.py',
91+
'xarray/datatree_/.*\.py',
9192
]
9293
files = "xarray"
9394
show_error_codes = true
@@ -98,8 +99,8 @@ warn_unused_ignores = true
9899

99100
# Ignore mypy errors for modules imported from datatree_.
100101
[[tool.mypy.overrides]]
101-
module = "xarray.datatree_.*"
102102
ignore_errors = true
103+
module = "xarray.datatree_.*"
103104

104105
# Much of the numerical computing stack doesn't have type annotations yet.
105106
[[tool.mypy.overrides]]
@@ -129,6 +130,7 @@ module = [
129130
"opt_einsum.*",
130131
"pandas.*",
131132
"pooch.*",
133+
"pyarrow.*",
132134
"pydap.*",
133135
"pytest.*",
134136
"scipy.*",
@@ -255,6 +257,9 @@ target-version = "py39"
255257
# E402: module level import not at top of file
256258
# E501: line too long - let black worry about that
257259
# E731: do not assign a lambda expression, use a def
260+
extend-safe-fixes = [
261+
"TID252", # absolute imports
262+
]
258263
ignore = [
259264
"E402",
260265
"E501",
@@ -268,9 +273,6 @@ select = [
268273
"I", # isort
269274
"UP", # Pyupgrade
270275
]
271-
extend-safe-fixes = [
272-
"TID252", # absolute imports
273-
]
274276

275277
[tool.ruff.lint.per-file-ignores]
276278
# don't enforce absolute imports

xarray/core/dataset.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload
2525

2626
import numpy as np
27+
from pandas.api.types import is_extension_array_dtype
2728

2829
# remove once numpy 2.0 is the oldest supported version
2930
try:
@@ -6853,10 +6854,13 @@ def reduce(
68536854
if (
68546855
# Some reduction functions (e.g. std, var) need to run on variables
68556856
# that don't have the reduce dims: PR5393
6856-
not reduce_dims
6857-
or not numeric_only
6858-
or np.issubdtype(var.dtype, np.number)
6859-
or (var.dtype == np.bool_)
6857+
not is_extension_array_dtype(var.dtype)
6858+
and (
6859+
not reduce_dims
6860+
or not numeric_only
6861+
or np.issubdtype(var.dtype, np.number)
6862+
or (var.dtype == np.bool_)
6863+
)
68606864
):
68616865
# prefer to aggregate over axis=None rather than
68626866
# axis=(0, 1) if they will be equivalent, because
@@ -7169,13 +7173,37 @@ def to_pandas(self) -> pd.Series | pd.DataFrame:
71697173
)
71707174

71717175
def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
7172-
columns = [k for k in self.variables if k not in self.dims]
7176+
columns_in_order = [k for k in self.variables if k not in self.dims]
7177+
non_extension_array_columns = [
7178+
k
7179+
for k in columns_in_order
7180+
if not is_extension_array_dtype(self.variables[k].data)
7181+
]
7182+
extension_array_columns = [
7183+
k
7184+
for k in columns_in_order
7185+
if is_extension_array_dtype(self.variables[k].data)
7186+
]
71737187
data = [
71747188
self._variables[k].set_dims(ordered_dims).values.reshape(-1)
7175-
for k in columns
7189+
for k in non_extension_array_columns
71767190
]
71777191
index = self.coords.to_index([*ordered_dims])
7178-
return pd.DataFrame(dict(zip(columns, data)), index=index)
7192+
broadcasted_df = pd.DataFrame(
7193+
dict(zip(non_extension_array_columns, data)), index=index
7194+
)
7195+
for extension_array_column in extension_array_columns:
7196+
extension_array = self.variables[extension_array_column].data.array
7197+
index = self[self.variables[extension_array_column].dims[0]].data
7198+
extension_array_df = pd.DataFrame(
7199+
{extension_array_column: extension_array},
7200+
index=self[self.variables[extension_array_column].dims[0]].data,
7201+
)
7202+
extension_array_df.index.name = self.variables[extension_array_column].dims[
7203+
0
7204+
]
7205+
broadcasted_df = broadcasted_df.join(extension_array_df)
7206+
return broadcasted_df[columns_in_order]
71797207

71807208
def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame:
71817209
"""Convert this dataset into a pandas.DataFrame.
@@ -7322,11 +7350,13 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73227350
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
73237351
)
73247352

7325-
# Cast to a NumPy array first, in case the Series is a pandas Extension
7326-
# array (which doesn't have a valid NumPy dtype)
7327-
# TODO: allow users to control how this casting happens, e.g., by
7328-
# forwarding arguments to pandas.Series.to_numpy?
7329-
arrays = [(k, np.asarray(v)) for k, v in dataframe.items()]
7353+
arrays = []
7354+
extension_arrays = []
7355+
for k, v in dataframe.items():
7356+
if not is_extension_array_dtype(v):
7357+
arrays.append((k, np.asarray(v)))
7358+
else:
7359+
extension_arrays.append((k, v))
73307360

73317361
indexes: dict[Hashable, Index] = {}
73327362
index_vars: dict[Hashable, Variable] = {}
@@ -7340,6 +7370,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73407370
xr_idx = PandasIndex(lev, dim)
73417371
indexes[dim] = xr_idx
73427372
index_vars.update(xr_idx.create_variables())
7373+
arrays += [(k, np.asarray(v)) for k, v in extension_arrays]
7374+
extension_arrays = []
73437375
else:
73447376
index_name = idx.name if idx.name is not None else "index"
73457377
dims = (index_name,)
@@ -7353,7 +7385,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73537385
obj._set_sparse_data_from_dataframe(idx, arrays, dims)
73547386
else:
73557387
obj._set_numpy_data_from_dataframe(idx, arrays, dims)
7356-
return obj
7388+
for name, extension_array in extension_arrays:
7389+
obj[name] = (dims, extension_array)
7390+
return obj[dataframe.columns] if len(dataframe.columns) else obj
73577391

73587392
def to_dask_dataframe(
73597393
self, dim_order: Sequence[Hashable] | None = None, set_index: bool = False

xarray/core/datatree.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
from xarray.core.coordinates import DatasetCoordinates
1919
from xarray.core.dataarray import DataArray
2020
from xarray.core.dataset import Dataset, DataVariables
21+
from xarray.core.datatree_mapping import (
22+
TreeIsomorphismError,
23+
check_isomorphic,
24+
map_over_subtree,
25+
)
26+
from xarray.core.formatting_html import (
27+
datatree_repr as datatree_repr_html,
28+
)
2129
from xarray.core.indexes import Index, Indexes
2230
from xarray.core.merge import dataset_update_method
2331
from xarray.core.options import OPTIONS as XR_OPTS
@@ -33,14 +41,6 @@
3341
from xarray.core.variable import Variable
3442
from xarray.datatree_.datatree.common import TreeAttrAccessMixin
3543
from xarray.datatree_.datatree.formatting import datatree_repr
36-
from xarray.datatree_.datatree.formatting_html import (
37-
datatree_repr as datatree_repr_html,
38-
)
39-
from xarray.datatree_.datatree.mapping import (
40-
TreeIsomorphismError,
41-
check_isomorphic,
42-
map_over_subtree,
43-
)
4444
from xarray.datatree_.datatree.ops import (
4545
DataTreeArithmeticMixin,
4646
MappedDatasetMethodsMixin,

xarray/datatree_/datatree/mapping.py renamed to xarray/core/datatree_mapping.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
import sys
55
from itertools import repeat
66
from textwrap import dedent
7-
from typing import TYPE_CHECKING, Callable, Tuple
7+
from typing import TYPE_CHECKING, Callable
88

99
from xarray import DataArray, Dataset
10-
1110
from xarray.core.iterators import LevelOrderIter
1211
from xarray.core.treenode import NodePath, TreeNode
1312

@@ -84,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
8483
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
8584
path_a, path_b = node_a.path, node_b.path
8685

87-
if require_names_equal:
88-
if node_a.name != node_b.name:
89-
diff = dedent(
90-
f"""\
86+
if require_names_equal and node_a.name != node_b.name:
87+
diff = dedent(
88+
f"""\
9189
Node '{path_a}' in the left object has name '{node_a.name}'
9290
Node '{path_b}' in the right object has name '{node_b.name}'"""
93-
)
94-
return diff
91+
)
92+
return diff
9593

9694
if len(node_a.children) != len(node_b.children):
9795
diff = dedent(
@@ -125,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable:
125123
func : callable
126124
Function to apply to datasets with signature:
127125
128-
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
126+
`func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`.
129127
130128
(i.e. func must accept at least one Dataset and return at least one Dataset.)
131129
Function will not be applied to any nodes without datasets.
@@ -154,7 +152,7 @@ def map_over_subtree(func: Callable) -> Callable:
154152
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?
155153

156154
@functools.wraps(func)
157-
def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
155+
def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
158156
"""Internal function which maps func over every node in tree, returning a tree of the results."""
159157
from xarray.core.datatree import DataTree
160158

@@ -259,19 +257,18 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
259257
return _map_over_subtree
260258

261259

262-
def _handle_errors_with_path_context(path):
260+
def _handle_errors_with_path_context(path: str):
263261
"""Wraps given function so that if it fails it also raises path to node on which it failed."""
264262

265263
def decorator(func):
266264
def wrapper(*args, **kwargs):
267265
try:
268266
return func(*args, **kwargs)
269267
except Exception as e:
270-
if sys.version_info >= (3, 11):
271-
# Add the context information to the error message
272-
e.add_note(
273-
f"Raised whilst mapping function over node with path {path}"
274-
)
268+
# Add the context information to the error message
269+
add_note(
270+
e, f"Raised whilst mapping function over node with path {path}"
271+
)
275272
raise
276273

277274
return wrapper
@@ -287,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None:
287284
err.add_note(msg)
288285

289286

290-
def _check_single_set_return_values(path_to_node, obj):
287+
def _check_single_set_return_values(
288+
path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
289+
):
291290
"""Check types returned from single evaluation of func, and return number of return values received from func."""
292291
if isinstance(obj, (Dataset, DataArray)):
293292
return 1

xarray/core/duck_array_ops.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from numpy import concatenate as _concatenate
3333
from numpy.lib.stride_tricks import sliding_window_view # noqa
3434
from packaging.version import Version
35+
from pandas.api.types import is_extension_array_dtype
3536

3637
from xarray.core import dask_array_ops, dtypes, nputils
3738
from xarray.core.options import OPTIONS
@@ -156,7 +157,7 @@ def isnull(data):
156157
return full_like(data, dtype=bool, fill_value=False)
157158
else:
158159
# at this point, array should have dtype=object
159-
if isinstance(data, np.ndarray):
160+
if isinstance(data, np.ndarray) or is_extension_array_dtype(data):
160161
return pandas_isnull(data)
161162
else:
162163
# Not reachable yet, but intended for use with other duck array
@@ -221,9 +222,19 @@ def asarray(data, xp=np):
221222

222223
def as_shared_dtype(scalars_or_arrays, xp=np):
223224
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
224-
array_type_cupy = array_type("cupy")
225-
if array_type_cupy and any(
226-
isinstance(x, array_type_cupy) for x in scalars_or_arrays
225+
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
226+
extension_array_types = [
227+
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
228+
]
229+
if len(extension_array_types) == len(scalars_or_arrays) and all(
230+
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
231+
):
232+
return scalars_or_arrays
233+
raise ValueError(
234+
f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
235+
)
236+
elif array_type_cupy := array_type("cupy") and any( # noqa: F841
237+
isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821
227238
):
228239
import cupy as cp
229240

0 commit comments

Comments
 (0)