diff --git a/flox/xarray.py b/flox/xarray.py index 11cf706d4..e6108d1fd 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -250,6 +250,16 @@ def xarray_reduce( else: ds = obj._to_temp_dataset() + # These will need to be broadcast/reduced as data_vars + reset_non_dim_coords = [ + name + for name in ds._coord_names + if any(dim in ds._variables[name].dims for dim in grouper_dims) + and name not in maybe_drop + and name not in ds._indexes + ] + ds = ds.reset_coords(reset_non_dim_coords) + try: from xarray.indexes import PandasMultiIndex except ImportError: @@ -475,6 +485,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): if all(d not in ds_broad[var].dims for d in dim_tuple): actual[var] = ds_broad[var] + actual = actual.set_coords(reset_non_dim_coords) for newdim in newdims: actual.coords[newdim.name] = newdim.values if newdim.is_scalar else np.array(newdim.values) @@ -524,7 +535,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): obj, xr.Dataset ) # do not re-order dataarrays inside datasets actual[var] = _restore_dim_order( - actual[var], template, by_da[0], no_groupby_reorder=no_groupby_reorder + actual[var].variable, template, by_da[0], no_groupby_reorder=no_groupby_reorder ) if missing_dim: diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 11b2e23cb..66ee62c15 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -749,3 +749,20 @@ def test_direct_reduction(func): with xr.set_options(use_flox=False): expected = getattr(data.groupby("x", squeeze=False), func)(**kwargs) xr.testing.assert_identical(expected, actual) + + +def test_non_dim_coords_with_core_dim(): + coords = {"a": ("x", [0, 0, 1, 1]), "b": ("y", [0, 0, 1, 1])} + square = xr.DataArray(np.arange(16).reshape(4, 4), coords=coords, dims=["x", "y"]) + actual = xarray_reduce(square, "a", "b", func="mean") + expected = xr.DataArray( + np.array([[2.5, 4.5], [10.5, 12.5]]), + dims=("a", "b"), + coords={"a": [0, 1], "b": [0, 1]}, + ) + xr.testing.assert_identical(actual, expected) + + actual = xarray_reduce(square, "x", "y", func="mean") + expected = square.astype(np.float64).copy() + expected["a"], expected["b"] = xr.broadcast(square.a, square.b) + xr.testing.assert_identical(actual, expected)