3737)
3838from databricks .koalas .exceptions import SparkPandasIndexingError , SparkPandasNotImplementedError
3939from databricks .koalas .utils import (
40+ is_name_like_tuple ,
41+ is_name_like_value ,
4042 lazy_property ,
4143 name_like_string ,
42- verify_temp_column_name ,
4344 same_anchor ,
4445 scol_for ,
46+ verify_temp_column_name ,
4547)
4648
4749if TYPE_CHECKING :
@@ -134,19 +136,18 @@ def __getitem__(self, key):
134136 col_sel = self ._kdf_or_kser ._column_label
135137
136138 if len (self ._internal .index_map ) == 1 :
137- if is_list_like (row_sel ):
139+ if not is_name_like_value (row_sel , allow_none = False , allow_tuple = False ):
138140 raise ValueError ("At based indexing on a single index can only have a single value" )
139141 row_sel = (row_sel ,)
140- elif not isinstance (row_sel , tuple ):
141- raise ValueError ("At based indexing on multi-index can only have tuple values" )
142- if not (
143- col_sel is None
144- or isinstance (col_sel , str )
145- or (isinstance (col_sel , tuple ) and all (isinstance (col , str ) for col in col_sel ))
146- ):
147- raise ValueError ("At based indexing on multi-index can only have tuple values" )
148- if isinstance (col_sel , str ):
149- col_sel = (col_sel ,)
142+ else :
143+ if not is_name_like_tuple (row_sel , allow_none = False ):
144+ raise ValueError ("At based indexing on multi-index can only have tuple values" )
145+
146+ if col_sel is not None :
147+ if not is_name_like_value (col_sel , allow_none = False ):
148+ raise ValueError ("At based indexing on multi-index can only have tuple values" )
149+ if not is_name_like_tuple (col_sel ):
150+ col_sel = (col_sel ,)
150151
151152 cond = reduce (
152153 lambda x , y : x & y ,
@@ -262,18 +263,16 @@ def _select_rows(
262263 # If slice is None - select everything, so nothing to do
263264 return None , None , None
264265 return self ._select_rows_by_slice (rows_sel )
265- elif isinstance (rows_sel , ( str , tuple ) ):
266+ elif isinstance (rows_sel , tuple ):
266267 return self ._select_rows_else (rows_sel )
267- elif isinstance (rows_sel , Iterable ):
268+ elif is_list_like (rows_sel ):
268269 return self ._select_rows_by_iterable (rows_sel )
269270 else :
270271 return self ._select_rows_else (rows_sel )
271272
272273 def _select_cols (
273- self , cols_sel : Any , missing_keys : Optional [List [Tuple [str , ...]]] = None
274- ) -> Tuple [
275- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
276- ]:
274+ self , cols_sel : Any , missing_keys : Optional [List [Tuple ]] = None
275+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
277276 """
278277 Dispatch the logic for select columns to more specific methods by `cols_sel` argument types.
279278
@@ -307,9 +306,9 @@ def _select_cols(
307306 data_spark_columns = self ._internal .data_spark_columns
308307 return column_labels , data_spark_columns , False , None
309308 return self ._select_cols_by_slice (cols_sel , missing_keys )
310- elif isinstance (cols_sel , ( str , tuple ) ):
309+ elif isinstance (cols_sel , tuple ):
311310 return self ._select_cols_else (cols_sel , missing_keys )
312- elif isinstance (cols_sel , Iterable ):
311+ elif is_list_like (cols_sel ):
313312 return self ._select_cols_by_iterable (cols_sel , missing_keys )
314313 else :
315314 return self ._select_cols_else (cols_sel , missing_keys )
@@ -355,46 +354,36 @@ def _select_rows_else(
355354
356355 @abstractmethod
357356 def _select_cols_by_series (
358- self , cols_sel : "Series" , missing_keys : Optional [List [Tuple [str , ...]]]
359- ) -> Tuple [
360- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
361- ]:
357+ self , cols_sel : "Series" , missing_keys : Optional [List [Tuple ]]
358+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
362359 """ Select columns by `Series` type key. """
363360 pass
364361
365362 @abstractmethod
366363 def _select_cols_by_spark_column (
367- self , cols_sel : spark .Column , missing_keys : Optional [List [Tuple [str , ...]]]
368- ) -> Tuple [
369- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
370- ]:
364+ self , cols_sel : spark .Column , missing_keys : Optional [List [Tuple ]]
365+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
371366 """ Select columns by Spark `Column` type key. """
372367 pass
373368
374369 @abstractmethod
375370 def _select_cols_by_slice (
376- self , cols_sel : slice , missing_keys : Optional [List [Tuple [str , ...]]]
377- ) -> Tuple [
378- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
379- ]:
371+ self , cols_sel : slice , missing_keys : Optional [List [Tuple ]]
372+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
380373 """ Select columns by `slice` type key. """
381374 pass
382375
383376 @abstractmethod
384377 def _select_cols_by_iterable (
385- self , cols_sel : Iterable , missing_keys : Optional [List [Tuple [str , ...]]]
386- ) -> Tuple [
387- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
388- ]:
378+ self , cols_sel : Iterable , missing_keys : Optional [List [Tuple ]]
379+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
389380 """ Select columns by `Iterable` type key. """
390381 pass
391382
392383 @abstractmethod
393384 def _select_cols_else (
394- self , cols_sel : Any , missing_keys : Optional [List [Tuple [str , ...]]]
395- ) -> Tuple [
396- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
397- ]:
385+ self , cols_sel : Any , missing_keys : Optional [List [Tuple ]]
386+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
398387 """ Select columns by other type key. """
399388 pass
400389
@@ -684,7 +673,7 @@ def __setitem__(self, key, value):
684673
685674 column_labels = self ._internal .column_labels .copy ()
686675 for label in missing_keys :
687- if isinstance (label , str ):
676+ if not is_name_like_tuple (label ):
688677 label = (label ,)
689678 if len (label ) < self ._internal .column_labels_level :
690679 label = tuple (
@@ -1073,9 +1062,7 @@ def _select_rows_else(
10731062
10741063 def _get_from_multiindex_column (
10751064 self , key , missing_keys , labels = None , recursed = 0
1076- ) -> Tuple [
1077- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
1078- ]:
1065+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
10791066 """ Select columns from multi-index columns. """
10801067 assert isinstance (key , tuple )
10811068 if labels is None :
@@ -1121,40 +1108,32 @@ def _get_from_multiindex_column(
11211108 return column_labels , data_spark_columns , returns_series , series_name
11221109
11231110 def _select_cols_by_series (
1124- self , cols_sel : "Series" , missing_keys : Optional [List [Tuple [str , ...]]]
1125- ) -> Tuple [
1126- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
1127- ]:
1111+ self , cols_sel : "Series" , missing_keys : Optional [List [Tuple ]]
1112+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
11281113 column_labels = [cols_sel ._column_label ]
11291114 data_spark_columns = [cols_sel .spark .column ]
11301115 return column_labels , data_spark_columns , True , None
11311116
11321117 def _select_cols_by_spark_column (
1133- self , cols_sel : spark .Column , missing_keys : Optional [List [Tuple [str , ...]]]
1134- ) -> Tuple [
1135- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
1136- ]:
1118+ self , cols_sel : spark .Column , missing_keys : Optional [List [Tuple ]]
1119+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
11371120 column_labels = [
11381121 (self ._internal .spark_frame .select (cols_sel ).columns [0 ],)
1139- ] # type: List[Tuple[str, ...] ]
1122+ ] # type: List[Tuple]
11401123 data_spark_columns = [cols_sel ]
11411124 return column_labels , data_spark_columns , True , None
11421125
11431126 def _select_cols_by_slice (
1144- self , cols_sel : slice , missing_keys : Optional [List [Tuple [str , ...]]]
1145- ) -> Tuple [
1146- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
1147- ]:
1127+ self , cols_sel : slice , missing_keys : Optional [List [Tuple ]]
1128+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
11481129 start , stop = self ._kdf_or_kser .columns .slice_locs (start = cols_sel .start , end = cols_sel .stop )
11491130 column_labels = self ._internal .column_labels [start :stop ]
11501131 data_spark_columns = self ._internal .data_spark_columns [start :stop ]
11511132 return column_labels , data_spark_columns , False , None
11521133
11531134 def _select_cols_by_iterable (
1154- self , cols_sel : Iterable , missing_keys : Optional [List [Tuple [str , ...]]]
1155- ) -> Tuple [
1156- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
1157- ]:
1135+ self , cols_sel : Iterable , missing_keys : Optional [List [Tuple ]]
1136+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
11581137 from databricks .koalas .series import Series
11591138
11601139 if all (isinstance (key , Series ) for key in cols_sel ):
@@ -1165,10 +1144,14 @@ def _select_cols_by_iterable(
11651144 (self ._internal .spark_frame .select (col ).columns [0 ],) for col in cols_sel
11661145 ]
11671146 data_spark_columns = list (cols_sel )
1168- elif any (isinstance (key , str ) for key in cols_sel ) and any (
1169- isinstance (key , tuple ) for key in cols_sel
1147+ elif any (isinstance (key , tuple ) for key in cols_sel ) and any (
1148+ not is_name_like_tuple (key ) for key in cols_sel
11701149 ):
1171- raise TypeError ("Expected tuple, got str" )
1150+ raise TypeError (
1151+ "Expected tuple, got {}" .format (
1152+ type (set (key for key in cols_sel if not is_name_like_tuple (key )).pop ())
1153+ )
1154+ )
11721155 else :
11731156 if missing_keys is None and all (isinstance (key , tuple ) for key in cols_sel ):
11741157 level = self ._internal .column_labels_level
@@ -1193,11 +1176,9 @@ def _select_cols_by_iterable(
11931176 return column_labels , data_spark_columns , False , None
11941177
11951178 def _select_cols_else (
1196- self , cols_sel : Any , missing_keys : Optional [List [Tuple [str , ...]]]
1197- ) -> Tuple [
1198- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
1199- ]:
1200- if isinstance (cols_sel , str ):
1179+ self , cols_sel : Any , missing_keys : Optional [List [Tuple ]]
1180+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
1181+ if not is_name_like_tuple (cols_sel ):
12011182 cols_sel = (cols_sel ,)
12021183 return self ._get_from_multiindex_column (cols_sel , missing_keys )
12031184
@@ -1513,30 +1494,24 @@ def _select_rows_else(
15131494 )
15141495
15151496 def _select_cols_by_series (
1516- self , cols_sel : "Series" , missing_keys : Optional [List [Tuple [str , ...]]]
1517- ) -> Tuple [
1518- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
1519- ]:
1497+ self , cols_sel : "Series" , missing_keys : Optional [List [Tuple ]]
1498+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
15201499 raise ValueError (
15211500 "Location based indexing can only have [integer, integer slice, "
15221501 "listlike of integers, boolean array] types, got {}" .format (cols_sel )
15231502 )
15241503
15251504 def _select_cols_by_spark_column (
1526- self , cols_sel : spark .Column , missing_keys : Optional [List [Tuple [str , ...]]]
1527- ) -> Tuple [
1528- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
1529- ]:
1505+ self , cols_sel : spark .Column , missing_keys : Optional [List [Tuple ]]
1506+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
15301507 raise ValueError (
15311508 "Location based indexing can only have [integer, integer slice, "
15321509 "listlike of integers, boolean array] types, got {}" .format (cols_sel )
15331510 )
15341511
15351512 def _select_cols_by_slice (
1536- self , cols_sel : slice , missing_keys : Optional [List [Tuple [str , ...]]]
1537- ) -> Tuple [
1538- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
1539- ]:
1513+ self , cols_sel : slice , missing_keys : Optional [List [Tuple ]]
1514+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
15401515 if all (
15411516 s is None or isinstance (s , int ) for s in (cols_sel .start , cols_sel .stop , cols_sel .step )
15421517 ):
@@ -1558,10 +1533,8 @@ def _select_cols_by_slice(
15581533 return column_labels , data_spark_columns , False , None
15591534
15601535 def _select_cols_by_iterable (
1561- self , cols_sel : Iterable , missing_keys : Optional [List [Tuple [str , ...]]]
1562- ) -> Tuple [
1563- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
1564- ]:
1536+ self , cols_sel : Iterable , missing_keys : Optional [List [Tuple ]]
1537+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
15651538 if all (isinstance (s , bool ) for s in cols_sel ):
15661539 cols_sel = [i for i , s in enumerate (cols_sel ) if s ]
15671540 if all (isinstance (s , int ) for s in cols_sel ):
@@ -1572,10 +1545,8 @@ def _select_cols_by_iterable(
15721545 raise TypeError ("cannot perform reduce with flexible type" )
15731546
15741547 def _select_cols_else (
1575- self , cols_sel : Any , missing_keys : Optional [List [Tuple [str , ...]]]
1576- ) -> Tuple [
1577- List [Tuple [str , ...]], Optional [List [spark .Column ]], bool , Optional [Tuple [str , ...]]
1578- ]:
1548+ self , cols_sel : Any , missing_keys : Optional [List [Tuple ]]
1549+ ) -> Tuple [List [Tuple ], Optional [List [spark .Column ]], bool , Optional [Tuple ]]:
15791550 if isinstance (cols_sel , int ):
15801551 if cols_sel > len (self ._internal .column_labels ):
15811552 raise KeyError (cols_sel )
0 commit comments