Skip to content

Commit 86c1cc0

Browse files
committed
Add prototype groupers
1 parent 920cc53 commit 86c1cc0

File tree

3 files changed

+243
-0
lines changed

3 files changed

+243
-0
lines changed

xarray/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from xarray.coding.cftimeindex import CFTimeIndex
1515
from xarray.coding.frequencies import infer_freq
1616
from xarray.conventions import SerializationWarning, decode_cf
17+
from xarray.core import groupers
1718
from xarray.core.alignment import align, broadcast
1819
from xarray.core.combine import combine_by_coords, combine_nested
1920
from xarray.core.common import ALL_DIMS, full_like, ones_like, zeros_like
@@ -54,6 +55,7 @@
5455
# `mypy --strict` running in projects that import xarray.
5556
__all__ = (
5657
# Sub-packages
58+
"groupers",
5759
"testing",
5860
"tutorial",
5961
# Top-level functions

xarray/core/groupers.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import itertools
2+
from collections.abc import Mapping, Sequence
3+
from dataclasses import dataclass, field
4+
5+
import numpy as np
6+
import pandas as pd
7+
8+
from xarray.core.groupby import Grouper, Resampler
9+
from xarray.core.variable import IndexVariable
10+
11+
12+
## From toolz
13+
## TODO: move to compat file, add license
14+
def sliding_window(n, seq):
15+
"""A sequence of overlapping subsequences
16+
17+
>>> list(sliding_window(2, [1, 2, 3, 4]))
18+
[(1, 2), (2, 3), (3, 4)]
19+
20+
This function creates a sliding window suitable for transformations like
21+
sliding means / smoothing
22+
23+
>>> mean = lambda seq: float(sum(seq)) / len(seq)
24+
>>> list(map(mean, sliding_window(2, [1, 2, 3, 4])))
25+
[1.5, 2.5, 3.5]
26+
"""
27+
import collections
28+
import itertools
29+
30+
return zip(
31+
*(
32+
collections.deque(itertools.islice(it, i), 0) or it
33+
for i, it in enumerate(itertools.tee(seq, n))
34+
)
35+
)
36+
37+
38+
def season_to_month_tuple(seasons: Sequence[str]) -> Sequence[Sequence[int]]:
39+
easy = {"D": 12, "F": 2, "S": 9, "O": 10, "N": 11}
40+
harder = {"DJF": 1, "FMA": 3, "MAM": 4, "AMJ": 5, "MJJ": 6, "JJA": 7, "JAS": 8}
41+
42+
if len("".join(seasons)) != 12:
43+
raise ValueError("SeasonGrouper requires exactly 12 months in total.")
44+
45+
# Slide through with a window of 3.
46+
# A 3 letter string is enough to unambiguously
47+
# assign the right month number of the middle letter
48+
WINDOW = 3
49+
50+
perseason = [seasons[-1], *seasons, seasons[0]]
51+
52+
season_inds = []
53+
for sprev, sthis, snxt in sliding_window(WINDOW, perseason):
54+
inds = []
55+
permonth = "".join([sprev[-1], *sthis, snxt[0]])
56+
for mprev, mthis, mnxt in sliding_window(WINDOW, permonth):
57+
if mthis in easy:
58+
inds.append(easy[mthis])
59+
else:
60+
concatted = "".join([mprev, mthis, mnxt])
61+
# print(concatted)
62+
inds.append(harder[concatted])
63+
64+
season_inds.append(inds)
65+
return season_inds
66+
67+
68+
@dataclass
69+
class SeasonGrouper(Grouper):
70+
"""Allows grouping using a custom definition of seasons.
71+
72+
Parameters
73+
----------
74+
seasons: sequence of str
75+
List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc.
76+
drop_incomplete: bool
77+
Whether to drop seasons that are not completely included in the data.
78+
For example, if a time series starts in Jan-2001, and seasons includes `"DJF"`
79+
then observations from Jan-2001, and Feb-2001 are ignored in the grouping
80+
since Dec-2000 isn't present.
81+
82+
Examples
83+
--------
84+
>>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"])
85+
>>> SeasonGrouper(["DJFM", "AM", "JJA", "SON"])
86+
"""
87+
88+
seasons: Sequence[str]
89+
season_inds: Sequence[Sequence[int]] = field(init=False)
90+
drop_incomplete: bool = field(default=True)
91+
92+
def __post_init__(self):
93+
self.season_inds = season_to_month_tuple(self.seasons)
94+
95+
def __repr__(self):
96+
return f"SeasonGrouper over {self.grouper.seasons!r}"
97+
98+
def factorize(self, group):
99+
seasons = self.seasons
100+
season_inds = self.season_inds
101+
102+
months = group.dt.month
103+
codes_ = np.full(group.shape, -1)
104+
group_indices = [[]] * len(seasons)
105+
106+
index = np.arange(group.size)
107+
for idx, season in enumerate(season_inds):
108+
mask = months.isin(season)
109+
codes_[mask] = idx
110+
group_indices[idx] = index[mask]
111+
112+
if np.all(codes_ == -1):
113+
raise ValueError(
114+
"Failed to group data. Are you grouping by a variable that is all NaN?"
115+
)
116+
codes = group.copy(data=codes_).rename("season")
117+
unique_coord = IndexVariable("season", seasons, attrs=group.attrs)
118+
full_index = unique_coord
119+
return codes, group_indices, unique_coord, full_index
120+
121+
122+
@dataclass
123+
class SeasonResampler(Resampler):
124+
"""Allows grouping using a custom definition of seasons.
125+
126+
Examples
127+
--------
128+
>>> SeasonResampler(["JF", "MAM", "JJAS", "OND"])
129+
>>> SeasonResampler(["DJFM", "AM", "JJA", "SON"])
130+
"""
131+
132+
seasons: Sequence[str]
133+
# drop_incomplete: bool = field(default=True) # TODO:
134+
season_inds: Sequence[Sequence[int]] = field(init=False)
135+
season_tuples: Mapping[str, Sequence[int]] = field(init=False)
136+
137+
def __post_init__(self):
138+
self.season_inds = season_to_month_tuple(self.seasons)
139+
self.season_tuples = dict(zip(self.seasons, self.season_inds))
140+
141+
def factorize(self, group):
142+
assert group.ndim == 1
143+
144+
seasons = self.seasons
145+
season_inds = self.season_inds
146+
season_tuples = self.season_tuples
147+
148+
nstr = max(len(s) for s in seasons)
149+
year = group.dt.year.astype(int)
150+
month = group.dt.month.astype(int)
151+
season_label = np.full(group.shape, "", dtype=f"U{nstr}")
152+
153+
# offset years for seasons with December and January
154+
for season_str, season_ind in zip(seasons, season_inds):
155+
season_label[month.isin(season_ind)] = season_str
156+
if "DJ" in season_str:
157+
after_dec = season_ind[season_str.index("D") + 1 :]
158+
year[month.isin(after_dec)] -= 1
159+
160+
frame = pd.DataFrame(
161+
data={"index": np.arange(group.size), "month": month},
162+
index=pd.MultiIndex.from_arrays(
163+
[year.data, season_label], names=["year", "season"]
164+
),
165+
)
166+
167+
series = frame["index"]
168+
g = series.groupby(["year", "season"], sort=False)
169+
first_items = g.first()
170+
counts = g.count()
171+
172+
# these are the seasons that are present
173+
unique_coord = pd.DatetimeIndex(
174+
[
175+
pd.Timestamp(year=year, month=season_tuples[season][0], day=1)
176+
for year, season in first_items.index
177+
]
178+
)
179+
180+
sbins = first_items.values.astype(int)
181+
group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])]
182+
group_indices += [slice(sbins[-1], None)]
183+
184+
# Make sure the first and last timestamps
185+
# are for the correct months,if not we have incomplete seasons
186+
unique_codes = np.arange(len(unique_coord))
187+
for idx, slicer in zip([0, -1], (slice(1, None), slice(-1))):
188+
stamp_year, stamp_season = frame.index[idx]
189+
code = seasons.index(stamp_season)
190+
stamp_month = season_inds[code][idx]
191+
if stamp_month != month[idx].item():
192+
# we have an incomplete season!
193+
group_indices = group_indices[slicer]
194+
unique_coord = unique_coord[slicer]
195+
if idx == 0:
196+
unique_codes -= 1
197+
unique_codes[idx] = -1
198+
199+
# all years and seasons
200+
complete_index = pd.DatetimeIndex(
201+
# This sorted call is a hack. It's hard to figure out how
202+
# to start the iteration
203+
sorted(
204+
[
205+
pd.Timestamp(f"{y}-{m}-01")
206+
for y, m in itertools.product(
207+
range(year[0].item(), year[-1].item() + 1),
208+
[s[0] for s in season_inds],
209+
)
210+
]
211+
)
212+
)
213+
# only keep that included in data
214+
range_ = complete_index.get_indexer(unique_coord[[0, -1]])
215+
full_index = complete_index[slice(range_[0], range_[-1] + 1)]
216+
# check that there are no "missing" seasons in the middle
217+
# print(full_index, unique_coord)
218+
if not full_index.equals(unique_coord):
219+
raise ValueError("Are there seasons missing in the middle of the dataset?")
220+
221+
codes = group.copy(data=np.repeat(unique_codes, counts))
222+
unique_coord_var = IndexVariable(group.name, unique_coord, group.attrs)
223+
224+
return codes, group_indices, unique_coord_var, full_index

xarray/tests/test_groupby.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2424,3 +2424,20 @@ def test_groupby_math_auto_chunk():
24242424
)
24252425
actual = da.chunk(x=1, y=2).groupby("label") - sub
24262426
assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)}
2427+
2428+
2429+
def test_season_to_month_tuple():
2430+
from xarray.core.groupers import season_to_month_tuple
2431+
2432+
assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == [
2433+
[1, 2],
2434+
[3, 4, 5],
2435+
[6, 7, 8, 9],
2436+
[10, 11, 12],
2437+
]
2438+
assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == [
2439+
[12, 1, 2, 3],
2440+
[4, 5],
2441+
[6, 7, 8, 9],
2442+
[10, 11],
2443+
]

0 commit comments

Comments
 (0)