Skip to content

Commit 3f0049f

Browse files
crusaderkyJoe Hamman
authored andcommitted
Speed up isel and __getitem__ (#3375)
* Variable.isel cleanup/speedup * Dataset.isel code cleanup * Speed up isel * What's New * Better error checks * Speedup * type annotations * Update doc/whats-new.rst Co-Authored-By: Maximilian Roos <[email protected]> * What's New * What's New * Always shallow-copy variables
1 parent 132733a commit 3f0049f

File tree

4 files changed

+93
-63
lines changed

4 files changed

+93
-63
lines changed

doc/whats-new.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,27 @@ Breaking changes
4242

4343
(:issue:`3222`, :issue:`3293`, :issue:`3340`, :issue:`3346`, :issue:`3358`).
4444
By `Guido Imperiale <https://github.com/crusaderky>`_.
45+
- Dropped the 'drop=False' optional parameter from :meth:`Variable.isel`.
46+
It was unused and doesn't make sense for a Variable.
47+
(:pull:`3375`) by `Guido Imperiale <https://github.com/crusaderky>`_.
4548

4649
New functions/methods
4750
~~~~~~~~~~~~~~~~~~~~~
4851

4952
Enhancements
5053
~~~~~~~~~~~~
5154

52-
- Add a repr for :py:class:`~xarray.core.GroupBy` objects (:issue:`3344`).
55+
- Add a repr for :py:class:`~xarray.core.GroupBy` objects.
5356
Example::
5457

5558
>>> da.groupby("time.season")
5659
DataArrayGroupBy, grouped over 'season'
5760
4 groups with labels 'DJF', 'JJA', 'MAM', 'SON'
5861

59-
By `Deepak Cherian <https://github.com/dcherian>`_.
62+
(:issue:`3344`) by `Deepak Cherian <https://github.com/dcherian>`_.
63+
- Speed up :meth:`Dataset.isel` up to 33% and :meth:`DataArray.isel` up to 25% for small
64+
arrays (:issue:`2799`, :pull:`3375`) by
65+
`Guido Imperiale <https://github.com/crusaderky>`_.
6066

6167
Bug fixes
6268
~~~~~~~~~

xarray/core/dataset.py

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,8 +1745,8 @@ def maybe_chunk(name, var, chunks):
17451745
return self._replace(variables)
17461746

17471747
def _validate_indexers(
1748-
self, indexers: Mapping
1749-
) -> List[Tuple[Any, Union[slice, Variable]]]:
1748+
self, indexers: Mapping[Hashable, Any]
1749+
) -> Iterator[Tuple[Hashable, Union[int, slice, np.ndarray, Variable]]]:
17501750
""" Here we make sure
17511751
+ indexer has a valid keys
17521752
+ indexer is in a valid data type
@@ -1755,50 +1755,61 @@ def _validate_indexers(
17551755
"""
17561756
from .dataarray import DataArray
17571757

1758-
invalid = [k for k in indexers if k not in self.dims]
1758+
invalid = indexers.keys() - self.dims.keys()
17591759
if invalid:
17601760
raise ValueError("dimensions %r do not exist" % invalid)
17611761

17621762
# all indexers should be int, slice, np.ndarrays, or Variable
1763-
indexers_list: List[Tuple[Any, Union[slice, Variable]]] = []
17641763
for k, v in indexers.items():
1765-
if isinstance(v, slice):
1766-
indexers_list.append((k, v))
1767-
continue
1768-
1769-
if isinstance(v, Variable):
1770-
pass
1764+
if isinstance(v, (int, slice, Variable)):
1765+
yield k, v
17711766
elif isinstance(v, DataArray):
1772-
v = v.variable
1767+
yield k, v.variable
17731768
elif isinstance(v, tuple):
1774-
v = as_variable(v)
1769+
yield k, as_variable(v)
17751770
elif isinstance(v, Dataset):
17761771
raise TypeError("cannot use a Dataset as an indexer")
17771772
elif isinstance(v, Sequence) and len(v) == 0:
1778-
v = Variable((k,), np.zeros((0,), dtype="int64"))
1773+
yield k, np.empty((0,), dtype="int64")
17791774
else:
17801775
v = np.asarray(v)
17811776

1782-
if v.dtype.kind == "U" or v.dtype.kind == "S":
1777+
if v.dtype.kind in "US":
17831778
index = self.indexes[k]
17841779
if isinstance(index, pd.DatetimeIndex):
17851780
v = v.astype("datetime64[ns]")
17861781
elif isinstance(index, xr.CFTimeIndex):
17871782
v = _parse_array_of_cftime_strings(v, index.date_type)
17881783

