Skip to content
Merged
2 changes: 1 addition & 1 deletion asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def setup(self, *args, **kwargs):
ret = flox.core._factorize_multiple(
by,
expected_groups=(pd.Index(np.arange(1, 13)), pd.Index(np.arange(1, 25))),
by_is_dask=False,
any_by_dask=False,
reindex=False,
)
# Add one so the rechunk code is simpler and makes sense
Expand Down
60 changes: 43 additions & 17 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def find_group_cohorts(labels, chunks, merge: bool = True):
axis = range(-labels.ndim, 0)
# Easier to create a dask array and use the .blocks property
array = dask.array.ones(tuple(sum(c) for c in chunks), chunks=chunks)
labels = np.broadcast_to(labels, array.shape[-labels.ndim :])

# Iterate over each block and create a new block of same shape with "chunk number"
shape = tuple(array.blocks.shape[ax] for ax in axis)
Expand Down Expand Up @@ -479,7 +480,7 @@ def factorize_(
idx, groups = pd.factorize(flat, sort=sort)

found_groups.append(np.array(groups))
factorized.append(idx)
factorized.append(idx.reshape(groupvar.shape))

grp_shape = tuple(len(grp) for grp in found_groups)
ngroups = math.prod(grp_shape)
Expand All @@ -489,20 +490,18 @@ def factorize_(
# Restore these after the raveling
nan_by_mask = reduce(np.logical_or, [(f == -1) for f in factorized])
group_idx[nan_by_mask] = -1
group_idx = group_idx.reshape(by[0].shape)
else:
group_idx = factorized[0]

if fastpath:
return group_idx.reshape(by[0].shape), found_groups, grp_shape
return group_idx, found_groups, grp_shape

if np.isscalar(axis) and groupvar.ndim > 1:
# Not reducing along all dimensions of by
# this is OK because for 3D by and axis=(1,2),
# we collapse to a 2D by and axis=-1
offset_group = True
group_idx, size = offset_labels(group_idx.reshape(by[0].shape), ngroups)
group_idx = group_idx.reshape(-1)
else:
size = ngroups
offset_group = False
Expand Down Expand Up @@ -647,6 +646,8 @@ def chunk_reduce(
else:
nax = by.ndim

assert by.ndim <= array.ndim

final_array_shape = array.shape[:-nax] + (1,) * (nax - 1)
final_groups_shape = (1,) * (nax - 1)

Expand All @@ -667,9 +668,17 @@ def chunk_reduce(
)
groups = groups[0]

if isinstance(axis, Sequence):
needs_broadcast = any(
group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1
for ax in range(-len(axis), 0)
)
if needs_broadcast:
group_idx = np.broadcast_to(group_idx, array.shape[-by.ndim :])
# always reshape to 1D along group dimensions
newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),)
array = array.reshape(newshape)
group_idx = group_idx.reshape(-1)

assert group_idx.ndim == 1
empty = np.all(props.nanmask)
Expand Down Expand Up @@ -1220,7 +1229,9 @@ def dask_groupby_agg(
# chunk numpy arrays like the input array
# This removes an extra rechunk-merge layer that would be
# added otherwise
by = dask.array.from_array(by, chunks=tuple(array.chunks[ax] for ax in range(-by.ndim, 0)))
chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0))

by = dask.array.from_array(by, chunks=chunks)
_, (array, by) = dask.array.unify_chunks(array, inds, by, inds[-by.ndim :])

