Skip to content

Commit 6f450ba

Browse files
author
Shallow Copy Bot
committed
Add basic extension dtypes support.
Original PR #2039 by ueshin Original: databricks/koalas#2039
1 parent 4723ddc commit 6f450ba

File tree

15 files changed

+1210
-211
lines changed

15 files changed

+1210
-211
lines changed

databricks/koalas/base.py

Lines changed: 121 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import numpy as np
2727
import pandas as pd # noqa: F401
28-
from pandas.api.types import is_list_like
28+
from pandas.api.types import is_list_like, pandas_dtype
2929
from pyspark import sql as spark
3030
from pyspark.sql import functions as F, Window, Column
3131
from pyspark.sql.types import (
@@ -51,7 +51,12 @@
5151
)
5252
from databricks.koalas.spark import functions as SF
5353
from databricks.koalas.spark.accessors import SparkIndexOpsMethods
54-
from databricks.koalas.typedef import as_spark_type, spark_type_to_pandas_dtype
54+
from databricks.koalas.typedef import (
55+
Dtype,
56+
as_spark_type,
57+
extension_dtypes,
58+
spark_type_to_pandas_dtype,
59+
)
5560
from databricks.koalas.utils import (
5661
combine_frames,
5762
same_anchor,
@@ -191,7 +196,7 @@ def align_diff_index_ops(func, this_index_ops: "IndexOpsMixin", *args) -> "Index
191196
).rename(that_series.name)
192197

193198

194-
def booleanize_null(left_scol, scol, f) -> Column:
199+
def booleanize_null(scol, f) -> Column:
195200
"""
196201
Booleanize Null in Spark Column
197202
"""
@@ -205,12 +210,6 @@ def booleanize_null(left_scol, scol, f) -> Column:
205210
filler = f == Column.__ne__
206211
scol = F.when(scol.isNull(), filler).otherwise(scol)
207212

208-
elif f == Column.__or__:
209-
scol = F.when(left_scol.isNull() | scol.isNull(), False).otherwise(scol)
210-
211-
elif f == Column.__and__:
212-
scol = F.when(scol.isNull(), False).otherwise(scol)
213-
214213
return scol
215214

216215

@@ -239,13 +238,23 @@ def wrapper(self, *args):
239238
# Same DataFrame anchors
240239
args = [arg.spark.column if isinstance(arg, IndexOpsMixin) else arg for arg in args]
241240
scol = f(self.spark.column, *args)
242-
scol = booleanize_null(self.spark.column, scol, f)
241+
242+
spark_type = self._internal.spark_frame.select(scol).schema[0].dataType
243+
use_extension_dtypes = any(
244+
isinstance(col.dtype, extension_dtypes) for col in [self] + cols
245+
)
246+
dtype = spark_type_to_pandas_dtype(
247+
spark_type, use_extension_dtypes=use_extension_dtypes
248+
)
249+
250+
if not isinstance(dtype, extension_dtypes):
251+
scol = booleanize_null(scol, f)
243252

244253
if isinstance(self, Series) or not any(isinstance(col, Series) for col in cols):
245-
index_ops = self._with_new_scol(scol)
254+
index_ops = self._with_new_scol(scol, dtype=dtype)
246255
else:
247256
kser = next(col for col in cols if isinstance(col, Series))
248-
index_ops = kser._with_new_scol(scol)
257+
index_ops = kser._with_new_scol(scol, dtype=dtype)
249258
elif get_option("compute.ops_on_diff_frames"):
250259
index_ops = align_diff_index_ops(f, self, *args)
251260
else:
@@ -293,7 +302,7 @@ def _kdf(self) -> DataFrame:
293302
pass
294303

295304
@abstractmethod
296-
def _with_new_scol(self, scol: spark.Column):
305+
def _with_new_scol(self, scol: spark.Column, *, dtype=None):
297306
pass
298307

299308
@property
@@ -603,11 +612,63 @@ def rpow_func(left, right):
603612

604613
# `and`, `or`, `not` cannot be overloaded in Python,
605614
# so use bitwise operators as boolean operators
606-
__and__ = column_op(Column.__and__)
607-
__or__ = column_op(Column.__or__)
615+
def __and__(self, other) -> Union["Series", "Index"]:
616+
if isinstance(self.dtype, extension_dtypes) or (
617+
isinstance(other, IndexOpsMixin) and isinstance(other.dtype, extension_dtypes)
618+
):
619+
620+
def and_func(left, right):
621+
if not isinstance(right, spark.Column):
622+
if pd.isna(right):
623+
right = F.lit(None)
624+
else:
625+
right = F.lit(right)
626+
return left & right
627+
628+
else:
629+
630+
def and_func(left, right):
631+
if not isinstance(right, spark.Column):
632+
if pd.isna(right):
633+
right = F.lit(None)
634+
else:
635+
right = F.lit(right)
636+
scol = left & right
637+
return F.when(scol.isNull(), False).otherwise(scol)
638+
639+
return column_op(and_func)(self, other)
640+
641+
def __or__(self, other) -> Union["Series", "Index"]:
642+
if isinstance(self.dtype, extension_dtypes) or (
643+
isinstance(other, IndexOpsMixin) and isinstance(other.dtype, extension_dtypes)
644+
):
645+
646+
def or_func(left, right):
647+
if not isinstance(right, spark.Column):
648+
if pd.isna(right):
649+
right = F.lit(None)
650+
else:
651+
right = F.lit(right)
652+
return left | right
653+
654+
else:
655+
656+
def or_func(left, right):
657+
if not isinstance(right, spark.Column) and pd.isna(right):
658+
return F.lit(False)
659+
else:
660+
scol = left | F.lit(right)
661+
return F.when(left.isNull() | scol.isNull(), False).otherwise(scol)
662+
663+
return column_op(or_func)(self, other)
664+
608665
__invert__ = column_op(Column.__invert__)
609-
__rand__ = column_op(Column.__rand__)
610-
__ror__ = column_op(Column.__ror__)
666+
667+
def __rand__(self, other) -> Union["Series", "Index"]:
668+
return self.__and__(other)
669+
670+
def __ror__(self, other) -> Union["Series", "Index"]:
671+
return self.__or__(other)
611672

