Skip to content

Commit 37e783c

Browse files
authored
Fix column access with named column index. (#629)
Fix an issue on the comment #621 (review).
1 parent 7fc7f84 commit 37e783c

File tree

3 files changed

+67
-28
lines changed

3 files changed

+67
-28
lines changed

databricks/koalas/frame.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3685,13 +3685,12 @@ def pivot(self, index=None, columns=None, values=None):
36853685
def columns(self):
36863686
"""The column labels of the DataFrame."""
36873687
if self._internal.column_index is not None:
3688-
if self._internal.column_index_names is not None:
3689-
return pd.MultiIndex.from_tuples(self._internal.column_index,
3690-
names=self._internal.column_index_names)
3691-
else:
3692-
return pd.MultiIndex.from_tuples(self._internal.column_index)
3688+
columns = pd.MultiIndex.from_tuples(self._internal.column_index)
36933689
else:
3694-
return pd.Index(self._internal.data_columns)
3690+
columns = pd.Index(self._internal.data_columns)
3691+
if self._internal.column_index_names is not None:
3692+
columns.names = self._internal.column_index_names
3693+
return columns
36953694

36963695
@columns.setter
36973696
def columns(self, columns):
@@ -3702,17 +3701,24 @@ def columns(self, columns):
37023701
raise ValueError(
37033702
"Length mismatch: Expected axis has %d elements, new values have %d elements"
37043703
% (len(old_names), len(column_index)))
3705-
self._internal = self._internal.copy(column_index=column_index)
3704+
column_index_names = columns.names
3705+
self._internal = self._internal.copy(column_index=column_index,
3706+
column_index_names=column_index_names)
37063707
else:
37073708
old_names = self._internal.data_columns
37083709
if len(old_names) != len(columns):
37093710
raise ValueError(
37103711
"Length mismatch: Expected axis has %d elements, new values have %d elements"
37113712
% (len(old_names), len(columns)))
3713+
if isinstance(columns, pd.Index):
3714+
column_index_names = columns.names
3715+
else:
3716+
column_index_names = None
37123717
sdf = self._sdf.select(self._internal.index_scols +
37133718
[self._internal.scol_for(old_name).alias(new_name)
37143719
for (old_name, new_name) in zip(old_names, columns)])
3715-
self._internal = self._internal.copy(sdf=sdf, data_columns=columns, column_index=None)
3720+
self._internal = self._internal.copy(sdf=sdf, data_columns=columns, column_index=None,
3721+
column_index_names=column_index_names)
37163722

37173723
@property
37183724
def dtypes(self):
@@ -6226,6 +6232,13 @@ def _get_from_multiindex_column(self, key):
62266232
recursive = True
62276233
for i, (col, idx) in enumerate(columns):
62286234
columns[i] = (col, tuple([str(key), *idx[1:]]))
6235+
column_index_names = None
6236+
if self._internal.column_index_names is not None:
6237+
# Manage column index names
6238+
column_index_level = set(len(idx) for _, idx in columns)
6239+
assert len(column_index_level) == 1
6240+
column_index_level = list(column_index_level)[0]
6241+
column_index_names = self._internal.column_index_names[-column_index_level:]
62296242
if all(len(idx) == 1 for _, idx in columns):
62306243
# If len(idx) == 1, then the result is not MultiIndex anymore
62316244
sdf = self._sdf.select(self._internal.index_scols +
@@ -6234,15 +6247,17 @@ def _get_from_multiindex_column(self, key):
62346247
kdf_or_ser = DataFrame(self._internal.copy(
62356248
sdf=sdf,
62366249
data_columns=[idx[0] for _, idx in columns],
6237-
column_index=None))
6250+
column_index=None,
6251+
column_index_names=column_index_names))
62386252
else:
62396253
# Otherwise, the result is still MultiIndex and need to manage column_index.
62406254
sdf = self._sdf.select(self._internal.index_scols +
62416255
[self._internal.scol_for(col) for col, _ in columns])
62426256
kdf_or_ser = DataFrame(self._internal.copy(
62436257
sdf=sdf,
62446258
data_columns=[col for col, _ in columns],
6245-
column_index=[idx for _, idx in columns]))
6259+
column_index=[idx for _, idx in columns],
6260+
column_index_names=column_index_names))
62466261
if recursive:
62476262
kdf_or_ser = kdf_or_ser._pd_getitem(str(key))
62486263
if isinstance(kdf_or_ser, Series):

databricks/koalas/internal.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -491,11 +491,9 @@ def pandas_df(self):
491491
pdf = pdf[self.data_columns]
492492

493493
if self._column_index is not None:
494-
if self.column_index_names is not None:
495-
pdf.columns = pd.MultiIndex.from_tuples(self._column_index,
496-
names=self.column_index_names)
497-
else:
498-
pdf.columns = pd.MultiIndex.from_tuples(self._column_index)
494+
pdf.columns = pd.MultiIndex.from_tuples(self._column_index)
495+
if self._column_index_names is not None:
496+
pdf.columns.names = self._column_index_names
499497

500498
index_names = self.index_names
501499
if len(index_names) > 0:
@@ -547,13 +545,9 @@ def from_pandas(pdf: pd.DataFrame) -> '_InternalFrame':
547545
data_columns = [str(col) for col in columns]
548546
if isinstance(columns, pd.MultiIndex):
549547
column_index = columns.tolist()
550-
if columns.names is not None:
551-
column_index_names = columns.names
552-
else:
553-
column_index_names = None
554548
else:
555549
column_index = None
556-
column_index_names = None
550+
column_index_names = columns.names
557551

