Skip to content

Commit 9ed7505

Browse files
authored
Ensure inherited methods return UXarray objects (#1172)
* add tests for inheritance and update where * use cleaner approach to wrap * add concat method, add dataset sel * clean up
1 parent da45ee7 commit 9ed7505

File tree

6 files changed

+333
-54
lines changed

6 files changed

+333
-54
lines changed

test/test_inheritance.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
import uxarray as ux
2+
import pytest
3+
import numpy as np
4+
from pathlib import Path
5+
import xarray as xr
6+
7+
@pytest.fixture
8+
def uxds_fixture():
9+
"""Fixture to load test dataset."""
10+
current_path = Path(__file__).resolve().parent
11+
quad_hex_grid_path = current_path / 'meshfiles' / "ugrid" / "quad-hexagon" / 'grid.nc'
12+
quad_hex_data_path = current_path / 'meshfiles' / "ugrid" / "quad-hexagon" / 'data.nc'
13+
14+
# Load the dataset
15+
ds = ux.open_dataset(quad_hex_grid_path, quad_hex_data_path)
16+
17+
# Add a dummy coordinate
18+
if 'n_face' in ds.dims:
19+
n_face_size = ds.dims['n_face']
20+
ds = ds.assign_coords(face_id=('n_face', np.arange(n_face_size)))
21+
22+
return ds
23+
24+
def test_isel(uxds_fixture):
25+
"""Test that isel method preserves UxDataset type."""
26+
result = uxds_fixture.isel(n_face=[1, 2])
27+
assert isinstance(result, ux.UxDataset)
28+
assert hasattr(result, 'uxgrid')
29+
30+
def test_where(uxds_fixture):
31+
"""Test that where method preserves UxDataset type."""
32+
result = uxds_fixture.where(uxds_fixture['t2m'] > uxds_fixture['t2m'].min())
33+
assert isinstance(result, ux.UxDataset)
34+
assert hasattr(result, 'uxgrid')
35+
36+
def test_assign(uxds_fixture):
37+
"""Test that assign method preserves UxDataset type."""
38+
# Create a new variable based on t2m
39+
new_var = xr.DataArray(
40+
np.ones_like(uxds_fixture['t2m']),
41+
dims=uxds_fixture['t2m'].dims
42+
)
43+
result = uxds_fixture.assign(new_var=new_var)
44+
assert isinstance(result, ux.UxDataset)
45+
assert hasattr(result, 'uxgrid')
46+
assert result.uxgrid is uxds_fixture.uxgrid
47+
assert 'new_var' in result.data_vars
48+
49+
50+
def test_drop_vars(uxds_fixture):
51+
"""Test that drop_vars method preserves UxDataset type."""
52+
# Create a copy with a new variable so we can drop it
53+
ds_copy = uxds_fixture.copy(deep=True)
54+
ds_copy['t2m_copy'] = ds_copy['t2m'].copy()
55+
result = ds_copy.drop_vars('t2m_copy')
56+
assert isinstance(result, ux.UxDataset)
57+
assert hasattr(result, 'uxgrid')
58+
assert result.uxgrid is ds_copy.uxgrid
59+
assert 't2m_copy' not in result.data_vars
60+
61+
def test_transpose(uxds_fixture):
62+
"""Test that transpose method preserves UxDataset type."""
63+
# Get all dimensions
64+
dims = list(uxds_fixture.dims)
65+
if len(dims) > 1:
66+
# Reverse dimensions for transpose
67+
reversed_dims = dims.copy()
68+
reversed_dims.reverse()
69+
result = uxds_fixture.transpose(*reversed_dims)
70+
assert isinstance(result, ux.UxDataset)
71+
assert hasattr(result, 'uxgrid')
72+
assert result.uxgrid is uxds_fixture.uxgrid
73+
74+
def test_fillna(uxds_fixture):
75+
"""Test that fillna method preserves UxDataset type."""
76+
# Create a copy with some NaN values in t2m
77+
ds_with_nans = uxds_fixture.copy(deep=True)
78+
t2m_data = ds_with_nans['t2m'].values
79+
if t2m_data.size > 0:
80+
t2m_data.ravel()[0:2] = np.nan
81+
result = ds_with_nans.fillna(0)
82+
assert isinstance(result, ux.UxDataset)
83+
assert hasattr(result, 'uxgrid')
84+
# Verify NaNs were filled
85+
assert not np.isnan(result['t2m'].values).any()
86+
87+
def test_rename(uxds_fixture):
88+
"""Test that rename method preserves UxDataset type."""
89+
result = uxds_fixture.rename({'t2m': 't2m_renamed'})
90+
assert isinstance(result, ux.UxDataset)
91+
assert hasattr(result, 'uxgrid')
92+
assert result.uxgrid is uxds_fixture.uxgrid
93+
assert 't2m_renamed' in result.data_vars
94+
assert 't2m' not in result.data_vars
95+
96+
def test_to_array(uxds_fixture):
97+
"""Test that to_array method preserves UxDataArray type for its result."""
98+
# Create a dataset with multiple variables to test to_array
99+
ds_multi = uxds_fixture.copy(deep=True)
100+
ds_multi['t2m_celsius'] = ds_multi['t2m']
101+
ds_multi['t2m_kelvin'] = ds_multi['t2m'] + 273.15
102+
103+
result = ds_multi.to_array()
104+
assert isinstance(result, ux.UxDataArray)
105+
assert hasattr(result, 'uxgrid')
106+
assert result.uxgrid == uxds_fixture.uxgrid
107+
108+
def test_arithmetic_operations(uxds_fixture):
109+
"""Test arithmetic operations preserve UxDataset type."""
110+
# Test addition
111+
result = uxds_fixture['t2m'] + 1
112+
assert isinstance(result, ux.UxDataArray)
113+
assert hasattr(result, 'uxgrid')
114+
115+
# Test dataset-level operations
116+
result = uxds_fixture * 2
117+
assert isinstance(result, ux.UxDataset)
118+
assert hasattr(result, 'uxgrid')
119+
120+
# Test more complex operations
121+
result = uxds_fixture.copy(deep=True)
122+
result['t2m_squared'] = uxds_fixture['t2m'] ** 2
123+
assert isinstance(result, ux.UxDataset)
124+
assert hasattr(result, 'uxgrid')
125+
assert 't2m_squared' in result.data_vars
126+
127+
def test_reduction_methods(uxds_fixture):
128+
"""Test reduction methods preserve UxDataset type when dimensions remain."""
129+
if len(uxds_fixture.dims) > 1:
130+
# Get a dimension to reduce over
131+
dim_to_reduce = list(uxds_fixture.dims)[0]
132+
133+
# Test mean
134+
result = uxds_fixture.mean(dim=dim_to_reduce)
135+
assert isinstance(result, ux.UxDataset)
136+
assert hasattr(result, 'uxgrid')
137+
138+
# Test sum on specific variable
139+
result = uxds_fixture['t2m'].sum(dim=dim_to_reduce)
140+
assert isinstance(result, ux.UxDataArray)
141+
assert hasattr(result, 'uxgrid')
142+
143+
def test_groupby(uxds_fixture):
144+
"""Test that groupby operations preserve UxDataset type."""
145+
# Use face_id coordinate for grouping
146+
if 'face_id' in uxds_fixture.coords:
147+
# Create a discrete grouping variable
148+
grouper = uxds_fixture['face_id'] % 2 # Group by even/odd
149+
uxds_fixture = uxds_fixture.assign_coords(parity=grouper)
150+
groups = uxds_fixture.groupby('parity')
151+
result = groups.mean()
152+
assert isinstance(result, ux.UxDataset)
153+
assert hasattr(result, 'uxgrid')
154+
155+
156+
def test_assign_coords(uxds_fixture):
157+
"""Test that assign_coords preserves UxDataset type."""
158+
if 'n_face' in uxds_fixture.dims:
159+
dim = 'n_face'
160+
size = uxds_fixture.dims[dim]
161+
# Create a coordinate that's different from face_id
162+
new_coord = xr.DataArray(np.arange(size) * 10, dims=[dim])
163+
result = uxds_fixture.assign_coords(scaled_id=new_coord)
164+
assert isinstance(result, ux.UxDataset)
165+
assert hasattr(result, 'uxgrid')
166+
assert 'scaled_id' in result.coords
167+
168+
169+
def test_expand_dims(uxds_fixture):
170+
"""Test that expand_dims preserves UxDataset type."""
171+
result = uxds_fixture.expand_dims(dim='time')
172+
assert isinstance(result, ux.UxDataset)
173+
assert hasattr(result, 'uxgrid')
174+
assert 'time' in result.dims
175+
assert result.dims['time'] == 1
176+
177+
# Verify data variable shape was updated correctly
178+
assert result['t2m'].shape[0] == 1
179+
180+
def test_method_chaining(uxds_fixture):
181+
"""Test that methods can be chained while preserving UxDataset type."""
182+
# Chain multiple operations
183+
result = (uxds_fixture
184+
.assign(t2m_kelvin=uxds_fixture['t2m'] + 273.15)
185+
.rename({'t2m': 't2m_celsius'})
186+
.fillna(0))
187+
assert isinstance(result, ux.UxDataset)
188+
assert hasattr(result, 'uxgrid')
189+
assert 't2m_celsius' in result.data_vars
190+
assert 't2m_kelvin' in result.data_vars
191+
192+
193+
def test_stack_unstack(uxds_fixture):
194+
"""Test that stack and unstack preserve UxDataset type."""
195+
if len(uxds_fixture.dims) >= 2:
196+
# Get two dimensions to stack
197+
dims = list(uxds_fixture.dims)[:2]
198+
# Stack the dimensions
199+
stacked_name = f"{dims[0]}_{dims[1]}"
200+
stacked = uxds_fixture.stack({stacked_name: dims})
201+
assert isinstance(stacked, ux.UxDataset)
202+
assert hasattr(stacked, 'uxgrid')
203+
204+
# Unstack them
205+
unstacked = stacked.unstack(stacked_name)
206+
assert isinstance(unstacked, ux.UxDataset)
207+
assert hasattr(unstacked, 'uxgrid')
208+
209+
210+
def test_sortby(uxds_fixture):
211+
"""Test that sortby preserves UxDataset type."""
212+
if 'face_id' in uxds_fixture.coords:
213+
# Create a reverse sorted coordinate
214+
size = len(uxds_fixture.face_id)
215+
uxds_fixture = uxds_fixture.assign_coords(reverse_id=('n_face', np.arange(size)[::-1]))
216+
217+
# Sort by this new coordinate
218+
result = uxds_fixture.sortby('reverse_id')
219+
assert isinstance(result, ux.UxDataset)
220+
assert hasattr(result, 'uxgrid')
221+
# Verify sorting changed the order
222+
assert np.array_equal(result.face_id.values, np.sort(uxds_fixture.face_id.values)[::-1])
223+
224+
def test_shift(uxds_fixture):
225+
"""Test that shift preserves UxDataset type."""
226+
if 'n_face' in uxds_fixture.dims:
227+
result = uxds_fixture.shift(n_face=1)
228+
assert isinstance(result, ux.UxDataset)
229+
assert hasattr(result, 'uxgrid')
230+
# Verify data has shifted (first element now NaN)
231+
assert np.isnan(result['t2m'].isel(n_face=0).values.item())

uxarray/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .core.api import open_grid, open_dataset, open_mfdataset
1+
from .core.api import open_grid, open_dataset, open_mfdataset, concat
22

33
from .core.dataset import UxDataset
44
from .core.dataarray import UxDataArray
@@ -23,6 +23,7 @@
2323
"open_grid",
2424
"open_dataset",
2525
"open_mfdataset",
26+
"concat",
2627
"UxDataset",
2728
"UxDataArray",
2829
"INT_DTYPE",

uxarray/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .dataset import UxDataset
44
from .dataarray import UxDataArray
55

6+
67
__all__ = (
78
"open_grid",
89
"open_dataset",

uxarray/core/api.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1-
"""UXarray dataset module."""
2-
31
import os
42
import numpy as np
53
import xarray as xr
64

7-
from typing import Any, Dict, Optional, Union
5+
from collections.abc import Hashable, Iterable
6+
from typing import Any, Dict, Optional, Union, TYPE_CHECKING
87

98
from uxarray.grid import Grid
109
from uxarray.core.dataset import UxDataset
1110
from uxarray.core.utils import _map_dims_to_ugrid
1211

1312
from warnings import warn
1413

14+
if TYPE_CHECKING:
15+
from xarray.core.types import (
16+
ConcatOptions,
17+
)
18+
19+
T_DataVars = Union[ConcatOptions, Iterable[Hashable]]
20+
1521

1622
def open_grid(
1723
grid_filename_or_obj: Union[
@@ -291,3 +297,30 @@ def open_mfdataset(
291297
uxds = UxDataset(ds, uxgrid=uxgrid, source_datasets=str(paths))
292298

293299
return uxds
300+
301+
302+
def concat(objs, *args, **kwargs):
303+
# Ensure there is at least one object to concat.
304+
if not objs:
305+
raise ValueError("No objects provided for concatenation.")
306+
307+
ref_uxgrid = getattr(objs[0], "uxgrid", None)
308+
if ref_uxgrid is None:
309+
raise AttributeError("The first object does not have a 'uxgrid' attribute.")
310+
311+
ref_id = id(ref_uxgrid)
312+
313+
for i, obj in enumerate(objs):
314+
uxgrid = getattr(obj, "uxgrid", None)
315+
if uxgrid is None:
316+
raise AttributeError(
317+
f"Object at index {i} does not have a 'uxgrid' attribute."
318+
)
319+
if id(uxgrid) != ref_id:
320+
raise ValueError(f"Object at index {i} has a different 'uxgrid' attribute.")
321+
322+
res = xr.concat(objs, *args, **kwargs)
323+
return UxDataset(res, uxgrid=uxgrid)
324+
325+
326+
concat.__doc__ = xr.concat.__doc__

uxarray/core/dataarray.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,38 @@
11
from __future__ import annotations
22

3+
import warnings
4+
from html import escape
5+
from typing import Hashable, Literal, Optional, TYPE_CHECKING, Any
6+
from warnings import warn
37

4-
import xarray as xr
8+
import cartopy.crs as ccrs
59
import numpy as np
6-
7-
8-
from typing import TYPE_CHECKING, Optional, Hashable, Literal
10+
import xarray as xr
911

1012
import uxarray
11-
from uxarray.formatting_html import array_repr
12-
13-
from html import escape
14-
15-
from xarray.core.options import OPTIONS
16-
17-
from uxarray.grid import Grid
18-
import uxarray.core.dataset
19-
from uxarray.grid.dual import construct_dual
20-
from uxarray.grid.validation import _check_duplicate_nodes_indices
21-
22-
if TYPE_CHECKING:
23-
from uxarray.core.dataset import UxDataset
24-
25-
from xarray.core.utils import UncachedAccessor
26-
27-
13+
from uxarray.core.aggregation import _uxda_grid_aggregate
2814
from uxarray.core.gradient import (
29-
_calculate_grad_on_edge_from_faces,
3015
_calculate_edge_face_difference,
3116
_calculate_edge_node_difference,
17+
_calculate_grad_on_edge_from_faces,
3218
)
33-
34-
from uxarray.plot.accessor import UxDataArrayPlotAccessor
35-
from uxarray.subset import DataArraySubsetAccessor
36-
from uxarray.remap import UxDataArrayRemapAccessor
37-
from uxarray.cross_sections import UxDataArrayCrossSectionAccessor
38-
from uxarray.core.aggregation import _uxda_grid_aggregate
3919
from uxarray.core.utils import _map_dims_to_ugrid
4020
from uxarray.core.zonal import _compute_non_conservative_zonal_mean
21+
from uxarray.cross_sections import UxDataArrayCrossSectionAccessor
22+
from uxarray.formatting_html import array_repr
23+
from uxarray.grid import Grid
24+
from uxarray.grid.dual import construct_dual
25+
from uxarray.grid.validation import _check_duplicate_nodes_indices
26+
from uxarray.plot.accessor import UxDataArrayPlotAccessor
27+
from uxarray.remap import UxDataArrayRemapAccessor
28+
from uxarray.subset import DataArraySubsetAccessor
4129

42-
import warnings
43-
from warnings import warn
30+
from xarray.core.options import OPTIONS
31+
from xarray.core.utils import UncachedAccessor
32+
from xarray.core import dtypes
4433

45-
import cartopy.crs as ccrs
34+
if TYPE_CHECKING:
35+
from uxarray.core.dataset import UxDataset
4636

4737

4838
class UxDataArray(xr.DataArray):
@@ -1333,3 +1323,13 @@ def get_dual(self):
13331323
uxda = uxarray.UxDataArray(uxgrid=dual, data=data, dims=dims, name=self.name)
13341324

13351325
return uxda
1326+
1327+
def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False):
1328+
return UxDataArray(super().where(cond, other, drop), uxgrid=self.uxgrid)
1329+
1330+
where.__doc__ = xr.DataArray.where.__doc__
1331+
1332+
def fillna(self, value: Any):
1333+
return UxDataArray(super().fillna(value), uxgrid=self.uxgrid)
1334+
1335+
fillna.__doc__ = xr.DataArray.fillna.__doc__

0 commit comments

Comments
 (0)