Skip to content

Commit 1bf41b8

Browse files
authored
Fix round to support multi-index columns. (#802)
1 parent 8ad9b4f commit 1bf41b8

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

databricks/koalas/frame.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2531,17 +2531,21 @@ def round(self, decimals=0):
25312531
third 0.9 0.0 0.49
25322532
"""
25332533
if isinstance(decimals, ks.Series):
2534-
decimals_list = [kv for kv in decimals._to_internal_pandas().items()]
2534+
decimals_list = [(k if isinstance(k, tuple) else (k,), v)
2535+
for k, v in decimals._to_internal_pandas().items()]
25352536
elif isinstance(decimals, dict):
2536-
decimals_list = [(k, v) for k, v in decimals.items()]
2537+
decimals_list = [(k if isinstance(k, tuple) else (k,), v)
2538+
for k, v in decimals.items()]
25372539
elif isinstance(decimals, int):
2538-
decimals_list = [(v, decimals) for v in self._internal.data_columns]
2540+
decimals_list = [(k, decimals) for k in self._internal.column_index]
25392541
else:
25402542
raise ValueError("decimals must be an integer, a dict-like or a Series")
25412543

25422544
sdf = self._sdf
2543-
for decimal in decimals_list:
2544-
sdf = sdf.withColumn(decimal[0], F.round(scol_for(sdf, decimal[0]), decimal[1]))
2545+
for idx, decimal in decimals_list:
2546+
if idx in self._internal.column_index:
2547+
col = self._internal.column_name_for(idx)
2548+
sdf = sdf.withColumn(col, F.round(scol_for(sdf, col), decimal))
25452549
return DataFrame(self._internal.copy(sdf=sdf))
25462550

25472551
def duplicated(self, subset=None, keep='first'):

databricks/koalas/tests/test_dataframe.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,12 +1480,29 @@ def test_round(self):
14801480
kdf.round(2))
14811481
self.assert_eq(pdf.round({'A': 1, 'C': 2}),
14821482
kdf.round({'A': 1, 'C': 2}))
1483+
self.assert_eq(pdf.round({'A': 1, 'D': 2}),
1484+
kdf.round({'A': 1, 'D': 2}))
14831485
self.assert_eq(pdf.round(pser),
14841486
kdf.round(kser))
14851487
msg = "decimals must be an integer, a dict-like or a Series"
14861488
with self.assertRaisesRegex(ValueError, msg):
14871489
kdf.round(1.5)
14881490

1491+
# multi-index columns
1492+
columns = pd.MultiIndex.from_tuples([('X', 'A'), ('X', 'B'), ('Y', 'C')])
1493+
pdf.columns = columns
1494+
kdf.columns = columns
1495+
pser = pd.Series([1, 0, 2], index=columns)
1496+
kser = ks.Series([1, 0, 2], index=columns)
1497+
self.assert_eq(pdf.round(2),
1498+
kdf.round(2))
1499+
self.assert_eq(pdf.round({('X', 'A'): 1, ('Y', 'C'): 2}),
1500+
kdf.round({('X', 'A'): 1, ('Y', 'C'): 2}))
1501+
self.assert_eq(pdf.round({('X', 'A'): 1, 'Y': 2}),
1502+
kdf.round({('X', 'A'): 1, 'Y': 2}))
1503+
self.assert_eq(pdf.round(pser),
1504+
kdf.round(kser))
1505+
14891506
def test_shift(self):
14901507
pdf = pd.DataFrame({'Col1': [10, 20, 15, 30, 45],
14911508
'Col2': [13, 23, 18, 33, 48],

0 commit comments

Comments
 (0)