Skip to content

Commit d3eceae

Browse files
authored
Fix assign to support multi-index columns. (#811)
1 parent 4785c48 commit d3eceae

File tree

2 files changed

+53
-22
lines changed

2 files changed

+53
-22
lines changed

databricks/koalas/frame.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3118,6 +3118,10 @@ def assign(self, **kwargs):
31183118
feature is supported in pandas for Python 3.6 and later but not in
31193119
Koalas. In Koalas, all items are computed first, and then assigned.
31203120
"""
3121+
return self._assign(kwargs)
3122+
3123+
def _assign(self, kwargs):
3124+
assert isinstance(kwargs, dict)
31213125
from databricks.koalas.series import Series
31223126
for k, v in kwargs.items():
31233127
if not (isinstance(v, (Series, spark.Column)) or
@@ -3127,23 +3131,39 @@ def assign(self, **kwargs):
31273131
if callable(v):
31283132
kwargs[k] = v(self)
31293133

3130-
pairs = list(kwargs.items())
3131-
sdf = self._sdf
3132-
for (name, c) in pairs:
3133-
if isinstance(c, Series):
3134-
sdf = sdf.withColumn(name, c._scol)
3135-
elif isinstance(c, Column):
3136-
sdf = sdf.withColumn(name, c)
3137-
else:
3138-
sdf = sdf.withColumn(name, F.lit(c))
3134+
pairs = {(k if isinstance(k, tuple) else (k,)):
3135+
(v._scol if isinstance(v, Series)
3136+
else v if isinstance(v, spark.Column)
3137+
else F.lit(v))
3138+
for k, v in kwargs.items()}
31393139

3140-
data_columns = set(self._internal.data_columns)
3141-
adding_columns = [name for name, _ in pairs if name not in data_columns]
3140+
scols = []
3141+
for idx in self._internal.column_index:
3142+
for i in range(len(idx)):
3143+
if idx[:len(idx)-i] in pairs:
3144+
name = self._internal.column_name_for(idx)
3145+
scol = pairs[idx[:len(idx)-i]].alias(name)
3146+
break
3147+
else:
3148+
scol = self._internal.scol_for(idx)
3149+
scols.append(scol)
3150+
3151+
adding_data_columns = []
3152+
adding_column_index = []
3153+
for idx, scol in pairs.items():
3154+
if idx not in set(i[:len(idx)] for i in self._internal.column_index):
3155+
name = str(idx) if len(idx) > 1 else idx[0]
3156+
scols.append(scol.alias(name))
3157+
adding_data_columns.append(name)
3158+
adding_column_index.append(idx)
3159+
3160+
sdf = self._sdf.select(self._internal.index_scols + scols)
31423161
level = self._internal.column_index_level
3143-
adding_column_index = [tuple([col, *([''] * (level - 1))]) for col in adding_columns]
3162+
adding_column_index = [tuple(list(idx) + ([''] * (level - len(idx))))
3163+
for idx in adding_column_index]
31443164
internal = self._internal.copy(
31453165
sdf=sdf,
3146-
data_columns=(self._internal.data_columns + adding_columns),
3166+
data_columns=(self._internal.data_columns + adding_data_columns),
31473167
column_index=(self._internal.column_index + adding_column_index))
31483168
return DataFrame(internal)
31493169

@@ -6991,14 +7011,14 @@ def assign_columns(kdf, this_column_index, that_column_index):
69917011
yield (kdf[this_idx], this_idx)
69927012

69937013
kdf = align_diff_frames(assign_columns, self, value, fillna=False, how="left")
6994-
elif isinstance(key, (tuple, list)):
7014+
elif isinstance(key, list):
69957015
assert isinstance(value, DataFrame)
69967016
# Same DataFrames.
69977017
field_names = value.columns
6998-
kdf = self.assign(**{k: value[c] for k, c in zip(key, field_names)})
7018+
kdf = self._assign({k: value[c] for k, c in zip(key, field_names)})
69997019
else:
70007020
# Same Series.
7001-
kdf = self.assign(**{key: value})
7021+
kdf = self._assign({key: value})
70027022

70037023
self._internal = kdf._internal
70047024

databricks/koalas/tests/test_dataframe.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,25 +265,36 @@ def test_nullable_object(self):
265265
self.assert_eq(kdf, pdf)
266266

267267
def test_assign(self):
268-
kdf = self.kdf.copy()
269-
pdf = self.pdf.copy()
268+
kdf = self.kdf
269+
pdf = self.pdf
270270

271271
kdf['w'] = 1.0
272272
pdf['w'] = 1.0
273273

274274
self.assert_eq(kdf, pdf)
275275

276-
kdf['a'] = 'abc'
277-
pdf['a'] = 'abc'
276+
kdf = kdf.assign(a=kdf['a'] * 2)
277+
pdf = pdf.assign(a=pdf['a'] * 2)
278278

279279
self.assert_eq(kdf, pdf)
280280

281+
# multi-index columns
281282
columns = pd.MultiIndex.from_tuples([('x', 'a'), ('x', 'b'), ('y', 'w')])
282283
pdf.columns = columns
283284
kdf.columns = columns
284285

285-
pdf['Z'] = 'ZZ'
286-
kdf['Z'] = 'ZZ'
286+
kdf[('a', 'c')] = 'def'
287+
pdf[('a', 'c')] = 'def'
288+
289+
self.assert_eq(kdf, pdf)
290+
291+
kdf = kdf.assign(Z='ZZ')
292+
pdf = pdf.assign(Z='ZZ')
293+
294+
self.assert_eq(kdf, pdf)
295+
296+
kdf['x'] = 'ghi'
297+
pdf['x'] = 'ghi'
287298

288299
self.assert_eq(kdf, pdf)
289300

0 commit comments

Comments
 (0)