Skip to content

Commit 50ea159

Browse files
authored
Support of repr and deepcopy of recursive arrays (#7112)
* allow recursive attrs in formatting (repr) * test recursion in formatting * support for deepcopy of recursive arrays
1 parent 114bf98 commit 50ea159

File tree

11 files changed

+176
-44
lines changed

11 files changed

+176
-44
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ Deprecations
3333

3434
Bug fixes
3535
~~~~~~~~~
36+
37+
- Support for recursively defined Arrays. Fixes repr and deepcopy. (:issue:`7111`, :pull:`7112`)
38+
By `Michael Niklas <https://github.com/headtr1ck>`_.
3639
- Fixed :py:meth:`Dataset.transpose` to raise a more informative error. (:issue:`6502`, :pull:`7120`)
3740
By `Patrick Naylor <https://github.com/patrick-naylor>`_
3841

xarray/core/dataarray.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def _overwrite_indexes(
516516
new_indexes.pop(name)
517517

518518
if rename_dims:
519-
new_variable.dims = [rename_dims.get(d, d) for d in new_variable.dims]
519+
new_variable.dims = tuple(rename_dims.get(d, d) for d in new_variable.dims)
520520

521521
return self._replace(
522522
variable=new_variable, coords=new_coords, indexes=new_indexes
@@ -1169,25 +1169,33 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray:
11691169
--------
11701170
pandas.DataFrame.copy
11711171
"""
1172-
variable = self.variable.copy(deep=deep, data=data)
1172+
return self._copy(deep=deep, data=data)
1173+
1174+
def _copy(
1175+
self: T_DataArray,
1176+
deep: bool = True,
1177+
data: Any = None,
1178+
memo: dict[int, Any] | None = None,
1179+
) -> T_DataArray:
1180+
variable = self.variable._copy(deep=deep, data=data, memo=memo)
11731181
indexes, index_vars = self.xindexes.copy_indexes(deep=deep)
11741182

11751183
coords = {}
11761184
for k, v in self._coords.items():
11771185
if k in index_vars:
11781186
coords[k] = index_vars[k]
11791187
else:
1180-
coords[k] = v.copy(deep=deep)
1188+
coords[k] = v._copy(deep=deep, memo=memo)
11811189

11821190
return self._replace(variable, coords, indexes=indexes)
11831191

11841192
def __copy__(self: T_DataArray) -> T_DataArray:
1185-
return self.copy(deep=False)
1193+
return self._copy(deep=False)
11861194

1187-
def __deepcopy__(self: T_DataArray, memo=None) -> T_DataArray:
1188-
# memo does nothing but is required for compatibility with
1189-
# copy.deepcopy
1190-
return self.copy(deep=True)
1195+
def __deepcopy__(
1196+
self: T_DataArray, memo: dict[int, Any] | None = None
1197+
) -> T_DataArray:
1198+
return self._copy(deep=True, memo=memo)
11911199

11921200
# mutable objects should not be Hashable
11931201
# https://github.com/python/mypy/issues/4266

xarray/core/dataset.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,14 @@ def copy(
12211221
--------
12221222
pandas.DataFrame.copy
12231223
"""
1224+
return self._copy(deep=deep, data=data)
1225+
1226+
def _copy(
1227+
self: T_Dataset,
1228+
deep: bool = False,
1229+
data: Mapping[Any, ArrayLike] | None = None,
1230+
memo: dict[int, Any] | None = None,
1231+
) -> T_Dataset:
12241232
if data is None:
12251233
data = {}
12261234
elif not utils.is_dict_like(data):
@@ -1249,13 +1257,21 @@ def copy(
12491257
if k in index_vars:
12501258
variables[k] = index_vars[k]
12511259
else:
1252-
variables[k] = v.copy(deep=deep, data=data.get(k))
1260+
variables[k] = v._copy(deep=deep, data=data.get(k), memo=memo)
12531261

1254-
attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs)
1255-
encoding = copy.deepcopy(self._encoding) if deep else copy.copy(self._encoding)
1262+
attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
1263+
encoding = (
1264+
copy.deepcopy(self._encoding, memo) if deep else copy.copy(self._encoding)
1265+
)
12561266

12571267
return self._replace(variables, indexes=indexes, attrs=attrs, encoding=encoding)
12581268

1269+
def __copy__(self: T_Dataset) -> T_Dataset:
1270+
return self._copy(deep=False)
1271+
1272+
def __deepcopy__(self: T_Dataset, memo: dict[int, Any] | None = None) -> T_Dataset:
1273+
return self._copy(deep=True, memo=memo)
1274+
12591275
def as_numpy(self: T_Dataset) -> T_Dataset:
12601276
"""
12611277
Coerces wrapped data and coordinates into numpy arrays, returning a Dataset.
@@ -1332,14 +1348,6 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
13321348

13331349
return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True)
13341350

1335-
def __copy__(self: T_Dataset) -> T_Dataset:
1336-
return self.copy(deep=False)
1337-
1338-
def __deepcopy__(self: T_Dataset, memo=None) -> T_Dataset:
1339-
# memo does nothing but is required for compatibility with
1340-
# copy.deepcopy
1341-
return self.copy(deep=True)
1342-
13431351
@property
13441352
def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]:
13451353
"""Places to look-up items for attribute-style access"""

xarray/core/formatting.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections import defaultdict
99
from datetime import datetime, timedelta
1010
from itertools import chain, zip_longest
11+
from reprlib import recursive_repr
1112
from typing import Collection, Hashable
1213

1314
import numpy as np
@@ -385,7 +386,6 @@ def _mapping_repr(
385386
expand_option_name="display_expand_data_vars",
386387
)
387388

388-
389389
attrs_repr = functools.partial(
390390
_mapping_repr,
391391
title="Attributes",
@@ -551,6 +551,7 @@ def short_data_repr(array):
551551
return f"[{array.size} values with dtype={array.dtype}]"
552552

553553

554+
@recursive_repr("<recursive array>")
554555
def array_repr(arr):
555556
from .variable import Variable
556557

@@ -592,11 +593,12 @@ def array_repr(arr):
592593
summary.append(unindexed_dims_str)
593594

594595
if arr.attrs:
595-
summary.append(attrs_repr(arr.attrs))
596+
summary.append(attrs_repr(arr.attrs, max_rows=max_rows))
596597

597598
return "\n".join(summary)
598599

599600

601+
@recursive_repr("<recursive Dataset>")
600602
def dataset_repr(ds):
601603
summary = [f"<xarray.{type(ds).__name__}>"]
602604

xarray/core/utils.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,13 @@ def equivalent(first: T, second: T) -> bool:
161161
# TODO: refactor to avoid circular import
162162
from . import duck_array_ops
163163

164+
if first is second:
165+
return True
164166
if isinstance(first, np.ndarray) or isinstance(second, np.ndarray):
165167
return duck_array_ops.array_equiv(first, second)
166-
elif isinstance(first, list) or isinstance(second, list):
168+
if isinstance(first, list) or isinstance(second, list):
167169
return list_equiv(first, second)
168-
else:
169-
return (
170-
(first is second)
171-
or (first == second)
172-
or (pd.isnull(first) and pd.isnull(second))
173-
)
170+
return (first == second) or (pd.isnull(first) and pd.isnull(second))
174171

175172

176173
def list_equiv(first, second):

xarray/core/variable.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,9 @@ def encoding(self, value):
918918
except ValueError:
919919
raise ValueError("encoding must be castable to a dictionary")
920920

921-
def copy(self, deep: bool = True, data: ArrayLike | None = None):
921+
def copy(
922+
self: T_Variable, deep: bool = True, data: ArrayLike | None = None
923+
) -> T_Variable:
922924
"""Returns a copy of this object.
923925
924926
If `deep=True`, the data array is loaded into memory and copied onto
@@ -974,6 +976,14 @@ def copy(self, deep: bool = True, data: ArrayLike | None = None):
974976
--------
975977
pandas.DataFrame.copy
976978
"""
979+
return self._copy(deep=deep, data=data)
980+
981+
def _copy(
982+
self: T_Variable,
983+
deep: bool = True,
984+
data: ArrayLike | None = None,
985+
memo: dict[int, Any] | None = None,
986+
) -> T_Variable:
977987
if data is None:
978988
ndata = self._data
979989

@@ -982,7 +992,7 @@ def copy(self, deep: bool = True, data: ArrayLike | None = None):
982992
ndata = indexing.MemoryCachedArray(ndata.array)
983993

984994
if deep:
985-
ndata = copy.deepcopy(ndata)
995+
ndata = copy.deepcopy(ndata, memo)
986996

987997
else:
988998
ndata = as_compatible_data(data)
@@ -993,8 +1003,10 @@ def copy(self, deep: bool = True, data: ArrayLike | None = None):
9931003
)
9941004
)
9951005

996-
attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs)
997-
encoding = copy.deepcopy(self._encoding) if deep else copy.copy(self._encoding)
1006+
attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
1007+
encoding = (
1008+
copy.deepcopy(self._encoding, memo) if deep else copy.copy(self._encoding)
1009+
)
9981010

9991011
# note: dims is already an immutable tuple
10001012
return self._replace(data=ndata, attrs=attrs, encoding=encoding)
@@ -1016,13 +1028,13 @@ def _replace(
10161028
encoding = copy.copy(self._encoding)
10171029
return type(self)(dims, data, attrs, encoding, fastpath=True)
10181030

1019-
def __copy__(self):
1020-
return self.copy(deep=False)
1031+
def __copy__(self: T_Variable) -> T_Variable:
1032+
return self._copy(deep=False)
10211033

1022-
def __deepcopy__(self, memo=None):
1023-
# memo does nothing but is required for compatibility with
1024-
# copy.deepcopy
1025-
return self.copy(deep=True)
1034+
def __deepcopy__(
1035+
self: T_Variable, memo: dict[int, Any] | None = None
1036+
) -> T_Variable:
1037+
return self._copy(deep=True, memo=memo)
10261038

10271039
# mutable objects should not be hashable
10281040
# https://github.com/python/mypy/issues/4266

xarray/tests/test_concat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,9 @@ def test_concat_errors(self):
219219
concat([data, data], "new_dim", coords=["not_found"])
220220

221221
with pytest.raises(ValueError, match=r"global attributes not"):
222-
data0, data1 = deepcopy(split_data)
222+
# call deepcopy seperately to get unique attrs
223+
data0 = deepcopy(split_data[0])
224+
data1 = deepcopy(split_data[1])
223225
data1.attrs["foo"] = "bar"
224226
concat([data0, data1], "dim1", compat="identical")
225227
assert_identical(data, concat([data0, data1], "dim1", compat="equals"))

xarray/tests/test_dataarray.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6488,6 +6488,28 @@ def test_deepcopy_obj_array() -> None:
64886488
assert x0.values[0] is not x1.values[0]
64896489

64906490

6491+
def test_deepcopy_recursive() -> None:
6492+
# GH:issue:7111
6493+
6494+
# direct recursion
6495+
da = xr.DataArray([1, 2], dims=["x"])
6496+
da.attrs["other"] = da
6497+
6498+
# TODO: cannot use assert_identical on recursive Vars yet...
6499+
# lets just ensure that deep copy works without RecursionError
6500+
da.copy(deep=True)
6501+
6502+
# indirect recursion
6503+
da2 = xr.DataArray([5, 6], dims=["y"])
6504+
da.attrs["other"] = da2
6505+
da2.attrs["other"] = da
6506+
6507+
# TODO: cannot use assert_identical on recursive Vars yet...
6508+
# lets just ensure that deep copy works without RecursionError
6509+
da.copy(deep=True)
6510+
da2.copy(deep=True)
6511+
6512+
64916513
def test_clip(da: DataArray) -> None:
64926514
with raise_if_dask_computes():
64936515
result = da.clip(min=0.5)

xarray/tests/test_dataset.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6687,6 +6687,28 @@ def test_deepcopy_obj_array() -> None:
66876687
assert x0["foo"].values[0] is not x1["foo"].values[0]
66886688

66896689

6690+
def test_deepcopy_recursive() -> None:
6691+
# GH:issue:7111
6692+
6693+
# direct recursion
6694+
ds = xr.Dataset({"a": (["x"], [1, 2])})
6695+
ds.attrs["other"] = ds
6696+
6697+
# TODO: cannot use assert_identical on recursive Vars yet...
6698+
# lets just ensure that deep copy works without RecursionError
6699+
ds.copy(deep=True)
6700+
6701+
# indirect recursion
6702+
ds2 = xr.Dataset({"b": (["y"], [3, 4])})
6703+
ds.attrs["other"] = ds2
6704+
ds2.attrs["other"] = ds
6705+
6706+
# TODO: cannot use assert_identical on recursive Vars yet...
6707+
# lets just ensure that deep copy works without RecursionError
6708+
ds.copy(deep=True)
6709+
ds2.copy(deep=True)
6710+
6711+
66906712
def test_clip(ds) -> None:
66916713
result = ds.clip(min=0.5)
66926714
assert all((result.min(...) >= 0.5).values())

xarray/tests/test_formatting.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,24 @@ def test_array_repr_variable(self) -> None:
431431
with xr.set_options(display_expand_data=False):
432432
formatting.array_repr(var)
433433

434+
def test_array_repr_recursive(self) -> None:
435+
# GH:issue:7111
436+
437+
# direct recurion
438+
var = xr.Variable("x", [0, 1])
439+
var.attrs["x"] = var
440+
formatting.array_repr(var)
441+
442+
da = xr.DataArray([0, 1], dims=["x"])
443+
da.attrs["x"] = da
444+
formatting.array_repr(da)
445+
446+
# indirect recursion
447+
var.attrs["x"] = da
448+
da.attrs["x"] = var
449+
formatting.array_repr(var)
450+
formatting.array_repr(da)
451+
434452
@requires_dask
435453
def test_array_scalar_format(self) -> None:
436454
# Test numpy scalars:
@@ -615,6 +633,21 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None:
615633
assert actual == expected
616634

617635

636+
def test__mapping_repr_recursive() -> None:
637+
# GH:issue:7111
638+
639+
# direct recursion
640+
ds = xr.Dataset({"a": [["x"], [1, 2, 3]]})
641+
ds.attrs["ds"] = ds
642+
formatting.dataset_repr(ds)
643+
644+
# indirect recursion
645+
ds2 = xr.Dataset({"b": [["y"], [1, 2, 3]]})
646+
ds.attrs["ds"] = ds2
647+
ds2.attrs["ds"] = ds
648+
formatting.dataset_repr(ds2)
649+
650+
618651
def test__element_formatter(n_elements: int = 100) -> None:
619652
expected = """\
620653
Dimensions without coordinates: dim_0: 3, dim_1: 3, dim_2: 3, dim_3: 3,

0 commit comments

Comments
 (0)