From 4a221de460380b3bf9f8c3675345887a6e9fe50c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 17 Sep 2024 19:41:43 -0600 Subject: [PATCH 1/2] Faster subsetting for cohorts Closes #396 --- flox/core.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/flox/core.py b/flox/core.py index 7e5362e18..1fb91dc25 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1494,8 +1494,9 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple: def subset_to_blocks( array: DaskArray, flatblocks: Sequence[int], - blkshape: tuple[int] | None = None, + blkshape: tuple[int, ...] | None = None, reindexer=identity, + chunks_as_array: tuple[int, ...] | None = None, ) -> DaskArray: """ Advanced indexing of .blocks such that we always get a regular array back. @@ -1518,6 +1519,9 @@ def subset_to_blocks( if blkshape is None: blkshape = array.blocks.shape + if chunks_as_array is None: + chunks_as_array = tuple(np.array(c) for c in array.chunks) + index = _normalize_indexes(array, flatblocks, blkshape) if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index): @@ -1531,7 +1535,7 @@ def subset_to_blocks( new_keys = array._key_array[index] squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index) - chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed)) + chunks = tuple(tuple(c[i].tolist()) for c, i in zip(chunks_as_array, squeezed)) keys = itertools.product(*(range(len(c)) for c in chunks)) layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys} @@ -1726,6 +1730,7 @@ def dask_groupby_agg( reduced_ = [] groups_ = [] + chunks_as_array = tuple(np.array(c) for c in array.chunks) for blks, cohort in chunks_cohorts.items(): cohort_index = pd.Index(cohort) reindexer = ( @@ -1733,7 +1738,7 @@ def dask_groupby_agg( if do_simple_combine else identity ) - reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer) + reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer, chunks_as_array) # now that we have reindexed, we can set reindex=True explicitlly reduced_.append( tree_reduce( From 08d308701e7c79d5075e9f0d1c4bd7bc4bf31b5d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 17 Sep 2024 19:46:38 -0600 Subject: [PATCH 2/2] tpying --- flox/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index 1fb91dc25..91903ded7 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1496,7 +1496,7 @@ def subset_to_blocks( flatblocks: Sequence[int], blkshape: tuple[int, ...] | None = None, reindexer=identity, - chunks_as_array: tuple[int, ...] | None = None, + chunks_as_array: tuple[np.ndarray, ...] | None = None, ) -> DaskArray: """ Advanced indexing of .blocks such that we always get a regular array back.