@@ -39,11 +39,10 @@ class providing the base-class of operations.
39
39
from pandas .errors import AbstractMethodError
40
40
from pandas .util ._decorators import Appender , Substitution , cache_readonly
41
41
42
- from pandas .core .dtypes .cast import maybe_downcast_to_dtype
42
+ from pandas .core .dtypes .cast import maybe_cast_result
43
43
from pandas .core .dtypes .common import (
44
44
ensure_float ,
45
45
is_datetime64_dtype ,
46
- is_extension_array_dtype ,
47
46
is_integer_dtype ,
48
47
is_numeric_dtype ,
49
48
is_object_dtype ,
@@ -53,7 +52,7 @@ class providing the base-class of operations.
53
52
54
53
from pandas .core import nanops
55
54
import pandas .core .algorithms as algorithms
56
- from pandas .core .arrays import Categorical , DatetimeArray , try_cast_to_ea
55
+ from pandas .core .arrays import Categorical , DatetimeArray
57
56
from pandas .core .base import DataError , PandasObject , SelectionMixin
58
57
import pandas .core .common as com
59
58
from pandas .core .frame import DataFrame
@@ -792,36 +791,6 @@ def _cumcount_array(self, ascending: bool = True):
792
791
rev [sorter ] = np .arange (count , dtype = np .intp )
793
792
return out [rev ].astype (np .int64 , copy = False )
794
793
795
- def _try_cast (self , result , obj , numeric_only : bool = False ):
796
- """
797
- Try to cast the result to our obj original type,
798
- we may have roundtripped through object in the mean-time.
799
-
800
- If numeric_only is True, then only try to cast numerics
801
- and not datetimelikes.
802
-
803
- """
804
- if obj .ndim > 1 :
805
- dtype = obj ._values .dtype
806
- else :
807
- dtype = obj .dtype
808
-
809
- if not is_scalar (result ):
810
- if is_extension_array_dtype (dtype ) and dtype .kind != "M" :
811
- # The function can return something of any type, so check
812
- # if the type is compatible with the calling EA.
813
- # datetime64tz is handled correctly in agg_series,
814
- # so is excluded here.
815
-
816
- if len (result ) and isinstance (result [0 ], dtype .type ):
817
- cls = dtype .construct_array_type ()
818
- result = try_cast_to_ea (cls , result , dtype = dtype )
819
-
820
- elif numeric_only and is_numeric_dtype (dtype ) or not numeric_only :
821
- result = maybe_downcast_to_dtype (result , dtype )
822
-
823
- return result
824
-
825
794
def _transform_should_cast (self , func_nm : str ) -> bool :
826
795
"""
827
796
Parameters
@@ -852,7 +821,7 @@ def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs):
852
821
continue
853
822
854
823
if self ._transform_should_cast (how ):
855
- result = self . _try_cast (result , obj )
824
+ result = maybe_cast_result (result , obj , how = how )
856
825
857
826
key = base .OutputKey (label = name , position = idx )
858
827
output [key ] = result
@@ -895,12 +864,12 @@ def _cython_agg_general(
895
864
assert len (agg_names ) == result .shape [1 ]
896
865
for result_column , result_name in zip (result .T , agg_names ):
897
866
key = base .OutputKey (label = result_name , position = idx )
898
- output [key ] = self . _try_cast (result_column , obj )
867
+ output [key ] = maybe_cast_result (result_column , obj , how = how )
899
868
idx += 1
900
869
else :
901
870
assert result .ndim == 1
902
871
key = base .OutputKey (label = name , position = idx )
903
- output [key ] = self . _try_cast (result , obj )
872
+ output [key ] = maybe_cast_result (result , obj , how = how )
904
873
idx += 1
905
874
906
875
if len (output ) == 0 :
@@ -929,7 +898,7 @@ def _python_agg_general(self, func, *args, **kwargs):
929
898
930
899
assert result is not None
931
900
key = base .OutputKey (label = name , position = idx )
932
- output [key ] = self . _try_cast (result , obj , numeric_only = True )
901
+ output [key ] = maybe_cast_result (result , obj , numeric_only = True )
933
902
934
903
if len (output ) == 0 :
935
904
return self ._python_apply_general (f )
@@ -944,7 +913,7 @@ def _python_agg_general(self, func, *args, **kwargs):
944
913
if is_numeric_dtype (values .dtype ):
945
914
values = ensure_float (values )
946
915
947
- output [key ] = self . _try_cast (values [mask ], result )
916
+ output [key ] = maybe_cast_result (values [mask ], result )
948
917
949
918
return self ._wrap_aggregated_output (output )
950
919
0 commit comments