Skip to content

Commit c972da8

Browse files
authored
Add __bool__ function. (#1526)
Adding `__bool__` function explicitly to raise an error when `DataFrame`, `Series`, and `Index` are used as bool, otherwise they might be used mistakenly.
1 parent 421b80d commit c972da8

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

databricks/koalas/generic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,6 +2140,12 @@ def compute(self):
21402140
"""Alias of `to_pandas()` to mimic dask for easily porting tests."""
21412141
return self.toPandas()
21422142

2143+
def __bool__(self):
2144+
raise ValueError(
2145+
"The truth value of a {0} is ambiguous. "
2146+
"Use a.empty, a.bool(), a.item(), a.any() or a.all().".format(self.__class__.__name__)
2147+
)
2148+
21432149
@staticmethod
21442150
def _count_expr(col: spark.Column, spark_type: DataType) -> spark.Column:
21452151
# Special handle floating point types because Spark's count treats nan as a valid value,

databricks/koalas/indexes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,6 +1804,12 @@ def __xor__(self, other):
18041804
def __len__(self):
18051805
return self.size
18061806

1807+
def __bool__(self):
1808+
raise ValueError(
1809+
"The truth value of a {0} is ambiguous. "
1810+
"Use a.empty, a.bool(), a.item(), a.any() or a.all().".format(self.__class__.__name__)
1811+
)
1812+
18071813

18081814
class MultiIndex(Index):
18091815
"""

databricks/koalas/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,10 +503,11 @@ def name_like_string(name: Union[str, Tuple]) -> str:
503503

504504
def validate_axis(axis=0, none_axis=0):
505505
""" Check the given axis is valid. """
506-
if axis not in (0, 1, "index", "columns", None):
507-
raise ValueError("No axis named {0}".format(axis))
508506
# convert to numeric axis
509-
return {None: none_axis, "index": 0, "columns": 1}.get(axis, axis)
507+
axis = {None: none_axis, "index": 0, "columns": 1}.get(axis, axis)
508+
if axis not in (none_axis, 0, 1):
509+
raise ValueError("No axis named {0}".format(axis))
510+
return axis
510511

511512

512513
def validate_bool_kwarg(value, arg_name):

0 commit comments

Comments
 (0)