2525
2626import numpy as np
2727import pandas as pd # noqa: F401
28- from pandas .api .types import is_list_like
28+ from pandas .api .types import is_list_like , pandas_dtype
2929from pyspark import sql as spark
3030from pyspark .sql import functions as F , Window , Column
3131from pyspark .sql .types import (
5151)
5252from databricks .koalas .spark import functions as SF
5353from databricks .koalas .spark .accessors import SparkIndexOpsMethods
54- from databricks .koalas .typedef import as_spark_type , spark_type_to_pandas_dtype
54+ from databricks .koalas .typedef import (
55+ Dtype ,
56+ as_spark_type ,
57+ extension_dtypes ,
58+ spark_type_to_pandas_dtype ,
59+ )
5560from databricks .koalas .utils import (
5661 combine_frames ,
5762 same_anchor ,
@@ -191,7 +196,7 @@ def align_diff_index_ops(func, this_index_ops: "IndexOpsMixin", *args) -> "Index
191196 ).rename (that_series .name )
192197
193198
194- def booleanize_null (left_scol , scol , f ) -> Column :
199+ def booleanize_null (scol , f ) -> Column :
195200 """
196201 Booleanize Null in Spark Column
197202 """
@@ -205,12 +210,6 @@ def booleanize_null(left_scol, scol, f) -> Column:
205210 filler = f == Column .__ne__
206211 scol = F .when (scol .isNull (), filler ).otherwise (scol )
207212
208- elif f == Column .__or__ :
209- scol = F .when (left_scol .isNull () | scol .isNull (), False ).otherwise (scol )
210-
211- elif f == Column .__and__ :
212- scol = F .when (scol .isNull (), False ).otherwise (scol )
213-
214213 return scol
215214
216215
@@ -239,13 +238,23 @@ def wrapper(self, *args):
239238 # Same DataFrame anchors
240239 args = [arg .spark .column if isinstance (arg , IndexOpsMixin ) else arg for arg in args ]
241240 scol = f (self .spark .column , * args )
242- scol = booleanize_null (self .spark .column , scol , f )
241+
242+ spark_type = self ._internal .spark_frame .select (scol ).schema [0 ].dataType
243+ use_extension_dtypes = any (
244+ isinstance (col .dtype , extension_dtypes ) for col in [self ] + cols
245+ )
246+ dtype = spark_type_to_pandas_dtype (
247+ spark_type , use_extension_dtypes = use_extension_dtypes
248+ )
249+
250+ if not isinstance (dtype , extension_dtypes ):
251+ scol = booleanize_null (scol , f )
243252
244253 if isinstance (self , Series ) or not any (isinstance (col , Series ) for col in cols ):
245- index_ops = self ._with_new_scol (scol )
254+ index_ops = self ._with_new_scol (scol , dtype = dtype )
246255 else :
247256 kser = next (col for col in cols if isinstance (col , Series ))
248- index_ops = kser ._with_new_scol (scol )
257+ index_ops = kser ._with_new_scol (scol , dtype = dtype )
249258 elif get_option ("compute.ops_on_diff_frames" ):
250259 index_ops = align_diff_index_ops (f , self , * args )
251260 else :
@@ -293,7 +302,7 @@ def _kdf(self) -> DataFrame:
293302 pass
294303
295304 @abstractmethod
296- def _with_new_scol (self , scol : spark .Column ):
305+ def _with_new_scol (self , scol : spark .Column , * , dtype = None ):
297306 pass
298307
299308 @property
@@ -603,11 +612,63 @@ def rpow_func(left, right):
603612
604613 # `and`, `or`, `not` cannot be overloaded in Python,
605614 # so use bitwise operators as boolean operators
606- __and__ = column_op (Column .__and__ )
607- __or__ = column_op (Column .__or__ )
615+ def __and__ (self , other ) -> Union ["Series" , "Index" ]:
616+ if isinstance (self .dtype , extension_dtypes ) or (
617+ isinstance (other , IndexOpsMixin ) and isinstance (other .dtype , extension_dtypes )
618+ ):
619+
620+ def and_func (left , right ):
621+ if not isinstance (right , spark .Column ):
622+ if pd .isna (right ):
623+ right = F .lit (None )
624+ else :
625+ right = F .lit (right )
626+ return left & right
627+
628+ else :
629+
630+ def and_func (left , right ):
631+ if not isinstance (right , spark .Column ):
632+ if pd .isna (right ):
633+ right = F .lit (None )
634+ else :
635+ right = F .lit (right )
636+ scol = left & right
637+ return F .when (scol .isNull (), False ).otherwise (scol )
638+
639+ return column_op (and_func )(self , other )
640+
641+ def __or__ (self , other ) -> Union ["Series" , "Index" ]:
642+ if isinstance (self .dtype , extension_dtypes ) or (
643+ isinstance (other , IndexOpsMixin ) and isinstance (other .dtype , extension_dtypes )
644+ ):
645+
646+ def or_func (left , right ):
647+ if not isinstance (right , spark .Column ):
648+ if pd .isna (right ):
649+ right = F .lit (None )
650+ else :
651+ right = F .lit (right )
652+ return left | right
653+
654+ else :
655+
656+ def or_func (left , right ):
657+ if not isinstance (right , spark .Column ) and pd .isna (right ):
658+ return F .lit (False )
659+ else :
660+ scol = left | F .lit (right )
661+ return F .when (left .isNull () | scol .isNull (), False ).otherwise (scol )
662+
663+ return column_op (or_func )(self , other )
664+
608665 __invert__ = column_op (Column .__invert__ )
609- __rand__ = column_op (Column .__rand__ )
610- __ror__ = column_op (Column .__ror__ )
666+
667+ def __rand__ (self , other ) -> Union ["Series" , "Index" ]:
668+ return self .__and__ (other )
669+
670+ def __ror__ (self , other ) -> Union ["Series" , "Index" ]:
671+ return self .__or__ (other )
611672
612673 def __len__ (self ):
613674 return len (self ._kdf )
@@ -632,7 +693,7 @@ def __array_ufunc__(self, ufunc: Callable, method: str, *inputs: Any, **kwargs:
632693 raise NotImplementedError ("Koalas objects currently do not support %s." % ufunc )
633694
634695 @property
635- def dtype (self ) -> np . dtype :
696+ def dtype (self ) -> Dtype :
636697 """Return the dtype object of the underlying data.
637698
638699 Examples
@@ -652,7 +713,7 @@ def dtype(self) -> np.dtype:
652713 >>> s.rename("a").to_frame().set_index("a").index.dtype
653714 dtype('<M8[ns]')
654715 """
655- return spark_type_to_pandas_dtype ( self .spark . data_type )
716+ return self ._internal . data_dtypes [ 0 ]
656717
657718 @property
658719 def empty (self ) -> bool :
@@ -955,7 +1016,7 @@ def ndim(self) -> int:
9551016 """
9561017 return 1
9571018
958- def astype (self , dtype ) -> Union ["Index" , "Series" ]:
1019+ def astype (self , dtype : Union [ str , type , Dtype ] ) -> Union ["Index" , "Series" ]:
9591020 """
9601021 Cast a Koalas object to a specified dtype ``dtype``.
9611022
@@ -989,37 +1050,55 @@ def astype(self, dtype) -> Union["Index", "Series"]:
9891050 >>> ser.rename("a").to_frame().set_index("a").index.astype('int64')
9901051 Int64Index([1, 2], dtype='int64', name='a')
9911052 """
1053+ dtype = pandas_dtype (dtype )
9921054 spark_type = as_spark_type (dtype )
9931055 if not spark_type :
9941056 raise ValueError ("Type {} not understood" .format (dtype ))
9951057 if isinstance (spark_type , BooleanType ):
996- if isinstance (self .spark .data_type , StringType ):
997- scol = F .when (self .spark .column .isNull (), F .lit (False )).otherwise (
998- F .length (self .spark .column ) > 0
999- )
1000- elif isinstance (self .spark .data_type , (FloatType , DoubleType )):
1001- scol = F .when (
1002- self .spark .column .isNull () | F .isnan (self .spark .column ), F .lit (True )
1003- ).otherwise (self .spark .column .cast (spark_type ))
1058+ if isinstance (dtype , extension_dtypes ):
1059+ scol = self .spark .column .cast (spark_type )
10041060 else :
1005- scol = F .when (self .spark .column .isNull (), F .lit (False )).otherwise (
1006- self .spark .column .cast (spark_type )
1007- )
1061+ if isinstance (self .spark .data_type , StringType ):
1062+ scol = F .when (self .spark .column .isNull (), F .lit (False )).otherwise (
1063+ F .length (self .spark .column ) > 0
1064+ )
1065+ elif isinstance (self .spark .data_type , (FloatType , DoubleType )):
1066+ scol = F .when (
1067+ self .spark .column .isNull () | F .isnan (self .spark .column ), F .lit (True )
1068+ ).otherwise (self .spark .column .cast (spark_type ))
1069+ else :
1070+ scol = F .when (self .spark .column .isNull (), F .lit (False )).otherwise (
1071+ self .spark .column .cast (spark_type )
1072+ )
10081073 elif isinstance (spark_type , StringType ):
1009- if isinstance (self .spark .data_type , NumericType ):
1010- null_str = str (np .nan )
1011- elif isinstance (self .spark .data_type , (DateType , TimestampType )):
1012- null_str = str (pd .NaT )
1013- else :
1014- null_str = str (None )
1015- if isinstance (self .spark .data_type , BooleanType ):
1016- casted = F .when (self .spark .column , "True" ).otherwise ("False" )
1074+ if isinstance (dtype , extension_dtypes ):
1075+ if isinstance (self .spark .data_type , BooleanType ):
1076+ scol = F .when (
1077+ self .spark .column .isNotNull (),
1078+ F .when (self .spark .column , "True" ).otherwise ("False" ),
1079+ )
1080+ elif isinstance (self .spark .data_type , TimestampType ):
1081+ # seems like a pandas' bug?
1082+ scol = F .when (self .spark .column .isNull (), str (pd .NaT )).otherwise (
1083+ self .spark .column .cast (spark_type )
1084+ )
1085+ else :
1086+ scol = self .spark .column .cast (spark_type )
10171087 else :
1018- casted = self .spark .column .cast (spark_type )
1019- scol = F .when (self .spark .column .isNull (), null_str ).otherwise (casted )
1088+ if isinstance (self .spark .data_type , NumericType ):
1089+ null_str = str (np .nan )
1090+ elif isinstance (self .spark .data_type , (DateType , TimestampType )):
1091+ null_str = str (pd .NaT )
1092+ else :
1093+ null_str = str (None )
1094+ if isinstance (self .spark .data_type , BooleanType ):
1095+ casted = F .when (self .spark .column , "True" ).otherwise ("False" )
1096+ else :
1097+ casted = self .spark .column .cast (spark_type )
1098+ scol = F .when (self .spark .column .isNull (), null_str ).otherwise (casted )
10201099 else :
10211100 scol = self .spark .column .cast (spark_type )
1022- return self ._with_new_scol (scol )
1101+ return self ._with_new_scol (scol , dtype = dtype )
10231102
10241103 def isin (self , values ) -> Union ["Series" , "Index" ]:
10251104 """
0 commit comments