Skip to content

Commit 8f923c9

Browse files
authored
Make Series.astype(bool) follow the concept of "truthy" and "falsey". (#1431)
Making `Series.astype(bool)` follow the concept of "truthy" and "falsey". ```py >>> kser = ks.Series(["hi", "hi ", " ", " \t", "", None], name="x") >>> kser 0 hi 1 hi 2 3 \t 4 5 None Name: x, dtype: object >>> kser.astype(bool) 0 True 1 True 2 True 3 True 4 False 5 False Name: x, dtype: bool >>> kser.str.strip().astype(bool) 0 True 1 True 2 False 3 False 4 False 5 False Name: x, dtype: bool ``` Resolves #1430.
1 parent 93932bf commit 8f923c9

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

databricks/koalas/series.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,15 @@
3232
from databricks.koalas.typedef import as_python_type
3333
from pyspark import sql as spark
3434
from pyspark.sql import functions as F, Column
35-
from pyspark.sql.types import BooleanType, StructType, LongType, IntegerType
35+
from pyspark.sql.types import (
36+
BooleanType,
37+
DoubleType,
38+
FloatType,
39+
StringType,
40+
StructType,
41+
LongType,
42+
IntegerType,
43+
)
3644
from pyspark.sql.window import Window
3745

3846
from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
@@ -918,7 +926,20 @@ def astype(self, dtype) -> "Series":
918926
spark_type = as_spark_type(dtype)
919927
if not spark_type:
920928
raise ValueError("Type {} not understood".format(dtype))
921-
return self._with_new_scol(self._scol.cast(spark_type))
929+
if isinstance(spark_type, BooleanType):
930+
if isinstance(self.spark_type, StringType):
931+
scol = F.when(self._scol.isNull(), F.lit(False)).otherwise(F.length(self._scol) > 0)
932+
elif isinstance(self.spark_type, (FloatType, DoubleType)):
933+
scol = F.when(self._scol.isNull() | F.isnan(self._scol), F.lit(True)).otherwise(
934+
self._scol.cast(spark_type)
935+
)
936+
else:
937+
scol = F.when(self._scol.isNull(), F.lit(False)).otherwise(
938+
self._scol.cast(spark_type)
939+
)
940+
else:
941+
scol = self._scol.cast(spark_type)
942+
return self._with_new_scol(scol)
922943

923944
def getField(self, name):
924945
if not isinstance(self.spark_type, StructType):

databricks/koalas/tests/test_series.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,26 @@ def test_shift(self):
10251025
def test_astype(self):
10261026
pser = pd.Series([10, 20, 15, 30, 45], name="x")
10271027
kser = ks.Series(pser)
1028+
1029+
self.assert_eq(kser.astype(int), pser.astype(int))
1030+
self.assert_eq(kser.astype(bool), pser.astype(bool))
1031+
1032+
pser = pd.Series([10, 20, 15, 30, 45, None, np.nan], name="x")
1033+
kser = ks.Series(pser)
1034+
1035+
self.assert_eq(kser.astype(bool), pser.astype(bool))
1036+
1037+
pser = pd.Series(["hi", "hi ", " ", " \t", "", None], name="x")
1038+
kser = ks.Series(pser)
1039+
1040+
self.assert_eq(kser.astype(bool), pser.astype(bool))
1041+
self.assert_eq(kser.str.strip().astype(bool), pser.str.strip().astype(bool))
1042+
1043+
pser = pd.Series([True, False, None], name="x")
1044+
kser = ks.Series(pser)
1045+
1046+
self.assert_eq(kser.astype(bool), pser.astype(bool))
1047+
10281048
with self.assertRaisesRegex(ValueError, "Type int63 not understood"):
10291049
kser.astype("int63")
10301050

databricks/koalas/typedef.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def clean_fun(*args2):
280280
series = kser._with_new_scol(scol=col) # type: 'ks.Series'
281281
all_name_tokens = name_tokens + sorted(kw_name_tokens)
282282
name = "{}({})".format(f.__name__, ", ".join(all_name_tokens))
283-
series = series.astype(return_type).rename(name)
283+
series = series.rename(name)
284284
return series
285285

286286

0 commit comments

Comments
 (0)