Skip to content

Commit 1a659bc

Browse files
committed
12k->14.5k
1 parent 3290e24 commit 1a659bc

File tree

2 files changed

+63
-13
lines changed

2 files changed

+63
-13
lines changed

xarray/core/indexes.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]:
161161
raise NotImplementedError()
162162

163163
def create_variables(
164-
self, variables: Mapping[Any, Variable] | None = None
164+
self,
165+
variables: Mapping[Any, Variable] | None = None,
166+
*,
167+
fastpath=False,
165168
) -> IndexVars:
166169
"""Maybe create new coordinate variables from this index.
167170
@@ -575,13 +578,19 @@ class PandasIndex(Index):
575578

576579
__slots__ = ("index", "dim", "coord_dtype")
577580

578-
def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None):
579-
# make a shallow copy: cheap and because the index name may be updated
580-
# here or in other constructors (cannot use pd.Index.rename as this
581-
# constructor is also called from PandasMultiIndex)
582-
index = safe_cast_to_index(array).copy()
581+
def __init__(
582+
self, array: Any, dim: Hashable, coord_dtype: Any = None, *, fastpath=False
583+
):
584+
if fastpath:
585+
index = array
586+
else:
587+
index = safe_cast_to_index(array)
583588

584589
if index.name is None:
590+
# make a shallow copy: cheap and because the index name may be updated
591+
# here or in other constructors (cannot use pd.Index.rename as this
592+
# constructor is also called from PandasMultiIndex)
593+
index = index.copy()
585594
index.name = dim
586595

587596
self.index = index
@@ -596,7 +605,7 @@ def _replace(self, index, dim=None, coord_dtype=None):
596605
dim = self.dim
597606
if coord_dtype is None:
598607
coord_dtype = self.coord_dtype
599-
return type(self)(index, dim, coord_dtype)
608+
return type(self)(index, dim, coord_dtype, fastpath=True)
600609

601610
@classmethod
602611
def from_variables(
@@ -641,6 +650,8 @@ def from_variables(
641650

642651
obj = cls(data, dim, coord_dtype=var.dtype)
643652
assert not isinstance(obj.index, pd.MultiIndex)
653+
# Rename safely
654+
obj.index = obj.index.copy()
644655
obj.index.name = name
645656

646657
return obj
@@ -684,7 +695,7 @@ def concat(
684695
return cls(new_pd_index, dim=dim, coord_dtype=coord_dtype)
685696

686697
def create_variables(
687-
self, variables: Mapping[Any, Variable] | None = None
698+
self, variables: Mapping[Any, Variable] | None = None, *, fastpath=False
688699
) -> IndexVars:
689700
from xarray.core.variable import IndexVariable
690701

@@ -701,7 +712,9 @@ def create_variables(
701712
encoding = None
702713

703714
data = PandasIndexingAdapter(self.index, dtype=self.coord_dtype)
704-
var = IndexVariable(self.dim, data, attrs=attrs, encoding=encoding)
715+
var = IndexVariable(
716+
self.dim, data, attrs=attrs, encoding=encoding, fastpath=fastpath
717+
)
705718
return {name: var}
706719

707720
def to_pandas_index(self) -> pd.Index:
@@ -1122,7 +1135,7 @@ def reorder_levels(
11221135
return self._replace(index, level_coords_dtype=level_coords_dtype)
11231136

11241137
def create_variables(
1125-
self, variables: Mapping[Any, Variable] | None = None
1138+
self, variables: Mapping[Any, Variable] | None = None, *, fastpath=False
11261139
) -> IndexVars:
11271140
from xarray.core.variable import IndexVariable
11281141

@@ -1772,6 +1785,37 @@ def check_variables():
17721785
return not not_equal
17731786

17741787

1788+
def _apply_indexes_fast(indexes: Indexes[Index], args: Mapping[Any, Any], func: str):
1789+
# This function avoids the call to indexes.group_by_index
1790+
# which is really slow when repeatidly iterating through
1791+
# an array. However, it fails to return the correct ID for
1792+
# multi-index arrays
1793+
indexes_fast, coords = indexes._indexes, indexes._variables
1794+
1795+
new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes_fast.items()}
1796+
new_index_variables: dict[Hashable, Variable] = {}
1797+
for name, index in indexes_fast.items():
1798+
coord = coords[name]
1799+
if hasattr(coord, "_indexes"):
1800+
index_vars = {n: coords[n] for n in coord._indexes}
1801+
else:
1802+
index_vars = {name: coord}
1803+
index_dims = {d for var in index_vars.values() for d in var.dims}
1804+
index_args = {k: v for k, v in args.items() if k in index_dims}
1805+
1806+
if index_args:
1807+
new_index = getattr(index, func)(index_args)
1808+
if new_index is not None:
1809+
new_indexes.update({k: new_index for k in index_vars})
1810+
new_index_vars = new_index.create_variables(index_vars, fastpath=True)
1811+
new_index_variables.update(new_index_vars)
1812+
new_index_variables.update(new_index_vars)
1813+
else:
1814+
for k in index_vars:
1815+
new_indexes.pop(k, None)
1816+
return new_indexes, new_index_variables
1817+
1818+
17751819
def _apply_indexes(
17761820
indexes: Indexes[Index],
17771821
args: Mapping[Any, Any],
@@ -1800,7 +1844,10 @@ def isel_indexes(
18001844
indexes: Indexes[Index],
18011845
indexers: Mapping[Any, Any],
18021846
) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
1803-
return _apply_indexes(indexes, indexers, "isel")
1847+
if any(isinstance(v, PandasMultiIndex) for v in indexes._indexes.values()):
1848+
return _apply_indexes(indexes, indexers, "isel")
1849+
else:
1850+
return _apply_indexes_fast(indexes, indexers, "isel")
18041851

18051852

18061853
def roll_indexes(

xarray/core/indexing.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,10 +1662,13 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
16621662

16631663
__slots__ = ("array", "_dtype")
16641664

1665-
def __init__(self, array: pd.Index, dtype: DTypeLike = None):
1665+
def __init__(self, array: pd.Index, dtype: DTypeLike = None, *, fastpath=False):
16661666
from xarray.core.indexes import safe_cast_to_index
16671667

1668-
self.array = safe_cast_to_index(array)
1668+
if fastpath:
1669+
self.array = array
1670+
else:
1671+
self.array = safe_cast_to_index(array)
16691672

16701673
if dtype is None:
16711674
self._dtype = get_valid_numpy_dtype(array)

0 commit comments

Comments
 (0)