Skip to content

Commit a042ae6

Browse files
authored
Review (re)set_index (#6992)
* review reset_index + tests Restore old behavior, i.e., - drop the multi-index dimension name (even if drop=False) unless reset_index still returns a multi-index - rename the level coordinate to the dimension name if the multi-index is reduced to a single index - drop the whole multi-index if its dimension coordinate is given as argument Fix IndexVariable -> Variable conversion * reset_index: fix dropped dimension(s) * reset_index: fix other tests * review set_index - Convert the coordinates left unindexed from IndexVariable to Variable - Keep multi-index coordinates next to each other in a consistent order * set_index with single index: preserve coord order * update what's new
1 parent 45c0a11 commit a042ae6

File tree

6 files changed

+157
-49
lines changed

6 files changed

+157
-49
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ Bug fixes
6565
By `András Gunyhó <https://github.com/mgunyho>`_.
6666
- Avoid use of random numbers in `test_weighted.test_weighted_operations_nonequal_coords` (:issue:`6504`, :pull:`6961`).
6767
By `Luke Conibear <https://github.com/lukeconibear>`_.
68+
- Fix multiple regression issues with :py:meth:`Dataset.set_index` and
69+
:py:meth:`Dataset.reset_index` (:pull:`6992`)
70+
By `Benoît Bovy <https://github.com/benbovy>`_.
6871
- Raise a ``UserWarning`` when renaming a coordinate or a dimension creates a
6972
non-indexed dimension coordinate, and suggest the user creating an index
7073
either with ``swap_dims`` or ``set_index`` (:issue:`6607`, :pull:`6999`). By

xarray/core/dataset.py

Lines changed: 81 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4026,10 +4026,11 @@ def set_index(
40264026
dim_coords = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index")
40274027

40284028
new_indexes: dict[Hashable, Index] = {}
4029-
new_variables: dict[Hashable, IndexVariable] = {}
4030-
maybe_drop_indexes: list[Hashable] = []
4031-
drop_variables: list[Hashable] = []
4029+
new_variables: dict[Hashable, Variable] = {}
4030+
drop_indexes: set[Hashable] = set()
4031+
drop_variables: set[Hashable] = set()
40324032
replace_dims: dict[Hashable, Hashable] = {}
4033+
all_var_names: set[Hashable] = set()
40334034

40344035
for dim, _var_names in dim_coords.items():
40354036
if isinstance(_var_names, str) or not isinstance(_var_names, Sequence):
@@ -4044,16 +4045,19 @@ def set_index(
40444045
+ " variable(s) do not exist"
40454046
)
40464047

4047-
current_coord_names = self.xindexes.get_all_coords(dim, errors="ignore")
4048+
all_var_names.update(var_names)
4049+
drop_variables.update(var_names)
40484050

4049-
# drop any pre-existing index involved
4050-
maybe_drop_indexes += list(current_coord_names) + var_names
4051+
# drop any pre-existing index involved and its corresponding coordinates
4052+
index_coord_names = self.xindexes.get_all_coords(dim, errors="ignore")
4053+
all_index_coord_names = set(index_coord_names)
40514054
for k in var_names:
4052-
maybe_drop_indexes += list(
4055+
all_index_coord_names.update(
40534056
self.xindexes.get_all_coords(k, errors="ignore")
40544057
)
40554058

4056-
drop_variables += var_names
4059+
drop_indexes.update(all_index_coord_names)
4060+
drop_variables.update(all_index_coord_names)
40574061

40584062
if len(var_names) == 1 and (not append or dim not in self._indexes):
40594063
var_name = var_names[0]
@@ -4065,10 +4069,14 @@ def set_index(
40654069
)
40664070
idx = PandasIndex.from_variables({dim: var})
40674071
idx_vars = idx.create_variables({var_name: var})
4072+
4073+
# trick to preserve coordinate order in this case
4074+
if dim in self._coord_names:
4075+
drop_variables.remove(dim)
40684076
else:
40694077
if append:
40704078
current_variables = {
4071-
k: self._variables[k] for k in current_coord_names
4079+
k: self._variables[k] for k in index_coord_names
40724080
}
40734081
else:
40744082
current_variables = {}
@@ -4083,8 +4091,17 @@ def set_index(
40834091
new_indexes.update({k: idx for k in idx_vars})
40844092
new_variables.update(idx_vars)
40854093

4094+
# re-add deindexed coordinates (convert to base variables)
4095+
for k in drop_variables:
4096+
if (
4097+
k not in new_variables
4098+
and k not in all_var_names
4099+
and k in self._coord_names
4100+
):
4101+
new_variables[k] = self._variables[k].to_base_variable()
4102+
40864103
indexes_: dict[Any, Index] = {
4087-
k: v for k, v in self._indexes.items() if k not in maybe_drop_indexes
4104+
k: v for k, v in self._indexes.items() if k not in drop_indexes
40884105
}
40894106
indexes_.update(new_indexes)
40904107

@@ -4099,7 +4116,7 @@ def set_index(
40994116
new_dims = [replace_dims.get(d, d) for d in v.dims]
41004117
variables[k] = v._replace(dims=new_dims)
41014118

4102-
coord_names = self._coord_names - set(drop_variables) | set(new_variables)
4119+
coord_names = self._coord_names - drop_variables | set(new_variables)
41034120

41044121
return self._replace_with_new_dims(
41054122
variables, coord_names=coord_names, indexes=indexes_
@@ -4139,35 +4156,60 @@ def reset_index(
41394156
f"{tuple(invalid_coords)} are not coordinates with an index"
41404157
)
41414158

4142-
drop_indexes: list[Hashable] = []
4143-
drop_variables: list[Hashable] = []
4144-
replaced_indexes: list[PandasMultiIndex] = []
4159+
drop_indexes: set[Hashable] = set()
4160+
drop_variables: set[Hashable] = set()
4161+
seen: set[Index] = set()
41454162
new_indexes: dict[Hashable, Index] = {}
4146-
new_variables: dict[Hashable, IndexVariable] = {}
4163+
new_variables: dict[Hashable, Variable] = {}
4164+
4165+
def drop_or_convert(var_names):
4166+
if drop:
4167+
drop_variables.update(var_names)
4168+
else:
4169+
base_vars = {
4170+
k: self._variables[k].to_base_variable() for k in var_names
4171+
}
4172+
new_variables.update(base_vars)
41474173

41484174
for name in dims_or_levels:
41494175
index = self._indexes[name]
4150-
drop_indexes += list(self.xindexes.get_all_coords(name))
4151-
4152-
if isinstance(index, PandasMultiIndex) and name not in self.dims:
4153-
# special case for pd.MultiIndex (name is an index level):
4154-
# replace by a new index with dropped level(s) instead of just drop the index
4155-
if index not in replaced_indexes:
4156-
level_names = index.index.names
4157-
level_vars = {
4158-
k: self._variables[k]
4159-
for k in level_names
4160-
if k not in dims_or_levels
4161-
}
4162-
if level_vars:
4163-
idx = index.keep_levels(level_vars)
4164-
idx_vars = idx.create_variables(level_vars)
4165-
new_indexes.update({k: idx for k in idx_vars})
4166-
new_variables.update(idx_vars)
4167-
replaced_indexes.append(index)
41684176

4169-
if drop:
4170-
drop_variables.append(name)
4177+
if index in seen:
4178+
continue
4179+
seen.add(index)
4180+
4181+
idx_var_names = set(self.xindexes.get_all_coords(name))
4182+
drop_indexes.update(idx_var_names)
4183+
4184+
if isinstance(index, PandasMultiIndex):
4185+
# special case for pd.MultiIndex
4186+
level_names = index.index.names
4187+
keep_level_vars = {
4188+
k: self._variables[k]
4189+
for k in level_names
4190+
if k not in dims_or_levels
4191+
}
4192+
4193+
if index.dim not in dims_or_levels and keep_level_vars:
4194+
# do not drop the multi-index completely
4195+
# instead replace it by a new (multi-)index with dropped level(s)
4196+
idx = index.keep_levels(keep_level_vars)
4197+
idx_vars = idx.create_variables(keep_level_vars)
4198+
new_indexes.update({k: idx for k in idx_vars})
4199+
new_variables.update(idx_vars)
4200+
if not isinstance(idx, PandasMultiIndex):
4201+
# multi-index reduced to single index
4202+
# backward compatibility: unique level coordinate renamed to dimension
4203+
drop_variables.update(keep_level_vars)
4204+
drop_or_convert(
4205+
[k for k in level_names if k not in keep_level_vars]
4206+
)
4207+
else:
4208+
# always drop the multi-index dimension variable
4209+
drop_variables.add(index.dim)
4210+
drop_or_convert(level_names)
4211+
else:
4212+
drop_or_convert(idx_var_names)
41714213

41724214
indexes = {k: v for k, v in self._indexes.items() if k not in drop_indexes}
41734215
indexes.update(new_indexes)
@@ -4177,9 +4219,11 @@ def reset_index(
41774219
}
41784220
variables.update(new_variables)
41794221

4180-
coord_names = set(new_variables) | self._coord_names
4222+
coord_names = self._coord_names - drop_variables
41814223

4182-
return self._replace(variables, coord_names=coord_names, indexes=indexes)
4224+
return self._replace_with_new_dims(
4225+
variables, coord_names=coord_names, indexes=indexes
4226+
)
41834227

41844228
def reorder_levels(
41854229
self: T_Dataset,

xarray/core/indexes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,8 +717,11 @@ def keep_levels(
717717
level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names}
718718
return self._replace(index, level_coords_dtype=level_coords_dtype)
719719
else:
720+
# backward compatibility: rename the level coordinate to the dimension name
720721
return PandasIndex(
721-
index, self.dim, coord_dtype=self.level_coords_dtype[index.name]
722+
index.rename(self.dim),
723+
self.dim,
724+
coord_dtype=self.level_coords_dtype[index.name],
722725
)
723726

724727
def reorder_levels(

xarray/tests/test_dataarray.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2007,7 +2007,6 @@ def test_set_index(self) -> None:
20072007
def test_reset_index(self) -> None:
20082008
indexes = [self.mindex.get_level_values(n) for n in self.mindex.names]
20092009
coords = {idx.name: ("x", idx) for idx in indexes}
2010-
coords["x"] = ("x", self.mindex.values)
20112010
expected = DataArray(self.mda.values, coords=coords, dims="x")
20122011

20132012
obj = self.mda.reset_index("x")
@@ -2018,16 +2017,19 @@ def test_reset_index(self) -> None:
20182017
assert len(obj.xindexes) == 0
20192018
obj = self.mda.reset_index(["x", "level_1"])
20202019
assert_identical(obj, expected, check_default_indexes=False)
2021-
assert list(obj.xindexes) == ["level_2"]
2020+
assert len(obj.xindexes) == 0
20222021

2022+
coords = {
2023+
"x": ("x", self.mindex.droplevel("level_1")),
2024+
"level_1": ("x", self.mindex.get_level_values("level_1")),
2025+
}
20232026
expected = DataArray(self.mda.values, coords=coords, dims="x")
20242027
obj = self.mda.reset_index(["level_1"])
20252028
assert_identical(obj, expected, check_default_indexes=False)
2026-
assert list(obj.xindexes) == ["level_2"]
2027-
assert type(obj.xindexes["level_2"]) is PandasIndex
2029+
assert list(obj.xindexes) == ["x"]
2030+
assert type(obj.xindexes["x"]) is PandasIndex
20282031

2029-
coords = {k: v for k, v in coords.items() if k != "x"}
2030-
expected = DataArray(self.mda.values, coords=coords, dims="x")
2032+
expected = DataArray(self.mda.values, dims="x")
20312033
obj = self.mda.reset_index("x", drop=True)
20322034
assert_identical(obj, expected, check_default_indexes=False)
20332035

@@ -2038,14 +2040,16 @@ def test_reset_index(self) -> None:
20382040
# single index
20392041
array = DataArray([1, 2], coords={"x": ["a", "b"]}, dims="x")
20402042
obj = array.reset_index("x")
2041-
assert_identical(obj, array, check_default_indexes=False)
2043+
print(obj.x.variable)
2044+
print(array.x.variable)
2045+
assert_equal(obj.x.variable, array.x.variable.to_base_variable())
20422046
assert len(obj.xindexes) == 0
20432047

20442048
def test_reset_index_keep_attrs(self) -> None:
20452049
coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True})
20462050
da = DataArray([1, 0], [coord_1])
20472051
obj = da.reset_index("coord_1")
2048-
assert_identical(obj, da, check_default_indexes=False)
2052+
assert obj.coord_1.attrs == da.coord_1.attrs
20492053
assert len(obj.xindexes) == 0
20502054

20512055
def test_reorder_levels(self) -> None:

xarray/tests/test_dataset.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3237,12 +3237,31 @@ def test_set_index(self) -> None:
32373237
with pytest.raises(ValueError, match=r"dimension mismatch.*"):
32383238
ds.set_index(y="x_var")
32393239

3240+
def test_set_index_deindexed_coords(self) -> None:
3241+
# test de-indexed coordinates are converted to base variable
3242+
# https://github.com/pydata/xarray/issues/6969
3243+
one = ["a", "a", "b", "b"]
3244+
two = [1, 2, 1, 2]
3245+
three = ["c", "c", "d", "d"]
3246+
four = [3, 4, 3, 4]
3247+
3248+
mindex_12 = pd.MultiIndex.from_arrays([one, two], names=["one", "two"])
3249+
mindex_34 = pd.MultiIndex.from_arrays([three, four], names=["three", "four"])
3250+
3251+
ds = xr.Dataset(
3252+
coords={"x": mindex_12, "three": ("x", three), "four": ("x", four)}
3253+
)
3254+
actual = ds.set_index(x=["three", "four"])
3255+
expected = xr.Dataset(
3256+
coords={"x": mindex_34, "one": ("x", one), "two": ("x", two)}
3257+
)
3258+
assert_identical(actual, expected)
3259+
32403260
def test_reset_index(self) -> None:
32413261
ds = create_test_multiindex()
32423262
mindex = ds["x"].to_index()
32433263
indexes = [mindex.get_level_values(n) for n in mindex.names]
32443264
coords = {idx.name: ("x", idx) for idx in indexes}
3245-
coords["x"] = ("x", mindex.values)
32463265
expected = Dataset({}, coords=coords)
32473266

32483267
obj = ds.reset_index("x")
@@ -3257,9 +3276,45 @@ def test_reset_index_keep_attrs(self) -> None:
32573276
coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True})
32583277
ds = Dataset({}, {"coord_1": coord_1})
32593278
obj = ds.reset_index("coord_1")
3260-
assert_identical(obj, ds, check_default_indexes=False)
3279+
assert ds.coord_1.attrs == obj.coord_1.attrs
32613280
assert len(obj.xindexes) == 0
32623281

3282+
def test_reset_index_drop_dims(self) -> None:
3283+
ds = Dataset(coords={"x": [1, 2]})
3284+
reset = ds.reset_index("x", drop=True)
3285+
assert len(reset.dims) == 0
3286+
3287+
@pytest.mark.parametrize(
3288+
"arg,drop,dropped,converted,renamed",
3289+
[
3290+
("foo", False, [], [], {"bar": "x"}),
3291+
("foo", True, ["foo"], [], {"bar": "x"}),
3292+
("x", False, ["x"], ["foo", "bar"], {}),
3293+
("x", True, ["x", "foo", "bar"], [], {}),
3294+
(["foo", "bar"], False, ["x"], ["foo", "bar"], {}),
3295+
(["foo", "bar"], True, ["x", "foo", "bar"], [], {}),
3296+
(["x", "foo"], False, ["x"], ["foo", "bar"], {}),
3297+
(["foo", "x"], True, ["x", "foo", "bar"], [], {}),
3298+
],
3299+
)
3300+
def test_reset_index_drop_convert(
3301+
self, arg, drop, dropped, converted, renamed
3302+
) -> None:
3303+
# regressions https://github.com/pydata/xarray/issues/6946 and
3304+
# https://github.com/pydata/xarray/issues/6989
3305+
# check that multi-index dimension or level coordinates are dropped, converted
3306+
# from IndexVariable to Variable or renamed to dimension as expected
3307+
midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("foo", "bar"))
3308+
ds = xr.Dataset(coords={"x": midx})
3309+
reset = ds.reset_index(arg, drop=drop)
3310+
3311+
for name in dropped:
3312+
assert name not in reset.variables
3313+
for name in converted:
3314+
assert_identical(reset[name].variable, ds[name].variable.to_base_variable())
3315+
for old_name, new_name in renamed.items():
3316+
assert_identical(ds[old_name].variable, reset[new_name].variable)
3317+
32633318
def test_reorder_levels(self) -> None:
32643319
ds = create_test_multiindex()
32653320
mindex = ds["x"].to_index()

xarray/tests/test_groupby.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,6 @@ def test_groupby_drops_nans() -> None:
538538
.rename({"xy": "id"})
539539
.to_dataset()
540540
.reset_index("id", drop=True)
541-
.drop_vars(["lon", "lat"])
542541
.assign(id=stacked.id.values)
543542
.dropna("id")
544543
.transpose(*actual2.dims)

0 commit comments

Comments
 (0)