@@ -2525,34 +2525,74 @@ def groupby_agg(
2525
2525
if callable (agg_func ):
2526
2526
agg_func = wrap_udf_function (agg_func )
2527
2527
2528
- if is_multi_by :
2529
- return super ().groupby_agg (
2530
- by = by ,
2531
- is_multi_by = is_multi_by ,
2532
- axis = axis ,
2533
- agg_func = agg_func ,
2534
- agg_args = agg_args ,
2535
- agg_kwargs = agg_kwargs ,
2536
- groupby_kwargs = groupby_kwargs ,
2537
- drop = drop ,
2538
- )
2539
-
2540
- by = by .to_pandas ().squeeze () if isinstance (by , type (self )) else by
2541
-
2542
2528
# since we're going to modify `groupby_kwargs` dict in a `groupby_agg_builder`,
2543
2529
# we want to copy it to not propagate these changes into source dict, in case
2544
2530
# of unsuccessful end of function
2545
2531
groupby_kwargs = groupby_kwargs .copy ()
2546
2532
2547
2533
as_index = groupby_kwargs .get ("as_index" , True )
2534
+ if isinstance (by , type (self )):
2535
+ # `drop` parameter indicates whether or not 'by' data came
2536
+ # from the `self` frame:
2537
+ # True: 'by' data came from the `self`
2538
+ # False: external 'by' data
2539
+ if drop :
2540
+ internal_by = by .columns
2541
+ by = [by ]
2542
+ else :
2543
+ internal_by = []
2544
+ by = [by ]
2545
+ else :
2546
+ if not isinstance (by , list ):
2547
+ by = [by ]
2548
+ internal_by = [o for o in by if isinstance (o , str ) and o in self .columns ]
2549
+ internal_qc = (
2550
+ [self .getitem_column_array (internal_by )] if len (internal_by ) else []
2551
+ )
2552
+
2553
+ by = internal_qc + by [len (internal_by ) :]
2554
+
2555
+ broadcastable_by = [o ._modin_frame for o in by if isinstance (o , type (self ))]
2556
+ not_broadcastable_by = [o for o in by if not isinstance (o , type (self ))]
2548
2557
2549
- def groupby_agg_builder (df ):
2558
+ def groupby_agg_builder (df , by = None , drop = False , partition_idx = None ):
2550
2559
# Set `as_index` to True to track the metadata of the grouping object
2551
2560
# It is used to make sure that between phases we are constructing the
2552
2561
# right index and placing columns in the correct order.
2553
2562
groupby_kwargs ["as_index" ] = True
2554
2563
2555
- def compute_groupby (df ):
2564
+ internal_by_cols = pandas .Index ([])
2565
+ missmatched_cols = pandas .Index ([])
2566
+ if by is not None :
2567
+ internal_by_df = by [internal_by ]
2568
+
2569
+ if isinstance (internal_by_df , pandas .Series ):
2570
+ internal_by_df = internal_by_df .to_frame ()
2571
+
2572
+ missmatched_cols = internal_by_df .columns .difference (df .columns )
2573
+ df = pandas .concat (
2574
+ [df , internal_by_df [missmatched_cols ]],
2575
+ axis = 1 ,
2576
+ copy = False ,
2577
+ )
2578
+ internal_by_cols = internal_by_df .columns
2579
+
2580
+ external_by = by .columns .difference (internal_by )
2581
+ external_by_df = by [external_by ].squeeze (axis = 1 )
2582
+
2583
+ if isinstance (external_by_df , pandas .DataFrame ):
2584
+ external_by_cols = [o for _ , o in external_by_df .iteritems ()]
2585
+ else :
2586
+ external_by_cols = [external_by_df ]
2587
+
2588
+ by = internal_by_cols .tolist () + external_by_cols
2589
+
2590
+ else :
2591
+ by = []
2592
+
2593
+ by += not_broadcastable_by
2594
+
2595
+ def compute_groupby (df , drop = False , partition_idx = 0 ):
2556
2596
grouped_df = df .groupby (by = by , axis = axis , ** groupby_kwargs )
2557
2597
try :
2558
2598
if isinstance (agg_func , dict ):
@@ -2569,17 +2609,91 @@ def compute_groupby(df):
2569
2609
# issues with extracting the index.
2570
2610
except (DataError , TypeError ):
2571
2611
result = pandas .DataFrame (index = grouped_df .size ().index )
2612
+ if isinstance (result , pandas .Series ):
2613
+ result = result .to_frame (
2614
+ result .name if result .name is not None else "__reduced__"
2615
+ )
2616
+
2617
+ result_cols = result .columns
2618
+ result .drop (columns = missmatched_cols , inplace = True , errors = "ignore" )
2619
+
2620
+ if not as_index :
2621
+ keep_index_levels = len (by ) > 1 and any (
2622
+ isinstance (x , pandas .CategoricalDtype )
2623
+ for x in df [internal_by_cols ].dtypes
2624
+ )
2625
+
2626
+ cols_to_insert = (
2627
+ internal_by_cols .intersection (result_cols )
2628
+ if keep_index_levels
2629
+ else internal_by_cols .difference (result_cols )
2630
+ )
2631
+
2632
+ if keep_index_levels :
2633
+ result .drop (
2634
+ columns = cols_to_insert , inplace = True , errors = "ignore"
2635
+ )
2636
+
2637
+ drop = True
2638
+ if partition_idx == 0 :
2639
+ drop = False
2640
+ if not keep_index_levels :
2641
+ lvls_to_drop = [
2642
+ i
2643
+ for i , name in enumerate (result .index .names )
2644
+ if name not in cols_to_insert
2645
+ ]
2646
+ if len (lvls_to_drop ) == result .index .nlevels :
2647
+ drop = True
2648
+ else :
2649
+ result .index = result .index .droplevel (lvls_to_drop )
2650
+
2651
+ if (
2652
+ not isinstance (result .index , pandas .MultiIndex )
2653
+ and result .index .name is None
2654
+ ):
2655
+ drop = True
2656
+
2657
+ result .reset_index (drop = drop , inplace = True )
2658
+
2659
+ new_index_names = [
2660
+ None
2661
+ if isinstance (name , str ) and name .startswith ("__reduced__" )
2662
+ else name
2663
+ for name in result .index .names
2664
+ ]
2665
+
2666
+ cols_to_drop = (
2667
+ result .columns [result .columns .str .match (r"__reduced__.*" , na = False )]
2668
+ if hasattr (result .columns , "str" )
2669
+ else []
2670
+ )
2671
+
2672
+ result .index .names = new_index_names
2673
+
2674
+ # Not dropping columns if result is Series
2675
+ if len (result .columns ) > 1 :
2676
+ result .drop (columns = cols_to_drop , inplace = True )
2677
+
2572
2678
return result
2573
2679
2574
2680
try :
2575
- return compute_groupby (df )
2681
+ return compute_groupby (df , drop , partition_idx )
2576
2682
# This will happen with Arrow buffer read-only errors. We don't want to copy
2577
2683
# all the time, so this will try to fast-path the code first.
2578
2684
except (ValueError , KeyError ):
2579
- return compute_groupby (df .copy ())
2685
+ return compute_groupby (df .copy (), drop , partition_idx )
2580
2686
2581
- new_modin_frame = self ._modin_frame ._apply_full_axis (
2582
- axis , lambda df : groupby_agg_builder (df )
2687
+ apply_indices = list (agg_func .keys ()) if isinstance (agg_func , dict ) else None
2688
+
2689
+ new_modin_frame = self ._modin_frame .broadcast_apply_full_axis (
2690
+ axis = axis ,
2691
+ func = lambda df , by = None , partition_idx = None : groupby_agg_builder (
2692
+ df , by , drop , partition_idx
2693
+ ),
2694
+ other = broadcastable_by ,
2695
+ apply_indices = apply_indices ,
2696
+ enumerate_partitions = True ,
2583
2697
)
2584
2698
result = self .__constructor__ (new_modin_frame )
2585
2699
@@ -2598,14 +2712,7 @@ def compute_groupby(df):
2598
2712
except Exception as e :
2599
2713
raise type (e )("No numeric types to aggregate." )
2600
2714
2601
- # Reset `as_index` because it was edited inplace.
2602
- groupby_kwargs ["as_index" ] = as_index
2603
- if as_index :
2604
- return result
2605
- else :
2606
- if result .index .name is None or result .index .name in result .columns :
2607
- drop = False
2608
- return result .reset_index (drop = not drop )
2715
+ return result
2609
2716
2610
2717
# END Manual Partitioning methods
2611
2718
0 commit comments