Skip to content

Commit 90393df

Browse files
tomwhitepre-commit-ci[bot]dcherian
authored
Initial minimal working Cubed example for "map-reduce" (#352)
* Initial minimal working Cubed example for "map-reduce" * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix misspelled `aggegrate_func` * Update flox/core.py Co-authored-by: Deepak Cherian <[email protected]> * Expand to ALL_FUNCS * Use `_finalize_results` directly * Add test for nan values * Removed unused dtype from test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move example notebook to a gist https://gist.github.com/tomwhite/2d637d2581b44468da5b7e29c30c0c49 * Add CubedArray type * Add Cubed to CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make mypy happy * Make mypy happy (again) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 603ad2c commit 90393df

File tree

6 files changed

+203
-5
lines changed

6 files changed

+203
-5
lines changed

ci/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ dependencies:
66
- cachey
77
- cftime
88
- codecov
9+
- cubed>=0.14.2
910
- dask-core
1011
- pandas
1112
- numpy>=1.22

flox/core.py

Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838
)
3939
from .cache import memoize
4040
from .xrutils import (
41+
is_chunked_array,
4142
is_duck_array,
43+
is_duck_cubed_array,
4244
is_duck_dask_array,
4345
isnull,
4446
module_available,
@@ -63,10 +65,11 @@
6365
except (ModuleNotFoundError, ImportError):
6466
Unpack: Any # type: ignore[no-redef]
6567

68+
import cubed.Array as CubedArray
6669
import dask.array.Array as DaskArray
6770
from dask.typing import Graph
6871

69-
T_DuckArray = Union[np.ndarray, DaskArray] # Any ?
72+
T_DuckArray = Union[np.ndarray, DaskArray, CubedArray] # Any ?
7073
T_By = T_DuckArray
7174
T_Bys = tuple[T_By, ...]
7275
T_ExpectIndex = pd.Index
@@ -95,7 +98,7 @@
9598

9699

97100
IntermediateDict = dict[Union[str, Callable], Any]
98-
FinalResultsDict = dict[str, Union["DaskArray", np.ndarray]]
101+
FinalResultsDict = dict[str, Union["DaskArray", "CubedArray", np.ndarray]]
99102
FactorProps = namedtuple("FactorProps", "offset_group nan_sentinel nanmask")
100103

101104
# This dummy axis is inserted using np.expand_dims
@@ -1718,6 +1721,109 @@ def dask_groupby_agg(
17181721
return (result, groups)
17191722

17201723

1724+
def cubed_groupby_agg(
1725+
array: CubedArray,
1726+
by: T_By,
1727+
agg: Aggregation,
1728+
expected_groups: pd.Index | None,
1729+
axis: T_Axes = (),
1730+
fill_value: Any = None,
1731+
method: T_Method = "map-reduce",
1732+
reindex: bool = False,
1733+
engine: T_Engine = "numpy",
1734+
sort: bool = True,
1735+
chunks_cohorts=None,
1736+
) -> tuple[CubedArray, tuple[np.ndarray | CubedArray]]:
1737+
import cubed
1738+
import cubed.core.groupby
1739+
1740+
# I think _tree_reduce expects this
1741+
assert isinstance(axis, Sequence)
1742+
assert all(ax >= 0 for ax in axis)
1743+
1744+
inds = tuple(range(array.ndim))
1745+
1746+
by_input = by
1747+
1748+
# Unifying chunks is necessary for argreductions.
1749+
# We need to rechunk before zipping up with the index
1750+
# let's always do it anyway
1751+
if not is_chunked_array(by):
1752+
# chunk numpy arrays like the input array
1753+
chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0))
1754+
1755+
by = cubed.from_array(by, chunks=chunks, spec=array.spec)
1756+
_, (array, by) = cubed.core.unify_chunks(array, inds, by, inds[-by.ndim :])
1757+
1758+
# Cubed's groupby_reduction handles the generation of "intermediates", and the
1759+
# "map-reduce" combination step, so we don't have to do that here.
1760+
# Only the equivalent of "_simple_combine" is supported, there is no
1761+
# support for "_grouped_combine".
1762+
labels_are_unknown = is_chunked_array(by_input) and expected_groups is None
1763+
do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown
1764+
1765+
assert do_simple_combine
1766+
assert method == "map-reduce"
1767+
assert expected_groups is not None
1768+
assert reindex is True
1769+
assert len(axis) == 1 # one axis/grouping
1770+
1771+
def _groupby_func(a, by, axis, intermediate_dtype, num_groups):
1772+
blockwise_method = partial(
1773+
_get_chunk_reduction(agg.reduction_type),
1774+
func=agg.chunk,
1775+
fill_value=agg.fill_value["intermediate"],
1776+
dtype=agg.dtype["intermediate"],
1777+
reindex=reindex,
1778+
user_dtype=agg.dtype["user"],
1779+
axis=axis,
1780+
expected_groups=expected_groups,
1781+
engine=engine,
1782+
sort=sort,
1783+
)
1784+
out = blockwise_method(a, by)
1785+
# Convert dict to one that cubed understands, dropping groups since they are
1786+
# known, and the same for every block.
1787+
return {f"f{idx}": intermediate for idx, intermediate in enumerate(out["intermediates"])}
1788+
1789+
def _groupby_combine(a, axis, dummy_axis, dtype, keepdims):
1790+
# this is similar to _simple_combine, except the dummy axis and concatenation is handled by cubed
1791+
# only combine over the dummy axis, to preserve grouping along 'axis'
1792+
dtype = dict(dtype)
1793+
out = {}
1794+
for idx, combine in enumerate(agg.simple_combine):
1795+
field = f"f{idx}"
1796+
out[field] = combine(a[field], axis=dummy_axis, keepdims=keepdims)
1797+
return out
1798+
1799+
def _groupby_aggregate(a):
1800+
# Convert cubed dict to one that _finalize_results works with
1801+
results = {"groups": expected_groups, "intermediates": a.values()}
1802+
out = _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
1803+
return out[agg.name]
1804+
1805+
# convert list of dtypes to a structured dtype for cubed
1806+
intermediate_dtype = [(f"f{i}", dtype) for i, dtype in enumerate(agg.dtype["intermediate"])]
1807+
dtype = agg.dtype["final"]
1808+
num_groups = len(expected_groups)
1809+
1810+
result = cubed.core.groupby.groupby_reduction(
1811+
array,
1812+
by,
1813+
func=_groupby_func,
1814+
combine_func=_groupby_combine,
1815+
aggregate_func=_groupby_aggregate,
1816+
axis=axis,
1817+
intermediate_dtype=intermediate_dtype,
1818+
dtype=dtype,
1819+
num_groups=num_groups,
1820+
)
1821+
1822+
groups = (expected_groups.to_numpy(),)
1823+
1824+
return (result, groups)
1825+
1826+
17211827
def _collapse_blocks_along_axes(reduced: DaskArray, axis: T_Axes, group_chunks) -> DaskArray:
17221828
import dask.array
17231829
from dask.highlevelgraph import HighLevelGraph
@@ -2240,6 +2346,7 @@ def groupby_reduce(
22402346
nax = len(axis_)
22412347

22422348
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
2349+
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)
22432350