612673
def __len__(self):
613674
return len(self._kdf)
@@ -632,7 +693,7 @@ def __array_ufunc__(self, ufunc: Callable, method: str, *inputs: Any, **kwargs:
632693
raise NotImplementedError("Koalas objects currently do not support %s." % ufunc)
633694

634695
@property
635-
def dtype(self) -> np.dtype:
696+
def dtype(self) -> Dtype:
636697
"""Return the dtype object of the underlying data.
637698
638699
Examples
@@ -652,7 +713,7 @@ def dtype(self) -> np.dtype:
652713
>>> s.rename("a").to_frame().set_index("a").index.dtype
653714
dtype('<M8[ns]')
654715
"""
655-
return spark_type_to_pandas_dtype(self.spark.data_type)
716+
return self._internal.data_dtypes[0]
656717

657718
@property
658719
def empty(self) -> bool:
@@ -955,7 +1016,7 @@ def ndim(self) -> int:
9551016
"""
9561017
return 1
9571018

958-
def astype(self, dtype) -> Union["Index", "Series"]:
1019+
def astype(self, dtype: Union[str, type, Dtype]) -> Union["Index", "Series"]:
9591020
"""
9601021
Cast a Koalas object to a specified dtype ``dtype``.
9611022
@@ -989,37 +1050,55 @@ def astype(self, dtype) -> Union["Index", "Series"]:
9891050
>>> ser.rename("a").to_frame().set_index("a").index.astype('int64')
9901051
Int64Index([1, 2], dtype='int64', name='a')
9911052
"""
1053+
dtype = pandas_dtype(dtype)
9921054
spark_type = as_spark_type(dtype)
9931055
if not spark_type:
9941056
raise ValueError("Type {} not understood".format(dtype))
9951057
if isinstance(spark_type, BooleanType):
996-
if isinstance(self.spark.data_type, StringType):
997-
scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise(
998-
F.length(self.spark.column) > 0
999-
)
1000-
elif isinstance(self.spark.data_type, (FloatType, DoubleType)):
1001-
scol = F.when(
1002-
self.spark.column.isNull() | F.isnan(self.spark.column), F.lit(True)
1003-
).otherwise(self.spark.column.cast(spark_type))
1058+
if isinstance(dtype, extension_dtypes):
1059+
scol = self.spark.column.cast(spark_type)
10041060
else:
1005-
scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise(
1006-
self.spark.column.cast(spark_type)
1007-
)
1061+
if isinstance(self.spark.data_type, StringType):
1062+
scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise(
1063+
F.length(self.spark.column) > 0
1064+
)
1065+
elif isinstance(self.spark.data_type, (FloatType, DoubleType)):
1066+
scol = F.when(
1067+
self.spark.column.isNull() | F.isnan(self.spark.column), F.lit(True)
1068+
).otherwise(self.spark.column.cast(spark_type))
1069+
else:
1070+
scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise(
1071+
self.spark.column.cast(spark_type)
1072+
)
10081073
elif isinstance(spark_type, StringType):
1009-
if isinstance(self.spark.data_type, NumericType):
1010-
null_str = str(np.nan)
1011-
elif isinstance(self.spark.data_type, (DateType, TimestampType)):
1012-
null_str = str(pd.NaT)
1013-
else:
1014-
null_str = str(None)
1015-
if isinstance(self.spark.data_type, BooleanType):
1016-
casted = F.when(self.spark.column, "True").otherwise("False")
1074+
if isinstance(dtype, extension_dtypes):
1075+
if isinstance(self.spark.data_type, BooleanType):
1076+
scol = F.when(
1077+
self.spark.column.isNotNull(),
1078+
F.when(self.spark.column, "True").otherwise("False"),
1079+
)
1080+
elif isinstance(self.spark.data_type, TimestampType):
1081+
# seems like a pandas' bug?
1082+
scol = F.when(self.spark.column.isNull(), str(pd.NaT)).otherwise(
1083+
self.spark.column.cast(spark_type)
1084+
)
1085+
else:
1086+
scol = self.spark.column.cast(spark_type)
10171087
else:
1018-
casted = self.spark.column.cast(spark_type)
1019-
scol = F.when(self.spark.column.isNull(), null_str).otherwise(casted)
1088+
if isinstance(self.spark.data_type, NumericType):
1089+
null_str = str(np.nan)
1090+
elif isinstance(self.spark.data_type, (DateType, TimestampType)):
1091+
null_str = str(pd.NaT)
1092+
else:
1093+
null_str = str(None)
1094+
if isinstance(self.spark.data_type, BooleanType):
1095+
casted = F.when(self.spark.column, "True").otherwise("False")
1096+
else:
1097+
casted = self.spark.column.cast(spark_type)
1098+
scol = F.when(self.spark.column.isNull(), null_str).otherwise(casted)
10201099
else:
10211100
scol = self.spark.column.cast(spark_type)
1022-
return self._with_new_scol(scol)
1101+
return self._with_new_scol(scol, dtype=dtype)
10231102

10241103
def isin(self, values) -> Union["Series", "Index"]:
10251104
"""

databricks/koalas/frame.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def apply_op(kdf, this_column_labels, that_column_labels):
795795
.alias(name_like_string(label))
796796
)
797797
column_labels.append(label)
798-
internal = self._internal.with_new_columns(applied, column_labels)
798+
internal = self._internal.with_new_columns(applied, column_labels=column_labels)
799799
return DataFrame(internal)
800800
else:
801801
return self._apply_series_op(lambda kser: getattr(kser, op)(other))
@@ -5504,7 +5504,7 @@ def head(self, n: int = 5) -> "DataFrame":
55045504
sdf = self._internal.resolved_copy.spark_frame
55055505
if get_option("compute.ordered_head"):
55065506
sdf = sdf.orderBy(NATURAL_ORDER_COLUMN_NAME)
5507-
return DataFrame(self._internal.with_new_sdf(sdf.limit(n)))
5507+
return DataFrame(self._internal.with_new_sdf(sdf.limit(n), preserve_dtypes=True))
55085508

