@@ -1494,8 +1494,9 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:
14941494def subset_to_blocks (
14951495 array : DaskArray ,
14961496 flatblocks : Sequence [int ],
1497- blkshape : tuple [int ] | None = None ,
1497+ blkshape : tuple [int , ... ] | None = None ,
14981498 reindexer = identity ,
1499+ chunks_as_array : tuple [np .ndarray , ...] | None = None ,
14991500) -> DaskArray :
15001501 """
15011502 Advanced indexing of .blocks such that we always get a regular array back.
@@ -1518,6 +1519,9 @@ def subset_to_blocks(
15181519 if blkshape is None :
15191520 blkshape = array .blocks .shape
15201521
1522+ if chunks_as_array is None :
1523+ chunks_as_array = tuple (np .array (c ) for c in array .chunks )
1524+
15211525 index = _normalize_indexes (array , flatblocks , blkshape )
15221526
15231527 if all (not isinstance (i , np .ndarray ) and i == slice (None ) for i in index ):
@@ -1531,7 +1535,7 @@ def subset_to_blocks(
15311535 new_keys = array ._key_array [index ]
15321536
15331537 squeezed = tuple (np .squeeze (i ) if isinstance (i , np .ndarray ) else i for i in index )
1534- chunks = tuple (tuple (np . array ( c ) [i ].tolist ()) for c , i in zip (array . chunks , squeezed ))
1538+ chunks = tuple (tuple (c [i ].tolist ()) for c , i in zip (chunks_as_array , squeezed ))
15351539
15361540 keys = itertools .product (* (range (len (c )) for c in chunks ))
15371541 layer : Graph = {(name ,) + key : (reindexer , tuple (new_keys [key ].tolist ())) for key in keys }
@@ -1726,14 +1730,15 @@ def dask_groupby_agg(
17261730
17271731 reduced_ = []
17281732 groups_ = []
1733+ chunks_as_array = tuple (np .array (c ) for c in array .chunks )
17291734 for blks , cohort in chunks_cohorts .items ():
17301735 cohort_index = pd .Index (cohort )
17311736 reindexer = (
17321737 partial (reindex_intermediates , agg = agg , unique_groups = cohort_index )
17331738 if do_simple_combine
17341739 else identity
17351740 )
1736- reindexed = subset_to_blocks (intermediate , blks , block_shape , reindexer )
1741+ reindexed = subset_to_blocks (intermediate , blks , block_shape , reindexer , chunks_as_array )
17371742 # now that we have reindexed, we can set reindex=True explicitlly
17381743 reduced_ .append (
17391744 tree_reduce (
0 commit comments