Skip to content

Commit 2ce0639

Browse files
committed
Various small fixes
1 parent eb50f50 commit 2ce0639

File tree

7 files changed

+29
-17
lines changed

7 files changed

+29
-17
lines changed

xarray/core/coordinates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def _update_coords(self, coords):
193193

194194
self._data._variables = variables
195195
self._data._coord_names.update(new_coord_names)
196-
self._data._dims = dict(dims)
196+
self._data._dims = dims
197197
self._data._indexes = None
198198

199199
def __delitem__(self, key):

xarray/core/dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _infer_coords_and_dims(shape, coords, dims):
6666
for dim, coord in zip(dims, coords):
6767
var = as_variable(coord, name=dim)
6868
var.dims = (dim,)
69-
new_coords[dim] = var
69+
new_coords[dim] = var.to_index_variable()
7070

7171
sizes = dict(zip(dims, shape))
7272
for k, v in new_coords.items():

xarray/core/dataset.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def calculate_dimensions(variables):
101101
Returns dictionary mapping from dimension names to sizes. Raises ValueError
102102
if any of the dimension sizes conflict.
103103
"""
104-
dims = OrderedDict()
104+
dims = {}
105105
last_used = {}
106106
scalar_vars = set(k for k, v in variables.items() if not v.dims)
107107
for k, var in variables.items():
@@ -693,7 +693,7 @@ def _construct_direct(cls, variables, coord_names, dims, attrs=None,
693693

694694
@classmethod
695695
def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None):
696-
dims = dict(calculate_dimensions(variables))
696+
dims = calculate_dimensions(variables)
697697
return cls._construct_direct(variables, coord_names, dims, attrs)
698698

699699
# TODO(shoyer): renable type checking on this signature when pytype has a
@@ -754,18 +754,20 @@ def _replace_with_new_dims( # type: ignore
754754
coord_names: set = None,
755755
attrs: 'Optional[OrderedDict]' = __default,
756756
indexes: 'Optional[OrderedDict[Any, pd.Index]]' = __default,
757+
encoding: Optional[dict] = __default,
757758
inplace: bool = False,
758759
) -> T:
759760
"""Replace variables with recalculated dimensions."""
760-
dims = dict(calculate_dimensions(variables))
761+
dims = calculate_dimensions(variables)
761762
return self._replace(
762-
variables, coord_names, dims, attrs, indexes, inplace=inplace)
763+
variables, coord_names, dims, attrs, indexes, encoding,
764+
inplace=inplace)
763765

764766
def _replace_vars_and_dims( # type: ignore
765767
self: T,
766768
variables: 'OrderedDict[Any, Variable]' = None,
767769
coord_names: set = None,
768-
dims: 'OrderedDict[Any, int]' = None,
770+
dims: Dict[Any, int] = None,
769771
attrs: 'Optional[OrderedDict]' = __default,
770772
inplace: bool = False,
771773
) -> T:
@@ -1081,6 +1083,7 @@ def __delitem__(self, key):
10811083
"""
10821084
del self._variables[key]
10831085
self._coord_names.discard(key)
1086+
self._dims = calculate_dimensions(self._variables)
10841087

10851088
# mutable objects should not be hashable
10861089
# https://github.com/python/mypy/issues/4266
@@ -2463,7 +2466,7 @@ def expand_dims(self, dim=None, axis=None, **dim_kwargs):
24632466
else:
24642467
# If dims includes a label of a non-dimension coordinate,
24652468
# it will be promoted to a 1D coordinate with a single value.
2466-
variables[k] = v.set_dims(k)
2469+
variables[k] = v.set_dims(k).to_index_variable()
24672470

24682471
new_dims = self._dims.copy()
24692472
new_dims.update(dim)
@@ -3548,12 +3551,15 @@ def from_dict(cls, d):
35483551
def _unary_op(f, keep_attrs=False):
35493552
@functools.wraps(f)
35503553
def func(self, *args, **kwargs):
3551-
ds = self.coords.to_dataset()
3552-
for k in self.data_vars:
3553-
ds._variables[k] = f(self._variables[k], *args, **kwargs)
3554-
if keep_attrs:
3555-
ds._attrs = self._attrs
3556-
return ds
3554+
variables = OrderedDict()
3555+
for k, v in self._variables.items():
3556+
if k in self._coord_names:
3557+
variables[k] = v
3558+
else:
3559+
variables[k] = f(v, *args, **kwargs)
3560+
attrs = self._attrs if keep_attrs else None
3561+
return self._replace_with_new_dims(
3562+
variables, attrs=attrs, encoding=None)
35573563

35583564
return func
35593565

xarray/core/merge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def merge_core(objs,
467467
'coordinates or not in the merged result: %s'
468468
% ambiguous_coords)
469469

470-
return variables, coord_names, dict(dims)
470+
return variables, coord_names, dims
471471

472472

473473
def merge(objects, compat='no_conflicts', join='outer'):

xarray/testing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,9 @@ def _assert_dataset_invariants(ds: Dataset):
230230

231231
assert type(ds._dims) is dict, ds._dims
232232
assert all(isinstance(v, int) for v in ds._dims.values()), ds._dims
233-
var_dims = set.union(*[set(v.dims) for v in ds._variables.values()])
233+
var_dims = set() # type: set
234+
for v in ds._variables.values():
235+
var_dims.update(v.dims)
234236
assert ds._dims.keys() == var_dims, (set(ds._dims), var_dims)
235237
assert all(ds._dims[k] == v.sizes[k]
236238
for v in ds._variables.values()

xarray/tests/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ def source_ndarray(array):
183183

184184
# Internal versions of xarray's test functions that validate additional
185185
# invariants
186-
# TODO: add more invariant checks.
187186

188187
def assert_equal(a, b, *, check_invariants=True):
189188
xarray.testing.assert_equal(a, b)

xarray/tests/test_dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2677,6 +2677,11 @@ def test_delitem(self):
26772677
assert set(data.variables) == all_items - set(['var1', 'numbers'])
26782678
assert 'numbers' not in data.coords
26792679

2680+
expected = Dataset()
2681+
actual = Dataset({'y': ('x', [1, 2])})
2682+
del actual['y']
2683+
assert_identical(expected, actual)
2684+
26802685
def test_squeeze(self):
26812686
data = Dataset({'foo': (['x', 'y', 'z'], [[[1], [2]]])})
26822687
for args in [[], [['x']], [['x', 'z']]]:

0 commit comments

Comments
 (0)