Skip to content

More efficient cohorts. #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Oct 11, 2022
2 changes: 1 addition & 1 deletion docs/source/implementation.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ or `xarray_reduce`.

First we describe xarray's current strategy

## `method="split-reduce"`: Xarray's current GroupBy strategy
## Background: Xarray's current GroupBy strategy

Xarray's current strategy is to find all unique group labels, index out each group,
and then apply the reduction operation. Note that this only works if we know the group
Expand Down
239 changes: 149 additions & 90 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _get_optimal_chunks_for_groups(chunks, labels):


@memoize
def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "cohorts"):
def find_group_cohorts(labels, chunks, merge: bool = True):
"""
Finds groups labels that occur together aka "cohorts"

Expand Down Expand Up @@ -167,9 +167,6 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co
# To do this, we must have values in memory so casting to numpy should be safe
labels = np.asarray(labels)

if method == "split-reduce":
return list(_get_expected_groups(labels, sort=False).to_numpy().reshape(-1, 1))

# Build an array with the shape of labels, but where every element is the "chunk number"
# 1. First subset the array appropriately
axis = range(-labels.ndim, 0)
Expand All @@ -195,7 +192,7 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co
if merge:
# First sort by number of chunks occupied by cohort
sorted_chunks_cohorts = dict(
reversed(sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0])))
sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True)
)

items = tuple(sorted_chunks_cohorts.items())
Expand All @@ -218,9 +215,15 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co
merged_cohorts[k1].extend(v2)
merged_keys.append(k2)

return merged_cohorts.values()
# make sure each cohort is sorted after merging
sorted_merged_cohorts = {k: sorted(v) for k, v in merged_cohorts.items()}
# sort by first label in cohort
# This will help when sort=True (default)
# and we have to resort the dask array
return dict(sorted(sorted_merged_cohorts.items(), key=lambda kv: kv[1][0]))

else:
return chunks_cohorts.values()
return chunks_cohorts


def rechunk_for_cohorts(
Expand Down Expand Up @@ -1079,6 +1082,63 @@ def _reduce_blockwise(
return result


def subset_to_blocks(
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
) -> DaskArray:
"""
Advanced indexing of .blocks such that we always get a regular array back.

Parameters
----------
array : dask.array
flatblocks : flat indices of blocks to extract
blkshape : shape of blocks with which to unravel flatblocks

