Skip to content

Commit eb5427c

Browse files
authored
Fix drop for multi-index columns support. (#658)
1 parent 3f0911d commit eb5427c

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

databricks/koalas/frame.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4087,7 +4087,8 @@ def count(self, axis=None):
40874087
return self._reduce_for_stat_function(
40884088
_Frame._count_expr, name="count", axis=axis, numeric_only=False)
40894089

4090-
def drop(self, labels=None, axis=1, columns: Union[str, List[str]] = None):
4090+
def drop(self, labels=None, axis=1,
4091+
columns: Union[str, Tuple[str], List[str], List[Tuple[str]]] = None):
40914092
"""
40924093
Drop specified labels from columns.
40934094
@@ -4150,11 +4151,21 @@ def drop(self, labels=None, axis=1, columns: Union[str, List[str]] = None):
41504151
raise NotImplementedError("Drop currently only works for axis=1")
41514152
elif columns is not None:
41524153
if isinstance(columns, str):
4154+
columns = [(columns,)]
4155+
elif isinstance(columns, tuple):
41534156
columns = [columns]
4154-
sdf = self._sdf.drop(*columns)
4155-
internal = self._internal.copy(
4156-
sdf=sdf,
4157-
data_columns=[column for column in self.columns if column not in columns])
4157+
else:
4158+
columns = [col if isinstance(col, tuple) else (col,) for col in columns]
4159+
drop_column_index = set(idx for idx in self._internal.column_index
4160+
for col in columns
4161+
if idx[:len(col)] == col)
4162+
if len(drop_column_index) == 0:
4163+
raise KeyError(columns)
4164+
cols, idx = zip(*((column, idx)
4165+
for column, idx
4166+
in zip(self._internal.data_columns, self._internal.column_index)
4167+
if idx not in drop_column_index))
4168+
internal = self._internal.copy(data_columns=list(cols), column_index=list(idx))
41584169
return DataFrame(internal)
41594170
else:
41604171
raise ValueError("Need to specify at least one of 'labels' or 'columns'")

databricks/koalas/tests/test_dataframe.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,17 @@ def test_drop(self):
378378
expected_output = pd.DataFrame({'y': [3, 4], 'z': [5, 6]})
379379
self.assert_eq(kdf.drop(labels=['x'], columns=['y']), expected_output)
380380

381+
columns = pd.MultiIndex.from_tuples([('a', 'x'), ('a', 'y'), ('b', 'z')])
382+
kdf.columns = columns
383+
pdf = kdf.to_pandas()
384+
385+
self.assert_eq(kdf.drop(columns='a'), pdf.drop(columns='a'))
386+
self.assert_eq(kdf.drop(columns=('a', 'x')), pdf.drop(columns=('a', 'x')))
387+
self.assert_eq(kdf.drop(columns=[('a', 'x'), 'b']), pdf.drop(columns=[('a', 'x'), 'b']))
388+
389+
self.assertRaises(KeyError, lambda: kdf.drop(columns='c'))
390+
self.assertRaises(KeyError, lambda: kdf.drop(columns=('a', 'z')))
391+
381392
def test_dropna(self):
382393
pdf = pd.DataFrame({'x': [np.nan, 2, 3, 4, np.nan, 6],
383394
'y': [1, 2, np.nan, 4, np.nan, np.nan],

0 commit comments

Comments
 (0)