|
17 | 17 | """ |
18 | 18 | A loc indexer for Koalas DataFrame/Series. |
19 | 19 | """ |
20 | | -from collections import OrderedDict |
| 20 | +from collections import OrderedDict, Iterable |
21 | 21 | from functools import reduce |
22 | 22 |
|
23 | 23 | from pandas.api.types import is_list_like |
24 | 24 | from pyspark import sql as spark |
25 | 25 | from pyspark.sql import functions as F |
26 | 26 | from pyspark.sql.types import BooleanType, LongType |
27 | 27 | from pyspark.sql.utils import AnalysisException |
| 28 | +import numpy as np |
28 | 29 |
|
29 | 30 | from databricks.koalas.internal import _InternalFrame, NATURAL_ORDER_COLUMN_NAME |
30 | 31 | from databricks.koalas.exceptions import SparkPandasIndexingError, SparkPandasNotImplementedError |
@@ -991,10 +992,47 @@ def verify_type(i): |
991 | 992 | elif isinstance(rows_sel, int): |
992 | 993 | sdf = self._internal.spark_frame |
993 | 994 | return (sdf[self._sequence_col] == rows_sel), None, 0 |
| 995 | + elif isinstance(rows_sel, Iterable): |
| 996 | + sdf = self._internal.spark_frame |
| 997 | + |
| 998 | + if any( |
| 999 | + isinstance(key, (int, np.int, np.int64, np.int32)) and key < 0 for key in rows_sel |
| 1000 | + ): |
| 1001 | + offset = sdf.count() |
| 1002 | + else: |
| 1003 | + offset = 0 |
| 1004 | + |
| 1005 | + new_rows_sel = [] |
| 1006 | + for key in list(rows_sel): |
| 1007 | + if not isinstance(key, (int, np.int, np.int64, np.int32)): |
| 1008 | + raise TypeError( |
| 1009 | + "cannot do positional indexing with these indexers [{}] of {}".format( |
| 1010 | + key, type(key) |
| 1011 | + ) |
| 1012 | + ) |
| 1013 | + if key < 0: |
| 1014 | + key = key + offset |
| 1015 | + new_rows_sel.append(key) |
| 1016 | + |
| 1017 | + if len(new_rows_sel) != len(set(new_rows_sel)): |
| 1018 | + raise NotImplementedError( |
| 1019 | + "Duplicated row selection is not currently supported; " |
| 1020 | + "however, normalised index was [%s]" % new_rows_sel |
| 1021 | + ) |
| 1022 | + |
| 1023 | + sequence_scol = sdf[self._sequence_col] |
| 1024 | + cond = [] |
| 1025 | + for key in new_rows_sel: |
| 1026 | + cond.append(sequence_scol == F.lit(int(key)).cast(LongType())) |
| 1027 | + |
| 1028 | + if len(cond) == 0: |
| 1029 | + cond = [F.lit(False)] |
| 1030 | + return reduce(lambda x, y: x | y, cond), None, None |
994 | 1031 | else: |
995 | 1032 | iLocIndexer._raiseNotImplemented( |
996 | | - ".iloc requires numeric slice or conditional " |
997 | | - "boolean Index, got {}".format(type(rows_sel)) |
| 1033 | + ".iloc requires numeric slice, conditional " |
| 1034 | + "boolean Index or a sequence of positions as int, " |
| 1035 | + "got {}".format(type(rows_sel)) |
998 | 1036 | ) |
999 | 1037 |
|
1000 | 1038 | def _select_cols(self, cols_sel): |
|
0 commit comments