Skip to content

Commit bc6aab8

Browse files
authored
Modify input check for indexers to support non-string column names. (#1808)
Enables column indexing for non-string column names as currently it is not available. ```py >>> kdf = ks.DataFrame({0: [1, 2, 3, 4, 5, 6, 7, 8, 9], 1: [4, 5, 6, 3, 2, 1, 0, 0, 0]}, index=[0, 1, 3, 5, 6, 8, 9, 9, 9]) >>> kdf.loc[:, 0] Traceback (most recent call last): ... AssertionError >>> kdf[0] Traceback (most recent call last): ... NotImplementedError: 0 ```
1 parent 2ce20ac commit bc6aab8

File tree

5 files changed

+237
-133
lines changed

5 files changed

+237
-133
lines changed

databricks/koalas/frame.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,17 @@
6464
from databricks.koalas.spark import functions as SF
6565
from databricks.koalas.spark.accessors import SparkFrameMethods, CachedSparkFrameMethods
6666
from databricks.koalas.utils import (
67-
validate_arguments_and_invoke_function,
6867
align_diff_frames,
69-
validate_bool_kwarg,
7068
column_labels_level,
69+
default_session,
70+
is_name_like_value,
7171
name_like_string,
7272
same_anchor,
7373
scol_for,
74+
validate_arguments_and_invoke_function,
7475
validate_axis,
76+
validate_bool_kwarg,
7577
verify_temp_column_name,
76-
default_session,
7778
)
7879
from databricks.koalas.generic import Frame
7980
from databricks.koalas.internal import (
@@ -10351,18 +10352,18 @@ def __getitem__(self, key):
1035110352

1035210353
if key is None:
1035310354
raise KeyError("none key")
10354-
if isinstance(key, Series):
10355+
elif isinstance(key, Series):
1035510356
return self.loc[key.astype(bool)]
10356-
elif isinstance(key, (str, tuple)):
10357-
return self.loc[:, key]
10358-
elif is_list_like(key):
10359-
return self.loc[:, list(key)]
1036010357
elif isinstance(key, slice):
1036110358
if any(type(n) == int or None for n in [key.start, key.stop]):
1036210359
# Seems like pandas Frame always uses int as positional search when slicing
1036310360
# with ints.
1036410361
return self.iloc[key]
1036510362
return self.loc[key]
10363+
elif is_name_like_value(key):
10364+
return self.loc[:, key]
10365+
elif is_list_like(key):
10366+
return self.loc[:, list(key)]
1036610367
raise NotImplementedError(key)
1036710368

1036810369
def __setitem__(self, key, value):

databricks/koalas/indexing.py

Lines changed: 60 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@
3737
)
3838
from databricks.koalas.exceptions import SparkPandasIndexingError, SparkPandasNotImplementedError
3939
from databricks.koalas.utils import (
40+
is_name_like_tuple,
41+
is_name_like_value,
4042
lazy_property,
4143
name_like_string,
42-
verify_temp_column_name,
4344
same_anchor,
4445
scol_for,
46+
verify_temp_column_name,
4547
)
4648

4749
if TYPE_CHECKING:
@@ -134,19 +136,18 @@ def __getitem__(self, key):
134136
col_sel = self._kdf_or_kser._column_label
135137

136138
if len(self._internal.index_map) == 1:
137-
if is_list_like(row_sel):
139+
if not is_name_like_value(row_sel, allow_none=False, allow_tuple=False):
138140
raise ValueError("At based indexing on a single index can only have a single value")
139141
row_sel = (row_sel,)
140-
elif not isinstance(row_sel, tuple):
141-
raise ValueError("At based indexing on multi-index can only have tuple values")
142-
if not (
143-
col_sel is None
144-
or isinstance(col_sel, str)
145-
or (isinstance(col_sel, tuple) and all(isinstance(col, str) for col in col_sel))
146-
):
147-
raise ValueError("At based indexing on multi-index can only have tuple values")
148-
if isinstance(col_sel, str):
149-
col_sel = (col_sel,)
142+
else:
143+
if not is_name_like_tuple(row_sel, allow_none=False):
144+
raise ValueError("At based indexing on multi-index can only have tuple values")
145+
146+
if col_sel is not None:
147+
if not is_name_like_value(col_sel, allow_none=False):
148+
raise ValueError("At based indexing on multi-index can only have tuple values")
149+
if not is_name_like_tuple(col_sel):
150+
col_sel = (col_sel,)
150151

151152
cond = reduce(
152153
lambda x, y: x & y,
@@ -262,18 +263,16 @@ def _select_rows(
262263
# If slice is None - select everything, so nothing to do
263264
return None, None, None
264265
return self._select_rows_by_slice(rows_sel)
265-
elif isinstance(rows_sel, (str, tuple)):
266+
elif isinstance(rows_sel, tuple):
266267
return self._select_rows_else(rows_sel)
267-
elif isinstance(rows_sel, Iterable):
268+
elif is_list_like(rows_sel):
268269
return self._select_rows_by_iterable(rows_sel)
269270
else:
270271
return self._select_rows_else(rows_sel)
271272

272273
def _select_cols(
273-
self, cols_sel: Any, missing_keys: Optional[List[Tuple[str, ...]]] = None
274-
) -> Tuple[
275-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
276-
]:
274+
self, cols_sel: Any, missing_keys: Optional[List[Tuple]] = None
275+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
277276
"""
278277
Dispatch the logic for select columns to more specific methods by `cols_sel` argument types.
279278
@@ -307,9 +306,9 @@ def _select_cols(
307306
data_spark_columns = self._internal.data_spark_columns
308307
return column_labels, data_spark_columns, False, None
309308
return self._select_cols_by_slice(cols_sel, missing_keys)
310-
elif isinstance(cols_sel, (str, tuple)):
309+
elif isinstance(cols_sel, tuple):
311310
return self._select_cols_else(cols_sel, missing_keys)
312-
elif isinstance(cols_sel, Iterable):
311+
elif is_list_like(cols_sel):
313312
return self._select_cols_by_iterable(cols_sel, missing_keys)
314313
else:
315314
return self._select_cols_else(cols_sel, missing_keys)
@@ -355,46 +354,36 @@ def _select_rows_else(
355354

356355
@abstractmethod
357356
def _select_cols_by_series(
358-
self, cols_sel: "Series", missing_keys: Optional[List[Tuple[str, ...]]]
359-
) -> Tuple[
360-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
361-
]:
357+
self, cols_sel: "Series", missing_keys: Optional[List[Tuple]]
358+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
362359
""" Select columns by `Series` type key. """
363360
pass
364361

365362
@abstractmethod
366363
def _select_cols_by_spark_column(
367-
self, cols_sel: spark.Column, missing_keys: Optional[List[Tuple[str, ...]]]
368-
) -> Tuple[
369-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
370-
]:
364+
self, cols_sel: spark.Column, missing_keys: Optional[List[Tuple]]
365+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
371366
""" Select columns by Spark `Column` type key. """
372367
pass
373368

374369
@abstractmethod
375370
def _select_cols_by_slice(
376-
self, cols_sel: slice, missing_keys: Optional[List[Tuple[str, ...]]]
377-
) -> Tuple[
378-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
379-
]:
371+
self, cols_sel: slice, missing_keys: Optional[List[Tuple]]
372+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
380373
""" Select columns by `slice` type key. """
381374
pass
382375

383376
@abstractmethod
384377
def _select_cols_by_iterable(
385-
self, cols_sel: Iterable, missing_keys: Optional[List[Tuple[str, ...]]]
386-
) -> Tuple[
387-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
388-
]:
378+
self, cols_sel: Iterable, missing_keys: Optional[List[Tuple]]
379+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
389380
""" Select columns by `Iterable` type key. """
390381
pass
391382

392383
@abstractmethod
393384
def _select_cols_else(
394-
self, cols_sel: Any, missing_keys: Optional[List[Tuple[str, ...]]]
395-
) -> Tuple[
396-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
397-
]:
385+
self, cols_sel: Any, missing_keys: Optional[List[Tuple]]
386+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
398387
""" Select columns by other type key. """
399388
pass
400389

@@ -684,7 +673,7 @@ def __setitem__(self, key, value):
684673

685674
column_labels = self._internal.column_labels.copy()
686675
for label in missing_keys:
687-
if isinstance(label, str):
676+
if not is_name_like_tuple(label):
688677
label = (label,)
689678
if len(label) < self._internal.column_labels_level:
690679
label = tuple(
@@ -1073,9 +1062,7 @@ def _select_rows_else(
10731062

10741063
def _get_from_multiindex_column(
10751064
self, key, missing_keys, labels=None, recursed=0
1076-
) -> Tuple[
1077-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
1078-
]:
1065+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
10791066
""" Select columns from multi-index columns. """
10801067
assert isinstance(key, tuple)
10811068
if labels is None:
@@ -1121,40 +1108,32 @@ def _get_from_multiindex_column(
11211108
return column_labels, data_spark_columns, returns_series, series_name
11221109

11231110
def _select_cols_by_series(
1124-
self, cols_sel: "Series", missing_keys: Optional[List[Tuple[str, ...]]]
1125-
) -> Tuple[
1126-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
1127-
]:
1111+
self, cols_sel: "Series", missing_keys: Optional[List[Tuple]]
1112+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
11281113
column_labels = [cols_sel._column_label]
11291114
data_spark_columns = [cols_sel.spark.column]
11301115
return column_labels, data_spark_columns, True, None
11311116

11321117
def _select_cols_by_spark_column(
1133-
self, cols_sel: spark.Column, missing_keys: Optional[List[Tuple[str, ...]]]
1134-
) -> Tuple[
1135-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
1136-
]:
1118+
self, cols_sel: spark.Column, missing_keys: Optional[List[Tuple]]
1119+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
11371120
column_labels = [
11381121
(self._internal.spark_frame.select(cols_sel).columns[0],)
1139-
] # type: List[Tuple[str, ...]]
1122+
] # type: List[Tuple]
11401123
data_spark_columns = [cols_sel]
11411124
return column_labels, data_spark_columns, True, None
11421125

11431126
def _select_cols_by_slice(
1144-
self, cols_sel: slice, missing_keys: Optional[List[Tuple[str, ...]]]
1145-
) -> Tuple[
1146-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
1147-
]:
1127+
self, cols_sel: slice, missing_keys: Optional[List[Tuple]]
1128+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
11481129
start, stop = self._kdf_or_kser.columns.slice_locs(start=cols_sel.start, end=cols_sel.stop)
11491130
column_labels = self._internal.column_labels[start:stop]
11501131
data_spark_columns = self._internal.data_spark_columns[start:stop]
11511132
return column_labels, data_spark_columns, False, None
11521133

11531134
def _select_cols_by_iterable(
1154-
self, cols_sel: Iterable, missing_keys: Optional[List[Tuple[str, ...]]]
1155-
) -> Tuple[
1156-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
1157-
]:
1135+
self, cols_sel: Iterable, missing_keys: Optional[List[Tuple]]
1136+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
11581137
from databricks.koalas.series import Series
11591138

11601139
if all(isinstance(key, Series) for key in cols_sel):
@@ -1165,10 +1144,14 @@ def _select_cols_by_iterable(
11651144
(self._internal.spark_frame.select(col).columns[0],) for col in cols_sel
11661145
]
11671146
data_spark_columns = list(cols_sel)
1168-
elif any(isinstance(key, str) for key in cols_sel) and any(
1169-
isinstance(key, tuple) for key in cols_sel
1147+
elif any(isinstance(key, tuple) for key in cols_sel) and any(
1148+
not is_name_like_tuple(key) for key in cols_sel
11701149
):
1171-
raise TypeError("Expected tuple, got str")
1150+
raise TypeError(
1151+
"Expected tuple, got {}".format(
1152+
type(set(key for key in cols_sel if not is_name_like_tuple(key)).pop())
1153+
)
1154+
)
11721155
else:
11731156
if missing_keys is None and all(isinstance(key, tuple) for key in cols_sel):
11741157
level = self._internal.column_labels_level
@@ -1193,11 +1176,9 @@ def _select_cols_by_iterable(
11931176
return column_labels, data_spark_columns, False, None
11941177

11951178
def _select_cols_else(
1196-
self, cols_sel: Any, missing_keys: Optional[List[Tuple[str, ...]]]
1197-
) -> Tuple[
1198-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
1199-
]:
1200-
if isinstance(cols_sel, str):
1179+
self, cols_sel: Any, missing_keys: Optional[List[Tuple]]
1180+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
1181+
if not is_name_like_tuple(cols_sel):
12011182
cols_sel = (cols_sel,)
12021183
return self._get_from_multiindex_column(cols_sel, missing_keys)
12031184

@@ -1513,30 +1494,24 @@ def _select_rows_else(
15131494
)
15141495

15151496
def _select_cols_by_series(
1516-
self, cols_sel: "Series", missing_keys: Optional[List[Tuple[str, ...]]]
1517-
) -> Tuple[
1518-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
1519-
]:
1497+
self, cols_sel: "Series", missing_keys: Optional[List[Tuple]]
1498+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
15201499
raise ValueError(
15211500
"Location based indexing can only have [integer, integer slice, "
15221501
"listlike of integers, boolean array] types, got {}".format(cols_sel)
15231502
)
15241503

15251504
def _select_cols_by_spark_column(
1526-
self, cols_sel: spark.Column, missing_keys: Optional[List[Tuple[str, ...]]]
1527-
) -> Tuple[
1528-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
1529-
]:
1505+
self, cols_sel: spark.Column, missing_keys: Optional[List[Tuple]]
1506+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
15301507
raise ValueError(
15311508
"Location based indexing can only have [integer, integer slice, "
15321509
"listlike of integers, boolean array] types, got {}".format(cols_sel)
15331510
)
15341511

15351512
def _select_cols_by_slice(
1536-
self, cols_sel: slice, missing_keys: Optional[List[Tuple[str, ...]]]
1537-
) -> Tuple[
1538-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
1539-
]:
1513+
self, cols_sel: slice, missing_keys: Optional[List[Tuple]]
1514+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
15401515
if all(
15411516
s is None or isinstance(s, int) for s in (cols_sel.start, cols_sel.stop, cols_sel.step)
15421517
):
@@ -1558,10 +1533,8 @@ def _select_cols_by_slice(
15581533
return column_labels, data_spark_columns, False, None
15591534

15601535
def _select_cols_by_iterable(
1561-
self, cols_sel: Iterable, missing_keys: Optional[List[Tuple[str, ...]]]
1562-
) -> Tuple[
1563-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
1564-
]:
1536+
self, cols_sel: Iterable, missing_keys: Optional[List[Tuple]]
1537+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
15651538
if all(isinstance(s, bool) for s in cols_sel):
15661539
cols_sel = [i for i, s in enumerate(cols_sel) if s]
15671540
if all(isinstance(s, int) for s in cols_sel):
@@ -1572,10 +1545,8 @@ def _select_cols_by_iterable(
15721545
raise TypeError("cannot perform reduce with flexible type")
15731546

15741547
def _select_cols_else(
1575-
self, cols_sel: Any, missing_keys: Optional[List[Tuple[str, ...]]]
1576-
) -> Tuple[
1577-
List[Tuple[str, ...]], Optional[List[spark.Column]], bool, Optional[Tuple[str, ...]]
1578-
]:
1548+
self, cols_sel: Any, missing_keys: Optional[List[Tuple]]
1549+
) -> Tuple[List[Tuple], Optional[List[spark.Column]], bool, Optional[Tuple]]:
15791550
if isinstance(cols_sel, int):
15801551
if cols_sel > len(self._internal.column_labels):
15811552
raise KeyError(cols_sel)

0 commit comments

Comments
 (0)