55095509
def pivot_table(
55105510
self, values=None, index=None, columns=None, aggfunc="mean", fill_value=None
@@ -5958,17 +5958,12 @@ def columns(self, columns) -> None:
59585958
else:
59595959
column_label_names = None
59605960

5961-
data_columns = [name_like_string(label) for label in column_labels]
5962-
data_spark_columns = [
5963-
self._internal.spark_column_for(label).alias(name)
5964-
for label, name in zip(self._internal.column_labels, data_columns)
5961+
ksers = [
5962+
self._kser_for(label).rename(name)
5963+
for label, name in zip(self._internal.column_labels, column_labels)
59655964
]
59665965
self._update_internal_frame(
5967-
self._internal.with_new_columns(
5968-
data_spark_columns,
5969-
column_labels=column_labels,
5970-
column_label_names=column_label_names,
5971-
)
5966+
self._internal.with_new_columns(ksers, column_label_names=column_label_names)
59725967
)
59735968

59745969
@property
@@ -6434,7 +6429,9 @@ def _sort(
64346429
}
64356430
by = [mapper[(asc, na_position)](scol) for scol, asc in zip(by, ascending)]
64366431
sdf = self._internal.resolved_copy.spark_frame.sort(*(by + [NATURAL_ORDER_COLUMN_NAME]))
6437-
kdf = DataFrame(self._internal.with_new_sdf(sdf)) # type: ks.DataFrame
6432+
kdf = DataFrame(
6433+
self._internal.with_new_sdf(sdf, preserve_dtypes=True)
6434+
) # type: ks.DataFrame
64386435
if inplace:
64396436
self._update_internal_frame(kdf._internal)
64406437
return None

databricks/koalas/indexes/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def _internal(self) -> InternalFrame:
159159
def _column_label(self):
160160
return self._kdf._internal.index_names[0]
161161

162-
def _with_new_scol(self, scol: spark.Column) -> "Index":
162+
def _with_new_scol(self, scol: spark.Column, *, dtype=None) -> "Index":
163163
"""
164164
Copy Koalas Index with the new Spark Column.
165165
@@ -168,8 +168,10 @@ def _with_new_scol(self, scol: spark.Column) -> "Index":
168168
"""
169169
internal = self._internal.copy(
170170
index_spark_columns=[scol.alias(SPARK_DEFAULT_INDEX_NAME)],
171+
index_dtypes=[dtype],
171172
column_labels=[],
172173
data_spark_columns=[],
174+
data_dtypes=[],
173175
)
174176
return DataFrame(internal).index
175177

databricks/koalas/indexes/multi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _column_label(self):
100100
def __abs__(self):
101101
raise TypeError("TypeError: cannot perform __abs__ with this index type: MultiIndex")
102102

103-
def _with_new_scol(self, scol: spark.Column):
103+
def _with_new_scol(self, scol: spark.Column, *, dtype=None):
104104
raise NotImplementedError("Not supported for type MultiIndex")
105105

106106
def _align_and_column_op(self, f, *args) -> Index:

0 commit comments

Comments
 (0)