# preprocess the array: for argreductions, this zips the index together with the array block
Expand Down Expand Up @@ -1396,7 +1407,7 @@ def dask_groupby_agg(


def _validate_reindex(
reindex: bool | None, func, method: T_Method, expected_groups, by_is_dask: bool
reindex: bool | None, func, method: T_Method, expected_groups, any_by_dask: bool
) -> bool:
if reindex is True:
if _is_arg_reduction(func):
Expand All @@ -1414,7 +1425,7 @@ def _validate_reindex(
reindex = False

elif method == "map-reduce":
if expected_groups is None and by_is_dask:
if expected_groups is None and any_by_dask:
reindex = False
else:
reindex = True
Expand All @@ -1424,8 +1435,9 @@ def _validate_reindex(


def _assert_by_is_aligned(shape, by):
assert all(b.ndim == by[0].ndim for b in by[1:])
for idx, b in enumerate(by):
if shape[-b.ndim :] != b.shape:
if not all(j in [i, 1] for i, j in zip(shape[-b.ndim :], b.shape)):
raise ValueError(
"`array` and `by` arrays must be aligned "
"i.e. array.shape[-by.ndim :] == by.shape. "
Expand Down Expand Up @@ -1462,26 +1474,34 @@ def _lazy_factorize_wrapper(*by, **kwargs):
return group_idx


def _factorize_multiple(by, expected_groups, by_is_dask, reindex):
def _factorize_multiple(by, expected_groups, any_by_dask, reindex):
kwargs = dict(
expected_groups=expected_groups,
axis=None, # always None, we offset later if necessary.
fastpath=True,
reindex=reindex,
)
if by_is_dask:
if any_by_dask:
import dask.array

# unifying chunks will make sure all arrays in `by` are dask arrays
# with compatible chunks, even if there was originally a numpy array
inds = tuple(range(by[0].ndim))
chunks, by_ = dask.array.unify_chunks(*itertools.chain(*zip(by, (inds,) * len(by))))

group_idx = dask.array.map_blocks(
_lazy_factorize_wrapper,
*np.broadcast_arrays(*by),
*by_,
chunks=tuple(chunks.values()),
meta=np.array((), dtype=np.int64),
**kwargs,
)
found_groups = tuple(
None if is_duck_dask_array(b) else pd.unique(b.reshape(-1)) for b in by
)
grp_shape = tuple(len(e) for e in expected_groups)
grp_shape = tuple(
len(e) if e is not None else len(f) for e, f in zip(expected_groups, found_groups)
)
else:
group_idx, found_groups, grp_shape = factorize_(by, **kwargs)

Expand Down Expand Up @@ -1611,15 +1631,16 @@ def groupby_reduce(

bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(bys)
by_is_dask = any(is_duck_dask_array(b) for b in bys)
by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
any_by_dask = any(by_is_dask)

if method in ["split-reduce", "cohorts"] and by_is_dask:
if method in ["split-reduce", "cohorts"] and any_by_dask:
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

if method == "split-reduce":
method = "cohorts"

reindex = _validate_reindex(reindex, func, method, expected_groups, by_is_dask)
reindex = _validate_reindex(reindex, func, method, expected_groups, any_by_dask)

if not is_duck_array(array):
array = np.asarray(array)
Expand All @@ -1634,6 +1655,11 @@ def groupby_reduce(
expected_groups = (None,) * nby

_assert_by_is_aligned(array.shape, bys)
for idx, (expect, is_dask) in enumerate(zip(expected_groups, by_is_dask)):
if is_dask and (reindex or nby > 1) and expect is None:
raise ValueError(
f"`expected_groups` for array {idx} in `by` cannot be None since it is a dask.array."
)

if nby == 1 and not isinstance(expected_groups, tuple):
expected_groups = (np.asarray(expected_groups),)
Expand All @@ -1653,7 +1679,7 @@ def groupby_reduce(
)
if factorize_early:
bys, final_groups, grp_shape = _factorize_multiple(
bys, expected_groups, by_is_dask=by_is_dask, reindex=reindex
bys, expected_groups, any_by_dask=any_by_dask, reindex=reindex
)
expected_groups = (pd.RangeIndex(math.prod(grp_shape)),)

Expand All @@ -1676,7 +1702,7 @@ def groupby_reduce(

# TODO: make sure expected_groups is unique
if nax == 1 and by_.ndim > 1 and expected_groups is None:
if not by_is_dask:
if not any_by_dask:
expected_groups = _get_expected_groups(by_, sort)
else:
# When we reduce along all axes, we are guaranteed to see all
Expand Down
73 changes: 39 additions & 34 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,6 @@
Dims = Union[str, Iterable[Hashable], None]


def _get_input_core_dims(group_names, dim, ds, grouper_dims):
input_core_dims = [[], []]
for g in group_names:
if g in dim:
continue
if g in ds.dims:
input_core_dims[0].extend([g])
if g in grouper_dims:
input_core_dims[1].extend([g])
input_core_dims[0].extend(dim)
input_core_dims[1].extend(dim)
return input_core_dims


def _restore_dim_order(result, obj, by):
def lookup_order(dimension):
if dimension == by.name and by.ndim == 1:
Expand All @@ -54,6 +40,26 @@ def lookup_order(dimension):
return result.transpose(*new_order)


def _broadcast_size_one_dims(*arrays, core_dims):
"""Broadcast by adding size-1 dimensions in the right place.

Workaround because apply_ufunc doesn't support this yet.
https://github.com/pydata/xarray/issues/3032#issuecomment-503337637

Specialized to the groupby problem.
"""
array_dims = set(core_dims[0])
broadcasted = [arrays[0]]
for dims, array in zip(core_dims[1:], arrays[1:]):
assert set(dims).issubset(array_dims)
order = [dims.index(d) for d in core_dims[0] if d in dims]
array = array.transpose(*order)
axis = [core_dims[0].index(d) for d in core_dims[0] if d not in dims]
broadcasted.append(np.expand_dims(array, axis))

return broadcasted


def xarray_reduce(
obj: T_Dataset | T_DataArray,
*by: T_DataArray | Hashable,
Expand Down Expand Up @@ -255,20 +261,11 @@ def xarray_reduce(
elif dim is not None:
dim_tuple = _atleast_1d(dim)
else:
dim_tuple = tuple()
dim_tuple = tuple(grouper_dims)

# broadcast all variables against each other along all dimensions in `by` variables
# don't exclude `dim` because it need not be a dimension in any of the `by` variables!
# in the case where dim is Ellipsis, and by.ndim < obj.ndim
# then we also broadcast `by` to all `obj.dims`
# TODO: avoid this broadcasting
# broadcast to make sure grouper dimensions are present in the array.
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple)
ds_broad, *by_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)

# all members of by_broad have the same dimensions
# so we just pull by_broad[0].dims if dim is None
if not dim_tuple:
dim_tuple = tuple(by_broad[0].dims)
ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0]

if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):
raise ValueError(f"Cannot reduce over absent dimensions {dim}.")
Expand Down Expand Up @@ -298,7 +295,7 @@ def xarray_reduce(
expected_groups = list(expected_groups)
group_names: tuple[Any, ...] = ()
group_sizes: dict[Any, int] = {}
for idx, (b_, expect, isbin_) in enumerate(zip(by_broad, expected_groups, isbins)):
for idx, (b_, expect, isbin_) in enumerate(zip(by_da, expected_groups, isbins)):
group_name = b_.name if not isbin_ else f"{b_.name}_bins"
group_names += (group_name,)

Expand Down Expand Up @@ -326,7 +323,10 @@ def xarray_reduce(
# This will never be reached
raise ValueError("expect_index cannot be None")

def wrapper(array, *by, func, skipna, **kwargs):
def wrapper(array, *by, func, skipna, core_dims, **kwargs):

array, *by = _broadcast_size_one_dims(array, *by, core_dims=core_dims)

# Handle skipna here because I need to know dtype to make a good default choice.
# We cannnot handle this easily for xarray Datasets in xarray_reduce
if skipna and func in ["all", "any", "count"]:
Expand Down Expand Up @@ -374,17 +374,21 @@ def wrapper(array, *by, func, skipna, **kwargs):
if is_missing_dim:
missing_dim[k] = v

input_core_dims = _get_input_core_dims(group_names, dim_tuple, ds_broad, grouper_dims)
input_core_dims += [input_core_dims[-1]] * (nby - 1)
# dim_tuple contains dimensions we are reducing over. These need to be the last
# core dimensions to be synchronized with axis.
input_core_dims = [[d for d in grouper_dims if d not in dim_tuple] + list(dim_tuple)]
input_core_dims += [list(b.dims) for b in by_da]

output_core_dims = [d for d in input_core_dims[0] if d not in dim_tuple]
output_core_dims.extend(group_names)
actual = xr.apply_ufunc(
wrapper,
ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims),
*by_broad,
*by_da,
input_core_dims=input_core_dims,
# for xarray's test_groupby_duplicate_coordinate_labels
exclude_dims=set(dim_tuple),
output_core_dims=[group_names],
output_core_dims=[output_core_dims],
dask="allowed",
dask_gufunc_kwargs=dict(
output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None
Expand All @@ -404,6 +408,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
"isbin": isbins,
"finalize_kwargs": finalize_kwargs,
"dtype": dtype,
"core_dims": input_core_dims,
},
)

Expand All @@ -413,7 +418,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
if all(d not in ds_broad[var].dims for d in dim_tuple):
actual[var] = ds_broad[var]

for name, expect, by_ in zip(group_names, expected_groups, by_broad):
for name, expect, by_ in zip(group_names, expected_groups, by_da):
# Can't remove this till xarray handles IntervalIndex
if isinstance(expect, pd.IntervalIndex):
expect = expect.to_numpy()
Expand Down Expand Up @@ -443,7 +448,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
template = obj

if actual[var].ndim > 1:
actual[var] = _restore_dim_order(actual[var], template, by_broad[0])
actual[var] = _restore_dim_order(actual[var], template, by_da[0])

if missing_dim:
for k, v in missing_dim.items():
Expand Down
Loading