Skip to content

Commit 1e96eb8

Browse files
ueshinHyukjinKwon
authored andcommitted
Fix filter for multi-index columns support. (#859)
1 parent 8d8b79e commit 1e96eb8

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

databricks/koalas/frame.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6844,9 +6844,9 @@ def filter(self, items=None, like=None, regex=None, axis=None):
68446844
sdf = sdf.filter(index_scols[0].contains(like))
68456845
return DataFrame(self._internal.copy(sdf=sdf))
68466846
elif axis in ('columns', 1, None):
6847-
data_columns = self._internal.data_columns
6848-
output_columns = [c for c in data_columns if like in c]
6849-
return self[output_columns]
6847+
column_index = self._internal.column_index
6848+
output_idx = [idx for idx in column_index if any(like in i for i in idx)]
6849+
return self[output_idx]
68506850
elif regex is not None:
68516851
if axis in ('index', 0):
68526852
# TODO: support multi-index here
@@ -6855,10 +6855,11 @@ def filter(self, items=None, like=None, regex=None, axis=None):
68556855
sdf = sdf.filter(index_scols[0].rlike(regex))
68566856
return DataFrame(self._internal.copy(sdf=sdf))
68576857
elif axis in ('columns', 1, None):
6858-
data_columns = self._internal.data_columns
6858+
column_index = self._internal.column_index
68596859
matcher = re.compile(regex)
6860-
output_columns = [c for c in data_columns if matcher.search(c) is not None]
6861-
return self[output_columns]
6860+
output_idx = [idx for idx in column_index
6861+
if any(matcher.search(i) is not None for i in idx)]
6862+
return self[output_idx]
68626863
else:
68636864
raise TypeError("Must pass either `items`, `like`, or `regex`")
68646865

databricks/koalas/internal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -720,14 +720,14 @@ def from_pandas(pdf: pd.DataFrame) -> '_InternalFrame':
720720
for i, name in enumerate(index.names)]
721721
else:
722722
name = index.name
723-
index_map = [(name if name is not None else '__index_level_0__',
723+
index_map = [(str(name) if name is not None else '__index_level_0__',
724724
name if name is None or isinstance(name, tuple) else (name,))]
725725

726726
index_columns = [index_column for index_column, _ in index_map]
727727

728728
reset_index = pdf.reset_index()
729729
reset_index.columns = index_columns + data_columns
730-
schema = StructType([StructField(name, infer_pd_series_spark_type(col),
730+
schema = StructType([StructField(str(name), infer_pd_series_spark_type(col),
731731
nullable=bool(col.isnull().any()))
732732
for name, col in reset_index.iteritems()])
733733
for name, col in reset_index.iteritems():

databricks/koalas/tests/test_dataframe.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,10 +1863,10 @@ def test_filter(self):
18631863
self.assert_eq(kdf.filter(like='b', axis='index'), pdf.filter(like='b', axis='index'))
18641864
self.assert_eq(kdf.filter(like='c', axis='columns'), pdf.filter(like='c', axis='columns'))
18651865

1866-
self.assert_eq(
1867-
kdf.filter(regex='b.*', axis='index'), pdf.filter(regex='b.*', axis='index'))
1868-
self.assert_eq(
1869-
kdf.filter(regex='b.*', axis='columns'), pdf.filter(regex='b.*', axis='columns'))
1866+
self.assert_eq(kdf.filter(regex='b.*', axis='index'),
1867+
pdf.filter(regex='b.*', axis='index'))
1868+
self.assert_eq(kdf.filter(regex='b.*', axis='columns'),
1869+
pdf.filter(regex='b.*', axis='columns'))
18701870

18711871
pdf = pdf.set_index('ba', append=True)
18721872
kdf = ks.from_pandas(pdf)
@@ -1892,6 +1892,33 @@ def test_filter(self):
18921892
with self.assertRaisesRegex(TypeError, "mutually exclusive"):
18931893
kdf.filter(regex='b.*', like="aaa")
18941894

1895+
# multi-index columns
1896+
pdf = pd.DataFrame({
1897+
('x', 'aa'): ['aa', 'ab', 'bc', 'bd', 'ce'],
1898+
('x', 'ba'): [1, 2, 3, 4, 5],
1899+
('y', 'cb'): [1., 2., 3., 4., 5.],
1900+
('z', 'db'): [1., np.nan, 3., np.nan, 5.],
1901+
})
1902+
pdf = pdf.set_index(('x', 'aa'))
1903+
kdf = ks.from_pandas(pdf)
1904+
1905+
self.assert_eq(
1906+
kdf.filter(items=['ab', 'aa'], axis=0).sort_index(),
1907+
pdf.filter(items=['ab', 'aa'], axis=0).sort_index())
1908+
self.assert_eq(
1909+
kdf.filter(items=[('x', 'ba'), ('z', 'db')], axis=1).sort_index(),
1910+
pdf.filter(items=[('x', 'ba'), ('z', 'db')], axis=1).sort_index())
1911+
1912+
self.assert_eq(kdf.filter(like='b', axis='index'),
1913+
pdf.filter(like='b', axis='index'))
1914+
self.assert_eq(kdf.filter(like='c', axis='columns'),
1915+
pdf.filter(like='c', axis='columns'))
1916+
1917+
self.assert_eq(kdf.filter(regex='b.*', axis='index'),
1918+
pdf.filter(regex='b.*', axis='index'))
1919+
self.assert_eq(kdf.filter(regex='b.*', axis='columns'),
1920+
pdf.filter(regex='b.*', axis='columns'))
1921+
18951922
def test_pipe(self):
18961923
kdf = ks.DataFrame({'category': ['A', 'A', 'B'],
18971924
'col1': [1, 2, 3],

0 commit comments

Comments
 (0)