22442351
if _is_first_last_reduction(func):
22452352
if has_dask and nax != 1:
@@ -2302,7 +2409,30 @@ def groupby_reduce(
23022409
kwargs["engine"] = _choose_engine(by_, agg) if engine is None else engine
23032410

23042411
groups: tuple[np.ndarray | DaskArray, ...]
2305-
if not has_dask:
2412+
if has_cubed:
2413+
if method is None:
2414+
method = "map-reduce"
2415+
2416+
if method != "map-reduce":
2417+
raise NotImplementedError(
2418+
"Reduction for Cubed arrays is only implemented for method 'map-reduce'."
2419+
)
2420+
2421+
partial_agg = partial(cubed_groupby_agg, **kwargs)
2422+
2423+
result, groups = partial_agg(
2424+
array,
2425+
by_,
2426+
expected_groups=expected_,
2427+
agg=agg,
2428+
reindex=reindex,
2429+
method=method,
2430+
sort=sort,
2431+
)
2432+
2433+
return (result, groups)
2434+
2435+
elif not has_dask:
23062436
results = _reduce_blockwise(
23072437
array, by_, agg, expected_groups=expected_, reindex=reindex, sort=sort, **kwargs
23082438
)

flox/xrutils.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,18 @@ def is_duck_array(value: Any) -> bool:
3737
hasattr(value, "ndim")
3838
and hasattr(value, "shape")
3939
and hasattr(value, "dtype")
40-
and hasattr(value, "__array_function__")
41-
and hasattr(value, "__array_ufunc__")
40+
and (
41+
(hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
42+
or hasattr(value, "__array_namespace__")
43+
)
4244
)
4345

4446

47+
def is_chunked_array(x) -> bool:
48+
"""True if dask or cubed"""
49+
return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks"))
50+
51+
4552
def is_dask_collection(x):
4653
try:
4754
import dask
@@ -56,6 +63,15 @@ def is_duck_dask_array(x):
5663
return is_duck_array(x) and is_dask_collection(x)
5764

5865

66+
def is_duck_cubed_array(x):
67+
try:
68+
import cubed
69+
70+
return is_duck_array(x) and isinstance(x, cubed.Array)
71+
except ImportError:
72+
return False
73+
74+
5975
class ReprObject:
6076
"""Object that prints as the given value, for use with sentinel values."""
6177

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ module=[
121121
"asv_runner.*",
122122
"cachey",
123123
"cftime",
124+
"cubed.*",
124125
"dask.*",
125126
"importlib_metadata",
126127
"numba",

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def LooseVersion(vstring):
4646

4747

4848
has_cftime, requires_cftime = _importorskip("cftime")
49+
has_cubed, requires_cubed = _importorskip("cubed")
4950
has_dask, requires_dask = _importorskip("dask")
5051
has_numba, requires_numba = _importorskip("numba")
5152
has_numbagg, requires_numbagg = _importorskip("numbagg")

tests/test_core.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@
3636
SCIPY_STATS_FUNCS,
3737
assert_equal,
3838
assert_equal_tuple,
39+
has_cubed,
3940
has_dask,
4041
raise_if_dask_computes,
42+
requires_cubed,
4143
requires_dask,
4244
)
4345

@@ -61,6 +63,10 @@ def dask_array_ones(*args):
6163
return None
6264

6365

66+
if has_cubed:
67+
import cubed
68+
69+
6470
DEFAULT_QUANTILE = 0.9
6571

6672
if TYPE_CHECKING:
@@ -477,6 +483,49 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp
477483
assert_equal(expected, actual)
478484

479485

486+
@requires_cubed
487+
@pytest.mark.parametrize("reindex", [True])
488+
@pytest.mark.parametrize("func", ALL_FUNCS)
489+
@pytest.mark.parametrize("add_nan", [False, True])
490+
@pytest.mark.parametrize(
491+
"shape, array_chunks, group_chunks",
492+
[
493+
((12,), (3,), 3), # form 1
494+
],
495+
)
496+
def test_groupby_agg_cubed(func, shape, array_chunks, group_chunks, add_nan, engine, reindex):
497+
"""Tests groupby_reduce with cubed arrays against groupby_reduce with numpy arrays"""
498+
499+
if func in ["first", "last"] or func in BLOCKWISE_FUNCS:
500+
pytest.skip()
501+
502+
if "arg" in func and (engine in ["flox", "numbagg"] or reindex):
503+
pytest.skip()
504+
505+
array = cubed.array_api.ones(shape, chunks=array_chunks)
506+
507+
labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
508+
if add_nan:
509+
labels = labels.astype(float)
510+
labels[:3] = np.nan # entire block is NaN when group_chunks=3
511+
labels[-2:] = np.nan
512+
513+
kwargs = dict(
514+
func=func,
515+
expected_groups=[0, 1, 2],
516+
fill_value=False if func in ["all", "any"] else 123,
517+
reindex=reindex,
518+
)
519+
520+
expected, _ = groupby_reduce(array.compute(), labels, engine="numpy", **kwargs)
521+
actual, _ = groupby_reduce(array.compute(), labels, engine=engine, **kwargs)
522+
assert_equal(actual, expected)
523+
524+
# TODO: raise_if_cubed_computes
525+
actual, _ = groupby_reduce(array, labels, engine=engine, **kwargs)
526+
assert_equal(expected, actual)
527+
528+
480529
def test_numpy_reduce_axis_subset(engine):
481530
# TODO: add NaNs
482531
by = labels2d

0 commit comments

Comments
 (0)