1789-
if v.ndim == 0:
1790-
v = Variable((), v)
1791-
elif v.ndim == 1:
1792-
v = Variable((k,), v)
1793-
else:
1784+
if v.ndim > 1:
17941785
raise IndexError(
17951786
"Unlabeled multi-dimensional array cannot be "
17961787
"used for indexing: {}".format(k)
17971788
)
1789+
yield k, v
17981790

1799-
indexers_list.append((k, v))
1800-
1801-
return indexers_list
1791+
def _validate_interp_indexers(
1792+
self, indexers: Mapping[Hashable, Any]
1793+
) -> Iterator[Tuple[Hashable, Variable]]:
1794+
"""Variant of _validate_indexers to be used for interpolation
1795+
"""
1796+
for k, v in self._validate_indexers(indexers):
1797+
if isinstance(v, Variable):
1798+
if v.ndim == 1:
1799+
yield k, v.to_index_variable()
1800+
else:
1801+
yield k, v
1802+
elif isinstance(v, int):
1803+
yield k, Variable((), v)
1804+
elif isinstance(v, np.ndarray):
1805+
if v.ndim == 0:
1806+
yield k, Variable((), v)
1807+
elif v.ndim == 1:
1808+
yield k, IndexVariable((k,), v)
1809+
else:
1810+
raise AssertionError() # Already tested by _validate_indexers
1811+
else:
1812+
raise TypeError(type(v))
18021813

18031814
def _get_indexers_coords_and_indexes(self, indexers):
18041815
"""Extract coordinates and indexes from indexers.
@@ -1885,10 +1896,10 @@ def isel(
18851896
Dataset.sel
18861897
DataArray.isel
18871898
"""
1888-
18891899
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
1890-
1891-
indexers_list = self._validate_indexers(indexers)
1900+
# Note: we need to preserve the original indexers variable in order to merge the
1901+
# coords below
1902+
indexers_list = list(self._validate_indexers(indexers))
18921903

18931904
variables = OrderedDict() # type: OrderedDict[Hashable, Variable]
18941905
indexes = OrderedDict() # type: OrderedDict[Hashable, pd.Index]
@@ -1904,19 +1915,21 @@ def isel(
19041915
)
19051916
if new_index is not None:
19061917
indexes[name] = new_index
1907-
else:
1918+
elif var_indexers:
19081919
new_var = var.isel(indexers=var_indexers)
1920+
else:
1921+
new_var = var.copy(deep=False)
19091922

19101923
variables[name] = new_var
19111924

1912-
coord_names = set(variables).intersection(self._coord_names)
1925+
coord_names = self._coord_names & variables.keys()
19131926
selected = self._replace_with_new_dims(variables, coord_names, indexes)
19141927

19151928
# Extract coordinates from indexers
19161929
coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(indexers)
19171930
variables.update(coord_vars)
19181931
indexes.update(new_indexes)
1919-
coord_names = set(variables).intersection(self._coord_names).union(coord_vars)
1932+
coord_names = self._coord_names & variables.keys() | coord_vars.keys()
19201933
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)
19211934

19221935
def sel(
@@ -2478,11 +2491,9 @@ def interp(
24782491

24792492
if kwargs is None:
24802493
kwargs = {}
2494+
24812495
coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
2482-
indexers = OrderedDict(
2483-
(k, v.to_index_variable() if isinstance(v, Variable) and v.ndim == 1 else v)
2484-
for k, v in self._validate_indexers(coords)
2485-
)
2496+
indexers = OrderedDict(self._validate_interp_indexers(coords))
24862497

24872498
obj = self if assume_sorted else self.sortby([k for k in coords])
24882499

@@ -2507,26 +2518,25 @@ def _validate_interp_indexer(x, new_x):
25072518
"strings or datetimes. "
25082519
"Instead got\n{}".format(new_x)
25092520
)
2510-
else:
2511-
return (x, new_x)
2521+
return x, new_x
25122522

25132523
variables = OrderedDict() # type: OrderedDict[Hashable, Variable]
25142524
for name, var in obj._variables.items():
2515-
if name not in indexers:
2516-
if var.dtype.kind in "uifc":
2517-
var_indexers = {
2518-
k: _validate_interp_indexer(maybe_variable(obj, k), v)
2519-
for k, v in indexers.items()
2520-
if k in var.dims
2521-
}
2522-
variables[name] = missing.interp(
2523-
var, var_indexers, method, **kwargs
2524-
)
2525-
elif all(d not in indexers for d in var.dims):
2526-
# keep unrelated object array
2527-
variables[name] = var
2525+
if name in indexers:
2526+
continue
2527+
2528+
if var.dtype.kind in "uifc":
2529+
var_indexers = {
2530+
k: _validate_interp_indexer(maybe_variable(obj, k), v)
2531+
for k, v in indexers.items()
2532+
if k in var.dims
2533+
}
2534+
variables[name] = missing.interp(var, var_indexers, method, **kwargs)
2535+
elif all(d not in indexers for d in var.dims):
2536+
# keep unrelated object array
2537+
variables[name] = var
25282538

2529-
coord_names = set(variables).intersection(obj._coord_names)
2539+
coord_names = obj._coord_names & variables.keys()
25302540
indexes = OrderedDict(
25312541
(k, v) for k, v in obj.indexes.items() if k not in indexers
25322542
)
@@ -2546,7 +2556,7 @@ def _validate_interp_indexer(x, new_x):
25462556
variables.update(coord_vars)
25472557
indexes.update(new_indexes)
25482558

2549-
coord_names = set(variables).intersection(obj._coord_names).union(coord_vars)
2559+
coord_names = obj._coord_names & variables.keys() | coord_vars.keys()
25502560
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)
25512561

