@@ -64,7 +64,7 @@ def _get_expected_groups(by, sort, *, raise_if_dask=True) -> pd.Index | None:
64
64
if raise_if_dask :
65
65
raise ValueError ("Please provide expected_groups if not grouping by a numpy array." )
66
66
return None
67
- flatby = by .ravel ( )
67
+ flatby = by .reshape ( - 1 )
68
68
expected = pd .unique (flatby [~ isnull (flatby )])
69
69
return _convert_expected_groups_to_index ((expected ,), isbin = (False ,), sort = sort )[0 ]
70
70
@@ -175,11 +175,11 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
175
175
blocks = np .empty (np .prod (shape ), dtype = object )
176
176
for idx , block in enumerate (array .blocks .ravel ()):
177
177
blocks [idx ] = np .full (tuple (block .shape [ax ] for ax in axis ), idx )
178
- which_chunk = np .block (blocks .reshape (shape ).tolist ()).ravel ( )
178
+ which_chunk = np .block (blocks .reshape (shape ).tolist ()).reshape ( - 1 )
179
179
180
180
# We always drop NaN; np.unique also considers every NaN to be different so
181
181
# it's really important we get rid of them.
182
- raveled = labels .ravel ( )
182
+ raveled = labels .reshape ( - 1 )
183
183
unique_labels = np .unique (raveled [~ isnull (raveled )])
184
184
# these are chunks where a label is present
185
185
label_chunks = {lab : tuple (np .unique (which_chunk [raveled == lab ])) for lab in unique_labels }
@@ -421,19 +421,28 @@ def factorize_(
421
421
factorized = []
422
422
found_groups = []
423
423
for groupvar , expect in zip (by , expected_groups ):
424
- flat = groupvar .ravel ()
425
- if isinstance (expect , pd .IntervalIndex ):
424
+ flat = groupvar .reshape (- 1 )
425
+ if isinstance (expect , pd .RangeIndex ):
426
+ idx = flat
427
+ found_groups .append (np .array (expect ))
428
+ # TODO: fix by using masked integers
429
+ idx [idx > expect [- 1 ]] = - 1
430
+
431
+ elif isinstance (expect , pd .IntervalIndex ):
426
432
# when binning we change expected groups to integers marking the interval
427
433
# this makes the reindexing logic simpler.
428
- if expect is None :
429
- raise ValueError ("Please pass bin edges in expected_groups." )
430
- # TODO: fix for binning
431
- found_groups .append (expect )
432
- # pd.cut with bins = IntervalIndex[datetime64] doesn't work...
434
+ # workaround for https://github.com/pandas-dev/pandas/issues/47614
435
+ # we create breaks and pass that to pd.cut, disallow closed="both" for now.
436
+ if expect .closed == "both" :
437
+ raise NotImplementedError
433
438
if groupvar .dtype .kind == "M" :
434
- expect = np .concatenate ([expect .left .to_numpy (), [expect .right [- 1 ].to_numpy ()]])
439
+ # pd.cut with bins = IntervalIndex[datetime64] doesn't work...
440
+ bins = np .concatenate ([expect .left .to_numpy (), [expect .right [- 1 ].to_numpy ()]])
441
+ else :
442
+ bins = np .concatenate ([expect .left .to_numpy (), [expect .right [- 1 ]]])
435
443
# code is -1 for values outside the bounds of all intervals
436
- idx = pd .cut (flat , bins = expect ).codes .copy ()
444
+ idx = pd .cut (flat , bins = bins , right = expect .closed_right ).codes .copy ()
445
+ found_groups .append (expect )
437
446
else :
438
447
if expect is not None and reindex :
439
448
sorter = np .argsort (expect )
@@ -447,7 +456,7 @@ def factorize_(
447
456
idx = sorter [(idx ,)]
448
457
idx [mask ] = - 1
449
458
else :
450
- idx , groups = pd .factorize (groupvar . ravel () , sort = sort )
459
+ idx , groups = pd .factorize (flat , sort = sort )
451
460
452
461
found_groups .append (np .array (groups ))
453
462
factorized .append (idx )
@@ -473,7 +482,7 @@ def factorize_(
473
482
# we collapse to a 2D by and axis=-1
474
483
offset_group = True
475
484
group_idx , size = offset_labels (group_idx .reshape (by [0 ].shape ), ngroups )
476
- group_idx = group_idx .ravel ( )
485
+ group_idx = group_idx .reshape ( - 1 )
477
486
else :
478
487
size = ngroups
479
488
offset_group = False
@@ -622,7 +631,7 @@ def chunk_reduce(
622
631
# avoid by factorizing again so indices=[2,2,2] is changed to
623
632
# indices=[0,0,0]. This is necessary when combining block results
624
633
# factorize can handle strings etc unlike digitize
625
- group_idx , groups , found_groups_shape , ngroups , size , props = factorize_ (
634
+ group_idx , groups , found_groups_shape , _ , size , props = factorize_ (
626
635
(by ,), axis , expected_groups = (expected_groups ,), reindex = reindex , sort = sort
627
636
)
628
637
groups = groups [0 ]
0 commit comments