@@ -622,9 +622,18 @@ def corr(self, **kwargs):
622
622
return self ._default_to_pandas (lambda df : df .corr (** kwargs ))
623
623
624
624
def fillna (self , ** kwargs ):
625
- result = self ._apply_agg_function (
626
- lambda df : df .fillna (** kwargs ), overwrite_groupby_kwargs = {"as_index" : True }
625
+ new_groupby_kwargs = self ._kwargs .copy ()
626
+ new_groupby_kwargs ["as_index" ] = True
627
+ work_object = type (self )(
628
+ df = self ._df ,
629
+ by = self ._by ,
630
+ axis = self ._axis ,
631
+ idx_name = self ._idx_name ,
632
+ drop = self ._drop ,
633
+ squeeze = self ._squeeze ,
634
+ ** new_groupby_kwargs ,
627
635
)
636
+ result = work_object ._apply_agg_function (lambda df : df .fillna (** kwargs ))
628
637
# pandas does not name the index on fillna
629
638
result ._query_compiler .set_index_name (None )
630
639
return result
@@ -894,7 +903,7 @@ def _wrap_aggregation(
894
903
return result .squeeze ()
895
904
return result
896
905
897
- def _apply_agg_function (self , f , overwrite_groupby_kwargs = None , * args , ** kwargs ):
906
+ def _apply_agg_function (self , f , * args , ** kwargs ):
898
907
"""
899
908
Perform aggregation and combine stages based on a given function.
900
909
@@ -904,8 +913,6 @@ def _apply_agg_function(self, f, overwrite_groupby_kwargs=None, *args, **kwargs)
904
913
----------
905
914
f: callable
906
915
The function to apply to each group.
907
- overwrite_groupby_kwargs: dict (optional),
908
- GroupBy kwargs to overwrite.
909
916
910
917
Returns
911
918
-------
@@ -914,17 +921,15 @@ def _apply_agg_function(self, f, overwrite_groupby_kwargs=None, *args, **kwargs)
914
921
assert callable (f ) or isinstance (
915
922
f , dict
916
923
), "'{0}' object is not callable and not a dict" .format (type (f ))
917
- groupby_kwargs = self ._kwargs .copy ()
918
- if overwrite_groupby_kwargs is not None :
919
- groupby_kwargs .update (overwrite_groupby_kwargs )
924
+
920
925
new_manager = self ._query_compiler .groupby_agg (
921
926
by = self ._by ,
922
927
is_multi_by = self ._is_multi_by ,
923
928
axis = self ._axis ,
924
929
agg_func = f ,
925
930
agg_args = args ,
926
931
agg_kwargs = kwargs ,
927
- groupby_kwargs = groupby_kwargs ,
932
+ groupby_kwargs = self . _kwargs ,
928
933
drop = self ._drop ,
929
934
)
930
935
if self ._idx_name is not None and self ._as_index :
0 commit comments