@@ -1164,7 +1164,7 @@ def subset_to_blocks(
11641164 return dask .array .Array (graph , name , chunks , meta = array )
11651165
11661166
1167- def _extract_unknown_groups (reduced , group_chunks , dtype ) -> tuple [DaskArray ]:
1167+ def _extract_unknown_groups (reduced , dtype ) -> tuple [DaskArray ]:
11681168 import dask .array
11691169 from dask .highlevelgraph import HighLevelGraph
11701170
@@ -1180,7 +1180,7 @@ def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:
11801180 dask .array .Array (
11811181 HighLevelGraph .from_collections (groups_token , layer , dependencies = [reduced ]),
11821182 groups_token ,
1183- chunks = group_chunks ,
1183+ chunks = (( np . nan ,),) ,
11841184 meta = np .array ([], dtype = dtype ),
11851185 ),
11861186 )
@@ -1293,14 +1293,7 @@ def dask_groupby_agg(
12931293 name = f"{ name } -chunk-{ token } " ,
12941294 )
12951295
1296- if expected_groups is None :
1297- if is_duck_dask_array (by_input ):
1298- expected_groups = None
1299- else :
1300- expected_groups = _get_expected_groups (by_input , sort = sort )
1301- group_chunks : tuple [tuple [Union [int , float ], ...]] = (
1302- (len (expected_groups ),) if expected_groups is not None else (np .nan ,),
1303- )
1296+ group_chunks : tuple [tuple [Union [int , float ], ...]]
13041297
13051298 if method in ["map-reduce" , "cohorts" ]:
13061299 combine : Callable [..., IntermediateDict ]
@@ -1333,13 +1326,13 @@ def dask_groupby_agg(
13331326 aggregate = partial (aggregate , expected_groups = expected_groups , reindex = reindex ),
13341327 )
13351328 if is_duck_dask_array (by_input ) and expected_groups is None :
1336- groups = _extract_unknown_groups (reduced , group_chunks = group_chunks , dtype = by .dtype )
1329+ groups = _extract_unknown_groups (reduced , dtype = by .dtype )
1330+ group_chunks = ((np .nan ,),)
13371331 else :
13381332 if expected_groups is None :
1339- expected_groups_ = _get_expected_groups (by_input , sort = sort )
1340- else :
1341- expected_groups_ = expected_groups
1342- groups = (expected_groups_ .to_numpy (),)
1333+ expected_groups = _get_expected_groups (by_input , sort = sort )
1334+ groups = (expected_groups .to_numpy (),)
1335+ group_chunks = ((len (expected_groups ),),)
13431336
13441337 elif method == "cohorts" :
13451338 chunks_cohorts = find_group_cohorts (
0 commit comments