558552
index = pdf.index
559553

databricks/koalas/tests/test_dataframe.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,25 @@ def test_dataframe_multiindex_columns(self):
100100
self.assert_eq(kdf[('x', 'a')], pdf[('x', 'a')])
101101
self.assert_eq(kdf[('x', 'a', '1')], pdf[('x', 'a', '1')])
102102

103-
def test_dataframe_multiindex_names_level(self):
104-
columns = pd.MultiIndex.from_tuples([('X', 'A'), ('X', 'B'), ('Y', 'C'), ('Y', 'D')],
105-
names=['lvl_1', 'lvl_2'])
106-
kdf = ks.DataFrame([[1, 2, 3, 4],
107-
[5, 6, 7, 8],
108-
[9, 10, 11, 12],
109-
[13, 14, 15, 16],
110-
[17, 18, 19, 20]], columns=columns)
103+
def test_dataframe_column_level_name(self):
104+
column = pd.Index(['A', 'B', 'C'], name='X')
105+
pdf = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=column)
106+
kdf = ks.from_pandas(pdf)
111107

108+
self.assert_eq(kdf, pdf)
109+
self.assert_eq(kdf.columns.names, pdf.columns.names)
110+
self.assert_eq(kdf.to_pandas().columns.names, pdf.columns.names)
111+
112+
def test_dataframe_multiindex_names_level(self):
113+
columns = pd.MultiIndex.from_tuples([('X', 'A', 'Z'), ('X', 'B', 'Z'),
114+
('Y', 'C', 'Z'), ('Y', 'D', 'Z')],
115+
names=['lvl_1', 'lvl_2', 'lv_3'])
112116
pdf = pd.DataFrame([[1, 2, 3, 4],
113117
[5, 6, 7, 8],
114118
[9, 10, 11, 12],
115119
[13, 14, 15, 16],
116120
[17, 18, 19, 20]], columns=columns)
121+
kdf = ks.from_pandas(pdf)
117122

118123
self.assert_eq(kdf.columns.names, pdf.columns.names)
119124
self.assert_eq(kdf.to_pandas().columns.names, pdf.columns.names)
@@ -125,6 +130,17 @@ def test_dataframe_multiindex_names_level(self):
125130
'be list-like or None for a MultiIndex'):
126131
ks.DataFrame(kdf1._internal.copy(column_index_names='level'))
127132

133+
self.assert_eq(kdf['X'], pdf['X'])
134+
self.assert_eq(kdf['X'].columns.names, pdf['X'].columns.names)
135+
self.assert_eq(kdf['X'].to_pandas().columns.names, pdf['X'].columns.names)
136+
self.assert_eq(kdf['X']['A'], pdf['X']['A'])
137+
self.assert_eq(kdf['X']['A'].columns.names, pdf['X']['A'].columns.names)
138+
self.assert_eq(kdf['X']['A'].to_pandas().columns.names, pdf['X']['A'].columns.names)
139+
self.assert_eq(kdf[('X', 'A')], pdf[('X', 'A')])
140+
self.assert_eq(kdf[('X', 'A')].columns.names, pdf[('X', 'A')].columns.names)
141+
self.assert_eq(kdf[('X', 'A')].to_pandas().columns.names, pdf[('X', 'A')].columns.names)
142+
self.assert_eq(kdf[('X', 'A', 'Z')], pdf[('X', 'A', 'Z')])
143+
128144
def test_reset_index_with_multiindex_columns(self):
129145
index = pd.MultiIndex.from_tuples([('bird', 'falcon'),
130146
('bird', 'parrot'),
@@ -290,6 +306,13 @@ def test_rename_columns(self):
290306
self.assert_eq(kdf.columns, pd.Index(['x', 'y']))
291307
self.assert_eq(kdf, pdf)
292308

309+
columns = pdf.columns
310+
columns.name = 'lvl_1'
311+
312+
kdf.columns = columns
313+
self.assert_eq(kdf.columns.names, ['lvl_1'])
314+
self.assert_eq(kdf, pdf)
315+
293316
msg = "Length mismatch: Expected axis has 2 elements, new values have 4 elements"
294317
with self.assertRaisesRegex(ValueError, msg):
295318
kdf.columns = [1, 2, 3, 4]
@@ -300,6 +323,7 @@ def test_rename_columns(self):
300323

301324
columns = pdf.columns
302325
self.assert_eq(kdf.columns, columns)
326+
self.assert_eq(kdf, pdf)
303327

304328
pdf.columns = ['x', 'y']
305329
kdf.columns = ['x', 'y']
@@ -311,6 +335,12 @@ def test_rename_columns(self):
311335
self.assert_eq(kdf.columns, columns)
312336
self.assert_eq(kdf, pdf)
313337

338+
columns.names = ['lvl_1', 'lvl_2']
339+
340+
kdf.columns = columns
341+
self.assert_eq(kdf.columns.names, ['lvl_1', 'lvl_2'])
342+
self.assert_eq(kdf, pdf)
343+
314344
def test_dot_in_column_name(self):
315345
self.assert_eq(
316346
ks.DataFrame(ks.range(1)._sdf.selectExpr("1 as `a.b`"))['a.b'],

0 commit comments

Comments
 (0)