25522562
def interp_like(

xarray/core/indexes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections import OrderedDict
33
from typing import Any, Hashable, Iterable, Mapping, Optional, Tuple, Union
44

5+
import numpy as np
56
import pandas as pd
67

78
from . import formatting
@@ -63,7 +64,7 @@ def isel_variable_and_index(
6364
name: Hashable,
6465
variable: Variable,
6566
index: pd.Index,
66-
indexers: Mapping[Any, Union[slice, Variable]],
67+
indexers: Mapping[Hashable, Union[int, slice, np.ndarray, Variable]],
6768
) -> Tuple[Variable, Optional[pd.Index]]:
6869
"""Index a Variable and pandas.Index together."""
6970
if not indexers:

xarray/core/variable.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import OrderedDict, defaultdict
44
from datetime import timedelta
55
from distutils.version import LooseVersion
6-
from typing import Any, Hashable, Mapping, Union
6+
from typing import Any, Hashable, Mapping, Union, TypeVar
77

88
import numpy as np
99
import pandas as pd
@@ -41,6 +41,18 @@
4141
# https://github.com/python/mypy/issues/224
4242
BASIC_INDEXING_TYPES = integer_types + (slice,) # type: ignore
4343

44+
VariableType = TypeVar("VariableType", bound="Variable")
45+
"""Type annotation to be used when methods of Variable return self or a copy of self.
46+
When called from an instance of a subclass, e.g. IndexVariable, mypy identifies the
47+
output as an instance of the subclass.
48+
49+
Usage::
50+
51+
class Variable:
52+
def f(self: VariableType, ...) -> VariableType:
53+
...
54+
"""
55+
4456

4557
class MissingDimensionsError(ValueError):
4658
"""Error class used when we can't safely guess a dimension name.
@@ -663,8 +675,8 @@ def _broadcast_indexes_vectorized(self, key):
663675

664676
return out_dims, VectorizedIndexer(tuple(out_key)), new_order
665677

666-
def __getitem__(self, key):
667-
"""Return a new Array object whose contents are consistent with
678+
def __getitem__(self: VariableType, key) -> VariableType:
679+
"""Return a new Variable object whose contents are consistent with
668680
getting the provided key from the underlying data.
669681
670682
NB. __getitem__ and __setitem__ implement xarray-style indexing,
@@ -682,7 +694,7 @@ def __getitem__(self, key):
682694
data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order)
683695
return self._finalize_indexing_result(dims, data)
684696

685-
def _finalize_indexing_result(self, dims, data):
697+
def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType:
686698
"""Used by IndexVariable to return IndexVariable objects when possible.
687699
"""
688700
return type(self)(dims, data, self._attrs, self._encoding, fastpath=True)
@@ -957,7 +969,11 @@ def chunk(self, chunks=None, name=None, lock=False):
957969

958970
return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True)
959971

960-
def isel(self, indexers=None, drop=False, **indexers_kwargs):
972+
def isel(
973+
self: VariableType,
974+
indexers: Mapping[Hashable, Any] = None,
975+
**indexers_kwargs: Any
976+
) -> VariableType:
961977
"""Return a new array indexed along the specified dimension(s).
962978
963979
Parameters
@@ -976,15 +992,12 @@ def isel(self, indexers=None, drop=False, **indexers_kwargs):
976992
"""
977993
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
978994

979-
invalid = [k for k in indexers if k not in self.dims]
995+
invalid = indexers.keys() - set(self.dims)
980996
if invalid:
981997
raise ValueError("dimensions %r do not exist" % invalid)
982998

983-
key = [slice(None)] * self.ndim
984-
for i, dim in enumerate(self.dims):
985-
if dim in indexers:
986-
key[i] = indexers[dim]
987-
return self[tuple(key)]
999+
key = tuple(indexers.get(dim, slice(None)) for dim in self.dims)
1000+
return self[key]
9881001

9891002
def squeeze(self, dim=None):
9901003
"""Return a new object with squeezed data.

0 commit comments

Comments
 (0)