Skip to content

Commit bfed2a3

Browse files
authored
Fix where to support multi-index columns. (#1249)
1 parent f058b7b commit bfed2a3

File tree

3 files changed

+49
-36
lines changed

3 files changed

+49
-36
lines changed

databricks/koalas/frame.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,26 +2200,27 @@ def where(self, cond, other=np.nan):
22002200
>>> reset_option("compute.ops_on_diff_frames")
22012201
"""
22022202
from databricks.koalas.series import Series
2203-
tmp_cond_col_name = '__tmp_cond_col_{}__'
2204-
tmp_other_col_name = '__tmp_other_col_{}__'
2203+
2204+
tmp_cond_col_name = '__tmp_cond_col_{}__'.format
2205+
tmp_other_col_name = '__tmp_other_col_{}__'.format
2206+
22052207
kdf = self.copy()
22062208
if isinstance(cond, DataFrame):
2207-
for column in self._internal.data_columns:
2208-
kdf[tmp_cond_col_name.format(column)] = cond.get(column, False)
2209+
for idx in self._internal.column_index:
2210+
kdf[tmp_cond_col_name(name_like_string(idx))] = cond.get(idx, False)
22092211
elif isinstance(cond, Series):
2210-
for column in self._internal.data_columns:
2211-
kdf[tmp_cond_col_name.format(column)] = cond
2212+
for idx in self._internal.column_index:
2213+
kdf[tmp_cond_col_name(name_like_string(idx))] = cond
22122214
else:
22132215
raise ValueError("type of cond must be a DataFrame or Series")
22142216

22152217
if isinstance(other, DataFrame):
2216-
for column in self._internal.data_columns:
2217-
kdf[tmp_other_col_name.format(column)] = other.get(column, np.nan)
2218+
for idx in self._internal.column_index:
2219+
kdf[tmp_other_col_name(name_like_string(idx))] = other.get(idx, np.nan)
22182220
else:
2219-
for column in self._internal.data_columns:
2220-
kdf[tmp_other_col_name.format(column)] = other
2221+
for idx in self._internal.column_index:
2222+
kdf[tmp_other_col_name(name_like_string(idx))] = other
22212223

2222-
sdf = kdf._sdf
22232224
# above logic make spark dataframe looks like below:
22242225
# +-----------------+---+---+------------------+-------------------+------------------+--...
22252226
# |__index_level_0__| A| B|__tmp_cond_col_A__|__tmp_other_col_A__|__tmp_cond_col_B__|__...
@@ -2231,22 +2232,18 @@ def where(self, cond, other=np.nan):
22312232
# | 4| 4|500| false| -4| false| ...
22322233
# +-----------------+---+---+------------------+-------------------+------------------+--...
22332234

2234-
output = []
2235-
for column in self._internal.data_columns:
2236-
data_col_name = self._internal.column_name_for(column)
2237-
output.append(
2235+
column_scols = []
2236+
for idx in self._internal.column_index:
2237+
column_scols.append(
22382238
F.when(
2239-
scol_for(sdf, tmp_cond_col_name.format(column)), scol_for(sdf, data_col_name)
2239+
kdf[tmp_cond_col_name(name_like_string(idx))]._scol,
2240+
kdf[idx]._scol
22402241
).otherwise(
2241-
scol_for(sdf, tmp_other_col_name.format(column))
2242-
).alias(data_col_name))
2242+
kdf[tmp_other_col_name(name_like_string(idx))]._scol
2243+
).alias(kdf._internal.column_name_for(idx)))
22432244

2244-
index_scols = kdf._internal.index_scols
2245-
sdf = sdf.select(index_scols + output + list(HIDDEN_COLUMNS))
2246-
2247-
return DataFrame(self._internal.copy(
2248-
sdf=sdf,
2249-
column_scols=[scol_for(sdf, column) for column in self._internal.data_columns]))
2245+
return DataFrame(kdf._internal.with_new_columns(column_scols,
2246+
column_index=self._internal.column_index))
22502247

22512248
def mask(self, cond, other=np.nan):
22522249
"""

databricks/koalas/series.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3886,8 +3886,6 @@ def where(self, cond, other=np.nan):
38863886
38873887
>>> reset_option("compute.ops_on_diff_frames")
38883888
"""
3889-
data_col_name = self._internal.column_name_for(self._internal.column_index[0])
3890-
38913889
assert isinstance(cond, Series)
38923890

38933891
# We should check the DataFrame from both `cond` and `other`.
@@ -3901,7 +3899,6 @@ def where(self, cond, other=np.nan):
39013899
kdf['__tmp_cond_col__'] = cond
39023900
kdf['__tmp_other_col__'] = other
39033901

3904-
sdf = kdf._sdf
39053902
# above logic makes a Spark DataFrame looks like below:
39063903
# +-----------------+---+----------------+-----------------+
39073904
# |__index_level_0__| 0|__tmp_cond_col__|__tmp_other_col__|
@@ -3913,21 +3910,18 @@ def where(self, cond, other=np.nan):
39133910
# | 4| 4| true| 500|
39143911
# +-----------------+---+----------------+-----------------+
39153912
condition = F.when(
3916-
sdf['__tmp_cond_col__'], sdf[data_col_name]
3917-
).otherwise(sdf['__tmp_other_col__']).alias(data_col_name)
3913+
kdf['__tmp_cond_col__']._scol, kdf[self._internal.column_index[0]]._scol
3914+
).otherwise(kdf['__tmp_other_col__']._scol).alias(self._internal.data_columns[0])
39183915

3919-
sdf = sdf.select(*self._internal.index_columns + [condition])
3920-
return _col(ks.DataFrame(_InternalFrame(
3921-
sdf=sdf,
3922-
index_map=self._internal.index_map,
3923-
column_index=self._internal.column_index,
3924-
column_index_names=self._internal.column_index_names)))
3916+
internal = kdf._internal.with_new_columns([condition],
3917+
column_index=self._internal.column_index)
3918+
return _col(DataFrame(internal))
39253919
else:
39263920
if isinstance(other, Series):
39273921
other = other._scol
39283922
condition = F.when(
39293923
cond._scol, self._scol
3930-
).otherwise(other).alias(data_col_name)
3924+
).otherwise(other).alias(self._internal.data_columns[0])
39313925
return self._with_new_scol(condition)
39323926

39333927
def mask(self, cond, other=np.nan):

databricks/koalas/tests/test_ops_on_diff_frames.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,17 @@ def test_where(self):
571571
self.assert_eq(repr(pdf1.where(pdf2 < -250)),
572572
repr(kdf1.where(kdf2 < -250).sort_index()))
573573

574+
# multi-index columns
575+
pdf1 = pd.DataFrame({('X', 'A'): [0, 1, 2, 3, 4],
576+
('X', 'B'): [100, 200, 300, 400, 500]})
577+
pdf2 = pd.DataFrame({('X', 'A'): [0, -1, -2, -3, -4],
578+
('X', 'B'): [-100, -200, -300, -400, -500]})
579+
kdf1 = ks.from_pandas(pdf1)
580+
kdf2 = ks.from_pandas(pdf2)
581+
582+
self.assert_eq(repr(pdf1.where(pdf2 > 100)),
583+
repr(kdf1.where(kdf2 > 100).sort_index()))
584+
574585
def test_mask(self):
575586
pdf1 = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B': [100, 200, 300, 400, 500]})
576587
pdf2 = pd.DataFrame({'A': [0, -1, -2, -3, -4], 'B': [-100, -200, -300, -400, -500]})
@@ -588,6 +599,17 @@ def test_mask(self):
588599
self.assert_eq(repr(pdf1.mask(pdf2 > -250)),
589600
repr(kdf1.mask(kdf2 > -250).sort_index()))
590601

602+
# multi-index columns
603+
pdf1 = pd.DataFrame({('X', 'A'): [0, 1, 2, 3, 4],
604+
('X', 'B'): [100, 200, 300, 400, 500]})
605+
pdf2 = pd.DataFrame({('X', 'A'): [0, -1, -2, -3, -4],
606+
('X', 'B'): [-100, -200, -300, -400, -500]})
607+
kdf1 = ks.from_pandas(pdf1)
608+
kdf2 = ks.from_pandas(pdf2)
609+
610+
self.assert_eq(repr(pdf1.mask(pdf2 < 100)),
611+
repr(kdf1.mask(kdf2 < 100).sort_index()))
612+
591613
def test_multi_index_column_assignment_frame(self):
592614
pdf = pd.DataFrame({'a': [1, 2, 3, 2], 'b': [4.0, 2.0, 3.0, 1.0]})
593615
pdf.columns = pd.MultiIndex.from_tuples([('a', 'x'), ('a', 'y')])

0 commit comments

Comments
 (0)