|
| 1 | +import itertools |
| 2 | +from collections.abc import Mapping, Sequence |
| 3 | +from dataclasses import dataclass, field |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import pandas as pd |
| 7 | + |
| 8 | +from xarray.core.groupby import Grouper, Resampler |
| 9 | +from xarray.core.variable import IndexVariable |
| 10 | + |
| 11 | + |
| 12 | +## From toolz |
| 13 | +## TODO: move to compat file, add license |
| 14 | +def sliding_window(n, seq): |
| 15 | + """A sequence of overlapping subsequences |
| 16 | +
|
| 17 | + >>> list(sliding_window(2, [1, 2, 3, 4])) |
| 18 | + [(1, 2), (2, 3), (3, 4)] |
| 19 | +
|
| 20 | + This function creates a sliding window suitable for transformations like |
| 21 | + sliding means / smoothing |
| 22 | +
|
| 23 | + >>> mean = lambda seq: float(sum(seq)) / len(seq) |
| 24 | + >>> list(map(mean, sliding_window(2, [1, 2, 3, 4]))) |
| 25 | + [1.5, 2.5, 3.5] |
| 26 | + """ |
| 27 | + import collections |
| 28 | + import itertools |
| 29 | + |
| 30 | + return zip( |
| 31 | + *( |
| 32 | + collections.deque(itertools.islice(it, i), 0) or it |
| 33 | + for i, it in enumerate(itertools.tee(seq, n)) |
| 34 | + ) |
| 35 | + ) |
| 36 | + |
| 37 | + |
| 38 | +def season_to_month_tuple(seasons: Sequence[str]) -> Sequence[Sequence[int]]: |
| 39 | + easy = {"D": 12, "F": 2, "S": 9, "O": 10, "N": 11} |
| 40 | + harder = {"DJF": 1, "FMA": 3, "MAM": 4, "AMJ": 5, "MJJ": 6, "JJA": 7, "JAS": 8} |
| 41 | + |
| 42 | + if len("".join(seasons)) != 12: |
| 43 | + raise ValueError("SeasonGrouper requires exactly 12 months in total.") |
| 44 | + |
| 45 | + # Slide through with a window of 3. |
| 46 | + # A 3 letter string is enough to unambiguously |
| 47 | + # assign the right month number of the middle letter |
| 48 | + WINDOW = 3 |
| 49 | + |
| 50 | + perseason = [seasons[-1], *seasons, seasons[0]] |
| 51 | + |
| 52 | + season_inds = [] |
| 53 | + for sprev, sthis, snxt in sliding_window(WINDOW, perseason): |
| 54 | + inds = [] |
| 55 | + permonth = "".join([sprev[-1], *sthis, snxt[0]]) |
| 56 | + for mprev, mthis, mnxt in sliding_window(WINDOW, permonth): |
| 57 | + if mthis in easy: |
| 58 | + inds.append(easy[mthis]) |
| 59 | + else: |
| 60 | + concatted = "".join([mprev, mthis, mnxt]) |
| 61 | + # print(concatted) |
| 62 | + inds.append(harder[concatted]) |
| 63 | + |
| 64 | + season_inds.append(inds) |
| 65 | + return season_inds |
| 66 | + |
| 67 | + |
| 68 | +@dataclass |
| 69 | +class SeasonGrouper(Grouper): |
| 70 | + """Allows grouping using a custom definition of seasons. |
| 71 | +
|
| 72 | + Parameters |
| 73 | + ---------- |
| 74 | + seasons: sequence of str |
| 75 | + List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc. |
| 76 | + drop_incomplete: bool |
| 77 | + Whether to drop seasons that are not completely included in the data. |
| 78 | + For example, if a time series starts in Jan-2001, and seasons includes `"DJF"` |
| 79 | + then observations from Jan-2001, and Feb-2001 are ignored in the grouping |
| 80 | + since Dec-2000 isn't present. |
| 81 | +
|
| 82 | + Examples |
| 83 | + -------- |
| 84 | + >>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"]) |
| 85 | + >>> SeasonGrouper(["DJFM", "AM", "JJA", "SON"]) |
| 86 | + """ |
| 87 | + |
| 88 | + seasons: Sequence[str] |
| 89 | + season_inds: Sequence[Sequence[int]] = field(init=False) |
| 90 | + drop_incomplete: bool = field(default=True) |
| 91 | + |
| 92 | + def __post_init__(self): |
| 93 | + self.season_inds = season_to_month_tuple(self.seasons) |
| 94 | + |
| 95 | + def __repr__(self): |
| 96 | + return f"SeasonGrouper over {self.grouper.seasons!r}" |
| 97 | + |
| 98 | + def factorize(self, group): |
| 99 | + seasons = self.seasons |
| 100 | + season_inds = self.season_inds |
| 101 | + |
| 102 | + months = group.dt.month |
| 103 | + codes_ = np.full(group.shape, -1) |
| 104 | + group_indices = [[]] * len(seasons) |
| 105 | + |
| 106 | + index = np.arange(group.size) |
| 107 | + for idx, season in enumerate(season_inds): |
| 108 | + mask = months.isin(season) |
| 109 | + codes_[mask] = idx |
| 110 | + group_indices[idx] = index[mask] |
| 111 | + |
| 112 | + if np.all(codes_ == -1): |
| 113 | + raise ValueError( |
| 114 | + "Failed to group data. Are you grouping by a variable that is all NaN?" |
| 115 | + ) |
| 116 | + codes = group.copy(data=codes_).rename("season") |
| 117 | + unique_coord = IndexVariable("season", seasons, attrs=group.attrs) |
| 118 | + full_index = unique_coord |
| 119 | + return codes, group_indices, unique_coord, full_index |
| 120 | + |
| 121 | + |
| 122 | +@dataclass |
| 123 | +class SeasonResampler(Resampler): |
| 124 | + """Allows grouping using a custom definition of seasons. |
| 125 | +
|
| 126 | + Examples |
| 127 | + -------- |
| 128 | + >>> SeasonResampler(["JF", "MAM", "JJAS", "OND"]) |
| 129 | + >>> SeasonResampler(["DJFM", "AM", "JJA", "SON"]) |
| 130 | + """ |
| 131 | + |
| 132 | + seasons: Sequence[str] |
| 133 | + # drop_incomplete: bool = field(default=True) # TODO: |
| 134 | + season_inds: Sequence[Sequence[int]] = field(init=False) |
| 135 | + season_tuples: Mapping[str, Sequence[int]] = field(init=False) |
| 136 | + |
| 137 | + def __post_init__(self): |
| 138 | + self.season_inds = season_to_month_tuple(self.seasons) |
| 139 | + self.season_tuples = dict(zip(self.seasons, self.season_inds)) |
| 140 | + |
| 141 | + def factorize(self, group): |
| 142 | + assert group.ndim == 1 |
| 143 | + |
| 144 | + seasons = self.seasons |
| 145 | + season_inds = self.season_inds |
| 146 | + season_tuples = self.season_tuples |
| 147 | + |
| 148 | + nstr = max(len(s) for s in seasons) |
| 149 | + year = group.dt.year.astype(int) |
| 150 | + month = group.dt.month.astype(int) |
| 151 | + season_label = np.full(group.shape, "", dtype=f"U{nstr}") |
| 152 | + |
| 153 | + # offset years for seasons with December and January |
| 154 | + for season_str, season_ind in zip(seasons, season_inds): |
| 155 | + season_label[month.isin(season_ind)] = season_str |
| 156 | + if "DJ" in season_str: |
| 157 | + after_dec = season_ind[season_str.index("D") + 1 :] |
| 158 | + year[month.isin(after_dec)] -= 1 |
| 159 | + |
| 160 | + frame = pd.DataFrame( |
| 161 | + data={"index": np.arange(group.size), "month": month}, |
| 162 | + index=pd.MultiIndex.from_arrays( |
| 163 | + [year.data, season_label], names=["year", "season"] |
| 164 | + ), |
| 165 | + ) |
| 166 | + |
| 167 | + series = frame["index"] |
| 168 | + g = series.groupby(["year", "season"], sort=False) |
| 169 | + first_items = g.first() |
| 170 | + counts = g.count() |
| 171 | + |
| 172 | + # these are the seasons that are present |
| 173 | + unique_coord = pd.DatetimeIndex( |
| 174 | + [ |
| 175 | + pd.Timestamp(year=year, month=season_tuples[season][0], day=1) |
| 176 | + for year, season in first_items.index |
| 177 | + ] |
| 178 | + ) |
| 179 | + |
| 180 | + sbins = first_items.values.astype(int) |
| 181 | + group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] |
| 182 | + group_indices += [slice(sbins[-1], None)] |
| 183 | + |
| 184 | + # Make sure the first and last timestamps |
| 185 | + # are for the correct months,if not we have incomplete seasons |
| 186 | + unique_codes = np.arange(len(unique_coord)) |
| 187 | + for idx, slicer in zip([0, -1], (slice(1, None), slice(-1))): |
| 188 | + stamp_year, stamp_season = frame.index[idx] |
| 189 | + code = seasons.index(stamp_season) |
| 190 | + stamp_month = season_inds[code][idx] |
| 191 | + if stamp_month != month[idx].item(): |
| 192 | + # we have an incomplete season! |
| 193 | + group_indices = group_indices[slicer] |
| 194 | + unique_coord = unique_coord[slicer] |
| 195 | + if idx == 0: |
| 196 | + unique_codes -= 1 |
| 197 | + unique_codes[idx] = -1 |
| 198 | + |
| 199 | + # all years and seasons |
| 200 | + complete_index = pd.DatetimeIndex( |
| 201 | + # This sorted call is a hack. It's hard to figure out how |
| 202 | + # to start the iteration |
| 203 | + sorted( |
| 204 | + [ |
| 205 | + pd.Timestamp(f"{y}-{m}-01") |
| 206 | + for y, m in itertools.product( |
| 207 | + range(year[0].item(), year[-1].item() + 1), |
| 208 | + [s[0] for s in season_inds], |
| 209 | + ) |
| 210 | + ] |
| 211 | + ) |
| 212 | + ) |
| 213 | + # only keep that included in data |
| 214 | + range_ = complete_index.get_indexer(unique_coord[[0, -1]]) |
| 215 | + full_index = complete_index[slice(range_[0], range_[-1] + 1)] |
| 216 | + # check that there are no "missing" seasons in the middle |
| 217 | + # print(full_index, unique_coord) |
| 218 | + if not full_index.equals(unique_coord): |
| 219 | + raise ValueError("Are there seasons missing in the middle of the dataset?") |
| 220 | + |
| 221 | + codes = group.copy(data=np.repeat(unique_codes, counts)) |
| 222 | + unique_coord_var = IndexVariable(group.name, unique_coord, group.attrs) |
| 223 | + |
| 224 | + return codes, group_indices, unique_coord_var, full_index |
0 commit comments