@@ -4026,10 +4026,11 @@ def set_index(
4026
4026
dim_coords = either_dict_or_kwargs (indexes , indexes_kwargs , "set_index" )
4027
4027
4028
4028
new_indexes : dict [Hashable , Index ] = {}
4029
- new_variables : dict [Hashable , IndexVariable ] = {}
4030
- maybe_drop_indexes : list [Hashable ] = []
4031
- drop_variables : list [Hashable ] = []
4029
+ new_variables : dict [Hashable , Variable ] = {}
4030
+ drop_indexes : set [Hashable ] = set ()
4031
+ drop_variables : set [Hashable ] = set ()
4032
4032
replace_dims : dict [Hashable , Hashable ] = {}
4033
+ all_var_names : set [Hashable ] = set ()
4033
4034
4034
4035
for dim , _var_names in dim_coords .items ():
4035
4036
if isinstance (_var_names , str ) or not isinstance (_var_names , Sequence ):
@@ -4044,16 +4045,19 @@ def set_index(
4044
4045
+ " variable(s) do not exist"
4045
4046
)
4046
4047
4047
- current_coord_names = self .xindexes .get_all_coords (dim , errors = "ignore" )
4048
+ all_var_names .update (var_names )
4049
+ drop_variables .update (var_names )
4048
4050
4049
- # drop any pre-existing index involved
4050
- maybe_drop_indexes += list (current_coord_names ) + var_names
4051
+ # drop any pre-existing index involved and its corresponding coordinates
4052
+ index_coord_names = self .xindexes .get_all_coords (dim , errors = "ignore" )
4053
+ all_index_coord_names = set (index_coord_names )
4051
4054
for k in var_names :
4052
- maybe_drop_indexes += list (
4055
+ all_index_coord_names . update (
4053
4056
self .xindexes .get_all_coords (k , errors = "ignore" )
4054
4057
)
4055
4058
4056
- drop_variables += var_names
4059
+ drop_indexes .update (all_index_coord_names )
4060
+ drop_variables .update (all_index_coord_names )
4057
4061
4058
4062
if len (var_names ) == 1 and (not append or dim not in self ._indexes ):
4059
4063
var_name = var_names [0 ]
@@ -4065,10 +4069,14 @@ def set_index(
4065
4069
)
4066
4070
idx = PandasIndex .from_variables ({dim : var })
4067
4071
idx_vars = idx .create_variables ({var_name : var })
4072
+
4073
+ # trick to preserve coordinate order in this case
4074
+ if dim in self ._coord_names :
4075
+ drop_variables .remove (dim )
4068
4076
else :
4069
4077
if append :
4070
4078
current_variables = {
4071
- k : self ._variables [k ] for k in current_coord_names
4079
+ k : self ._variables [k ] for k in index_coord_names
4072
4080
}
4073
4081
else :
4074
4082
current_variables = {}
@@ -4083,8 +4091,17 @@ def set_index(
4083
4091
new_indexes .update ({k : idx for k in idx_vars })
4084
4092
new_variables .update (idx_vars )
4085
4093
4094
+ # re-add deindexed coordinates (convert to base variables)
4095
+ for k in drop_variables :
4096
+ if (
4097
+ k not in new_variables
4098
+ and k not in all_var_names
4099
+ and k in self ._coord_names
4100
+ ):
4101
+ new_variables [k ] = self ._variables [k ].to_base_variable ()
4102
+
4086
4103
indexes_ : dict [Any , Index ] = {
4087
- k : v for k , v in self ._indexes .items () if k not in maybe_drop_indexes
4104
+ k : v for k , v in self ._indexes .items () if k not in drop_indexes
4088
4105
}
4089
4106
indexes_ .update (new_indexes )
4090
4107
@@ -4099,7 +4116,7 @@ def set_index(
4099
4116
new_dims = [replace_dims .get (d , d ) for d in v .dims ]
4100
4117
variables [k ] = v ._replace (dims = new_dims )
4101
4118
4102
- coord_names = self ._coord_names - set ( drop_variables ) | set (new_variables )
4119
+ coord_names = self ._coord_names - drop_variables | set (new_variables )
4103
4120
4104
4121
return self ._replace_with_new_dims (
4105
4122
variables , coord_names = coord_names , indexes = indexes_
@@ -4139,35 +4156,60 @@ def reset_index(
4139
4156
f"{ tuple (invalid_coords )} are not coordinates with an index"
4140
4157
)
4141
4158
4142
- drop_indexes : list [Hashable ] = []
4143
- drop_variables : list [Hashable ] = []
4144
- replaced_indexes : list [ PandasMultiIndex ] = []
4159
+ drop_indexes : set [Hashable ] = set ()
4160
+ drop_variables : set [Hashable ] = set ()
4161
+ seen : set [ Index ] = set ()
4145
4162
new_indexes : dict [Hashable , Index ] = {}
4146
- new_variables : dict [Hashable , IndexVariable ] = {}
4163
+ new_variables : dict [Hashable , Variable ] = {}
4164
+
4165
+ def drop_or_convert (var_names ):
4166
+ if drop :
4167
+ drop_variables .update (var_names )
4168
+ else :
4169
+ base_vars = {
4170
+ k : self ._variables [k ].to_base_variable () for k in var_names
4171
+ }
4172
+ new_variables .update (base_vars )
4147
4173
4148
4174
for name in dims_or_levels :
4149
4175
index = self ._indexes [name ]
4150
- drop_indexes += list (self .xindexes .get_all_coords (name ))
4151
-
4152
- if isinstance (index , PandasMultiIndex ) and name not in self .dims :
4153
- # special case for pd.MultiIndex (name is an index level):
4154
- # replace by a new index with dropped level(s) instead of just drop the index
4155
- if index not in replaced_indexes :
4156
- level_names = index .index .names
4157
- level_vars = {
4158
- k : self ._variables [k ]
4159
- for k in level_names
4160
- if k not in dims_or_levels
4161
- }
4162
- if level_vars :
4163
- idx = index .keep_levels (level_vars )
4164
- idx_vars = idx .create_variables (level_vars )
4165
- new_indexes .update ({k : idx for k in idx_vars })
4166
- new_variables .update (idx_vars )
4167
- replaced_indexes .append (index )
4168
4176
4169
- if drop :
4170
- drop_variables .append (name )
4177
+ if index in seen :
4178
+ continue
4179
+ seen .add (index )
4180
+
4181
+ idx_var_names = set (self .xindexes .get_all_coords (name ))
4182
+ drop_indexes .update (idx_var_names )
4183
+
4184
+ if isinstance (index , PandasMultiIndex ):
4185
+ # special case for pd.MultiIndex
4186
+ level_names = index .index .names
4187
+ keep_level_vars = {
4188
+ k : self ._variables [k ]
4189
+ for k in level_names
4190
+ if k not in dims_or_levels
4191
+ }
4192
+
4193
+ if index .dim not in dims_or_levels and keep_level_vars :
4194
+ # do not drop the multi-index completely
4195
+ # instead replace it by a new (multi-)index with dropped level(s)
4196
+ idx = index .keep_levels (keep_level_vars )
4197
+ idx_vars = idx .create_variables (keep_level_vars )
4198
+ new_indexes .update ({k : idx for k in idx_vars })
4199
+ new_variables .update (idx_vars )
4200
+ if not isinstance (idx , PandasMultiIndex ):
4201
+ # multi-index reduced to single index
4202
+ # backward compatibility: unique level coordinate renamed to dimension
4203
+ drop_variables .update (keep_level_vars )
4204
+ drop_or_convert (
4205
+ [k for k in level_names if k not in keep_level_vars ]
4206
+ )
4207
+ else :
4208
+ # always drop the multi-index dimension variable
4209
+ drop_variables .add (index .dim )
4210
+ drop_or_convert (level_names )
4211
+ else :
4212
+ drop_or_convert (idx_var_names )
4171
4213
4172
4214
indexes = {k : v for k , v in self ._indexes .items () if k not in drop_indexes }
4173
4215
indexes .update (new_indexes )
@@ -4177,9 +4219,11 @@ def reset_index(
4177
4219
}
4178
4220
variables .update (new_variables )
4179
4221
4180
- coord_names = set ( new_variables ) | self ._coord_names
4222
+ coord_names = self ._coord_names - drop_variables
4181
4223
4182
- return self ._replace (variables , coord_names = coord_names , indexes = indexes )
4224
+ return self ._replace_with_new_dims (
4225
+ variables , coord_names = coord_names , indexes = indexes
4226
+ )
4183
4227
4184
4228
def reorder_levels (
4185
4229
self : T_Dataset ,
0 commit comments