Returns
-------
dask.array
"""
if blkshape is None:
blkshape = array.blocks.shape

unraveled = np.unravel_index(flatblocks, blkshape)
normalized: list[Union[int, np.ndarray, slice]] = []
for ax, idx in enumerate(unraveled):
i = np.unique(idx).squeeze()
if i.ndim == 0:
normalized.append(i.item())
else:
if np.array_equal(i, np.arange(blkshape[ax])):
normalized.append(slice(None))
elif np.array_equal(i, np.arange(i[0], i[-1] + 1)):
normalized.append(slice(i[0], i[-1] + 1))
else:
normalized.append(i)
full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized)

# has no iterables
noiter = tuple(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized)
# has all iterables
alliter = {
ax: i if hasattr(i, "__len__") else slice(None) for ax, i in enumerate(full_normalized)
}

# apply everything but the iterables
if all(i == slice(None) for i in noiter):
return array

subset = array.blocks[noiter]

for ax, inds in alliter.items():
if isinstance(inds, slice):
continue
idxr = [slice(None, None)] * array.ndim
idxr[ax] = inds
subset = subset.blocks[tuple(idxr)]

return subset


def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:
import dask.array
from dask.highlevelgraph import HighLevelGraph
Expand Down Expand Up @@ -1115,6 +1175,7 @@ def dask_groupby_agg(
reindex: bool = False,
engine: T_Engine = "numpy",
sort: bool = True,
chunks_cohorts=None,
) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]:

import dask.array
Expand Down Expand Up @@ -1194,7 +1255,7 @@ def dask_groupby_agg(
partial(
blockwise_method,
axis=axis,
expected_groups=expected_groups,
expected_groups=None if method in ["split-reduce", "cohorts"] else expected_groups,
engine=engine,
sort=sort,
),
Expand Down Expand Up @@ -1223,43 +1284,77 @@ def dask_groupby_agg(
expected_groups = _get_expected_groups(by_input, sort=sort)
group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)

if method == "map-reduce":
if method in ["map-reduce", "cohorts", "split-reduce"]:
combine: Callable[..., IntermediateDict]
if do_simple_combine:
combine = _simple_combine
else:
combine = partial(_grouped_combine, engine=engine, sort=sort)

# reduced is really a dict mapping reduction name to array
# and "groups" to an array of group labels
# Each chunk of `reduced`` is really a dict mapping
# 1. reduction name to array
# 2. "groups" to an array of group labels
# Note: it does not make sense to interpret axis relative to
# shape of intermediate results after the blockwise call
reduced = dask.array.reductions._tree_reduce(
intermediate,
aggregate=partial(
_aggregate,
combine=combine,
agg=agg,
expected_groups=None if split_out > 1 else expected_groups,
fill_value=fill_value,
reindex=reindex,
),
tree_reduce = partial(
dask.array.reductions._tree_reduce,
combine=partial(combine, agg=agg),
name=f"{name}-reduce",
name=f"{name}-reduce-{method}",
dtype=array.dtype,
axis=axis,
keepdims=True,
concatenate=False,
)

if is_duck_dask_array(by_input) and expected_groups is None:
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
else:
if expected_groups is None:
expected_groups_ = _get_expected_groups(by_input, sort=sort)
aggregate = partial(
_aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex
)
if method == "map-reduce":
reduced = tree_reduce(
intermediate,
aggregate=partial(
aggregate, expected_groups=None if split_out > 1 else expected_groups
),
)
if is_duck_dask_array(by_input) and expected_groups is None:
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
else:
expected_groups_ = expected_groups
groups = (expected_groups_.to_numpy(),)
if expected_groups is None:
expected_groups_ = _get_expected_groups(by_input, sort=sort)
else:
expected_groups_ = expected_groups
groups = (expected_groups_.to_numpy(),)

elif method in ["cohorts", "split-reduce"]:
chunks_cohorts = find_group_cohorts(
by_input, [array.chunks[ax] for ax in axis], merge=True
)
reduced_ = []
groups_ = []
for blks, cohort in chunks_cohorts.items():
subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
if do_simple_combine:
# reindex so that reindex can be set to True later
reindexed = dask.array.map_blocks(
reindex_intermediates,
subset,
agg=agg,
unique_groups=cohort,
meta=subset._meta,
)
else:
reindexed = subset

reduced_.append(
tree_reduce(
reindexed,
aggregate=partial(aggregate, expected_groups=cohort, reindex=reindex),
)
)
groups_.append(cohort)

reduced = dask.array.concatenate(reduced_, axis=-1)
groups = (np.concatenate(groups_),)
group_chunks = (tuple(len(cohort) for cohort in groups_),)

elif method == "blockwise":
reduced = intermediate
Expand Down Expand Up @@ -1297,7 +1392,11 @@ def dask_groupby_agg(
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
else:
inchunk = ochunk[:-1] + (0,) * len(axis) + (ochunk[-1],) * int(split_out > 1)
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1)
if split_out > 1:
inchunk = inchunk + (0,)
inchunk = inchunk + (ochunk[-1],)

layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)

result = dask.array.Array(
Expand Down Expand Up @@ -1326,6 +1425,9 @@ def _validate_reindex(reindex: bool | None, func, method: T_Method, expected_gro
if method in ["split-reduce", "cohorts"] and reindex is False:
raise NotImplementedError

if method in ["split-reduce", "cohorts"] and reindex is None:
reindex = True

# TODO: Should reindex be a bool-only at this point? Would've been nice but
# None's are relied on after this function as well.
return reindex
Expand Down Expand Up @@ -1480,9 +1582,7 @@ def groupby_reduce(
method by first rechunking using ``rechunk_for_cohorts``
(for 1D ``by`` only).
* ``"split-reduce"``:
Break out each group into its own array and then ``"map-reduce"``.
This is implemented by having each group be its own cohort,
and is identical to xarray's default strategy.
Same as "cohorts" and will be removed soon.
engine : {"flox", "numpy", "numba"}, optional
Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
* ``"numpy"``:
Expand Down Expand Up @@ -1652,67 +1752,26 @@ def groupby_reduce(

partial_agg = partial(dask_groupby_agg, split_out=split_out, **kwargs)

if method in ["split-reduce", "cohorts"]:
cohorts = find_group_cohorts(
by_, [array.chunks[ax] for ax in axis_], merge=True, method=method
)

results_ = []
groups_ = []
for cohort in cohorts:
cohort = sorted(cohort)
# equivalent of xarray.DataArray.where(mask, drop=True)
mask = np.isin(by_, cohort)
indexer = [np.unique(v) for v in np.nonzero(mask)]
array_subset = array
for ax, idxr in zip(range(-by_.ndim, 0), indexer):
array_subset = np.take(array_subset, idxr, axis=ax)
numblocks = math.prod([len(array_subset.chunks[ax]) for ax in axis_])

# get final result for these groups
r, *g = partial_agg(
array_subset,
by_[np.ix_(*indexer)],
expected_groups=pd.Index(cohort),
# First deep copy becasue we might be doping blockwise,
# which sets agg.finalize=None, then map-reduce (GH102)
agg=copy.deepcopy(agg),
# reindex to expected_groups at the blockwise step.
# this approach avoids replacing non-cohort members with
# np.nan or some other sentinel value, and preserves dtypes
reindex=True,
# sort controls the final output order so apply that at the end
sort=False,
# if only a single block along axis, we can just work blockwise
# inspired by https://github.com/dask/dask/issues/8361
method="blockwise" if numblocks == 1 and nax == by_.ndim else "map-reduce",
)
results_.append(r)
groups_.append(cohort)
if method == "blockwise" and by_.ndim == 1:
array = rechunk_for_blockwise(array, axis=-1, labels=by_)

# concatenate results together,
# sort to make sure we match expected output
groups = (np.hstack(groups_),)
result = np.concatenate(results_, axis=-1)
else:
if method == "blockwise" and by_.ndim == 1:
array = rechunk_for_blockwise(array, axis=-1, labels=by_)

result, groups = partial_agg(
array,
by_,
expected_groups=None if method == "blockwise" else expected_groups,
agg=agg,
reindex=reindex,
method=method,
sort=sort,
)
result, groups = partial_agg(
array,
by_,
expected_groups=None if method == "blockwise" else expected_groups,
agg=agg,
reindex=reindex,
method=method,
sort=sort,
)

if sort and method != "map-reduce":
assert len(groups) == 1
sorted_idx = np.argsort(groups[0])
result = result[..., sorted_idx]
groups = (groups[0][sorted_idx],)
# This optimization helps specifically with resampling
if not (sorted_idx[1:] <= sorted_idx[:-1]).all():
result = result[..., sorted_idx]
groups = (groups[0][sorted_idx],)

if factorize_early:
# nan group labels are factorized to -1, and preserved
Expand Down
4 changes: 2 additions & 2 deletions flox/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ def visualize_cohorts_2d(by, array, method="cohorts"):
print("finding cohorts...")
before_merged = find_group_cohorts(
by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=False, method=method
)
).values()
merged = find_group_cohorts(
by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=True, method=method
)
).values()
print("finished cohorts...")

xticks = np.cumsum(array.chunks[-1])
Expand Down
4 changes: 1 addition & 3 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def xarray_reduce(
method by first rechunking using ``rechunk_for_cohorts``
(for 1D ``by`` only).
* ``"split-reduce"``:
Break out each group into its own array and then ``"map-reduce"``.
This is implemented by having each group be its own cohort,
and is identical to xarray's default strategy.
Same as "cohorts" and will be removed soon.
engine : {"flox", "numpy", "numba"}, optional
Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
* ``"numpy"``:
Expand Down
Loading