Skip to content

Commit eb30dfa

Browse files
authored
Fix where to reduce the number of joins. (#1274)
When `DataFrame.where`, there were so many joins in its plan and it depends on the number of columns. E.g.,: ```py >>> df1 = ks.DataFrame({'A': [0, 1, 2, 3, 4], 'B':[100, 200, 300, 400, 500]}) >>> df1.where(df1 > 0).explain() == Physical Plan == *(8) Project [__index_level_0__#704L, CASE WHEN __tmp_cond_col_A__#808 THEN cast(A#1L as double) ELSE NaN END AS A#1002, CASE WHEN B#674 THEN cast(B#2L as double) ELSE NaN END AS B#1003] +- SortMergeJoin [__index_level_0__#704L], [__index_level_0__#0L], LeftOuter :- *(5) Sort [__index_level_0__#704L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(__index_level_0__#704L, 200) : +- *(4) Project [__index_level_0__#0L AS __index_level_0__#704L, A#1L, B#2L, A#673 AS __tmp_cond_col_A__#808] : +- SortMergeJoin [__index_level_0__#0L], [__index_level_0__#705L], LeftOuter : :- *(1) Sort [__index_level_0__#0L ASC NULLS FIRST], false, 0 : : +- Exchange hashpartitioning(__index_level_0__#0L, 200) : : +- Scan ExistingRDD[__index_level_0__#0L,A#1L,B#2L] : +- *(3) Sort [__index_level_0__#705L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(__index_level_0__#705L, 200) : +- *(2) Project [__index_level_0__#705L, (A#706L > 0) AS A#673] : +- Scan ExistingRDD[__index_level_0__#705L,A#706L,B#707L] +- *(7) Sort [__index_level_0__#0L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(__index_level_0__#0L, 200) +- *(6) Project [__index_level_0__#0L, (B#2L > 0) AS B#674] +- Scan ExistingRDD[__index_level_0__#0L,A#1L,B#2L] >>> df1.where(df1 > 0, df1 + 100).explain() == Physical Plan == *(14) Project [__index_level_0__#1054L, CASE WHEN __tmp_cond_col_A__#1158 THEN A#1L ELSE __tmp_other_col_A__#1470L END AS A#1677L, CASE WHEN __tmp_cond_col_B__#1302 THEN B#2L ELSE B#1026L END AS B#1678L] +- SortMergeJoin [__index_level_0__#1054L], [__index_level_0__#0L], LeftOuter :- *(11) Project [__index_level_0__#1054L, A#1L, B#2L, __tmp_cond_col_A__#1158, __tmp_cond_col_B__#1302, A#1024L AS __tmp_other_col_A__#1470L] : +- SortMergeJoin [__index_level_0__#1054L], [__index_level_0__#0L], LeftOuter : :- *(8) Project [__index_level_0__#1054L, A#1L, B#2L, __tmp_cond_col_A__#1158, B#1016 AS __tmp_cond_col_B__#1302] : : +- SortMergeJoin [__index_level_0__#1054L], [__index_level_0__#0L], LeftOuter : : :- *(5) Sort [__index_level_0__#1054L ASC NULLS FIRST], false, 0 : : : +- Exchange hashpartitioning(__index_level_0__#1054L, 200) : : : +- *(4) Project [__index_level_0__#0L AS __index_level_0__#1054L, A#1L, B#2L, A#1015 AS __tmp_cond_col_A__#1158] : : : +- SortMergeJoin [__index_level_0__#0L], [__index_level_0__#1055L], LeftOuter : : : :- *(1) Sort [__index_level_0__#0L ASC NULLS FIRST], false, 0 : : : : +- Exchange hashpartitioning(__index_level_0__#0L, 200) : : : : +- Scan ExistingRDD[__index_level_0__#0L,A#1L,B#2L] : : : +- *(3) Sort [__index_level_0__#1055L ASC NULLS FIRST], false, 0 : : : +- Exchange hashpartitioning(__index_level_0__#1055L, 200) : : : +- *(2) Project [__index_level_0__#1055L, (A#1056L > 0) AS A#1015] : : : +- Scan ExistingRDD[__index_level_0__#1055L,A#1056L,B#1057L] : : +- *(7) Sort [__index_level_0__#0L ASC NULLS FIRST], false, 0 : : +- Exchange hashpartitioning(__index_level_0__#0L, 200) : : +- *(6) Project [__index_level_0__#0L, (B#2L > 0) AS B#1016] : : +- Scan ExistingRDD[__index_level_0__#0L,A#1L,B#2L] : +- *(10) Sort [__index_level_0__#0L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(__index_level_0__#0L, 200) : +- *(9) Project [__index_level_0__#0L, (A#1L + 100) AS A#1024L] : +- Scan ExistingRDD[__index_level_0__#0L,A#1L,B#2L] +- *(13) Sort [__index_level_0__#0L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(__index_level_0__#0L, 200) +- *(12) Project [__index_level_0__#0L, (B#2L + 100) AS B#1026L] +- Scan ExistingRDD[__index_level_0__#0L,A#1L,B#2L] ``` We can reduce the number of joins to at most twice. ```py >>> df1.where(df1 > 0).explain() == Physical Plan == *(4) Project [__index_level_0__#0L, CASE WHEN __tmp_cond_col_A__#255 THEN cast(A#1L as double) ELSE NaN END AS A#468, CASE WHEN __tmp_cond_col_B__#256 THEN cast(B#2L as double) ELSE NaN END AS B#469] +- SortMergeJoin [__index_level_0__#0L], [__index_level_0__#292L], LeftOuter :- *(1) Sort [__index_level_0__#0L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(__index_level_0__#0L, 200) : +- Scan ExistingRDD[__index_level_0__#0L,A#1L,B#2L] +- *(3) Sort [__index_level_0__#292L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(__index_level_0__#292L, 200) +- *(2) Project [__index_level_0__#292L, (A#293L > 0) AS __tmp_cond_col_A__#255, (B#294L > 0) AS __tmp_cond_col_B__#256] +- Scan ExistingRDD[__index_level_0__#292L,A#293L,B#294L] >>> df1.where(df1 > 0, df1 + 100).explain() == Physical Plan == *(8) Project [__index_level_0__#535L, CASE WHEN __tmp_cond_col_A__#499 THEN A#1L ELSE __tmp_other_col_A__#669L END AS A#882L, CASE WHEN __tmp_cond_col_B__#500 THEN B#2L ELSE __tmp_other_col_B__#670L END AS B#883L] +- SortMergeJoin [__index_level_0__#535L], [__index_level_0__#0L], LeftOuter :- *(5) Sort [__index_level_0__#535L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(__index_level_0__#535L, 200) : +- *(4) Project [__index_level_0__#0L AS __index_level_0__#535L, A#1L, B#2L, __tmp_cond_col_A__#499, __tmp_cond_col_B__#500] : +- SortMergeJoin [__index_level_0__#0L], [__index_level_0__#536L], LeftOuter : :- *(1) Sort [__index_level_0__#0L ASC NULLS FIRST], false, 0 : : +- Exchange hashpartitioning(__index_level_0__#0L, 200) : : +- Scan ExistingRDD[__index_level_0__#0L,A#1L,B#2L] : +- *(3) Sort [__index_level_0__#536L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(__index_level_0__#536L, 200) : +- *(2) Project [__index_level_0__#536L, (A#537L > 0) AS __tmp_cond_col_A__#499, (B#538L > 0) AS __tmp_cond_col_B__#500] : +- Scan ExistingRDD[__index_level_0__#536L,A#537L,B#538L] +- *(7) Sort [__index_level_0__#0L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(__index_level_0__#0L, 200) +- *(6) Project [__index_level_0__#0L, (A#1L + 100) AS __tmp_other_col_A__#669L, (B#2L + 100) AS __tmp_other_col_B__#670L] +- Scan ExistingRDD[__index_level_0__#0L,A#1L,B#2L] ```
1 parent 2abed90 commit eb30dfa

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

databricks/koalas/frame.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2429,18 +2429,36 @@ def where(self, cond, other=np.nan):
24292429
tmp_other_col_name = '__tmp_other_col_{}__'.format
24302430

24312431
kdf = self.copy()
2432+
2433+
tmp_cond_col_names = [tmp_cond_col_name(name_like_string(label))
2434+
for label in self._internal.column_labels]
24322435
if isinstance(cond, DataFrame):
2433-
for label in self._internal.column_labels:
2434-
kdf[tmp_cond_col_name(name_like_string(label))] = cond.get(label, False)
2436+
cond = cond[[(cond._internal.scol_for(label)
2437+
if label in cond._internal.column_labels else F.lit(False)).alias(name)
2438+
for label, name
2439+
in zip(self._internal.column_labels, tmp_cond_col_names)]]
2440+
kdf[tmp_cond_col_names] = cond
24352441
elif isinstance(cond, Series):
2436-
for label in self._internal.column_labels:
2437-
kdf[tmp_cond_col_name(name_like_string(label))] = cond
2442+
cond = cond.to_frame()
2443+
cond = cond[[cond._internal.column_scols[0].alias(name) for name in tmp_cond_col_names]]
2444+
kdf[tmp_cond_col_names] = cond
24382445
else:
24392446
raise ValueError("type of cond must be a DataFrame or Series")
24402447

2448+
tmp_other_col_names = [tmp_other_col_name(name_like_string(label))
2449+
for label in self._internal.column_labels]
24412450
if isinstance(other, DataFrame):
2442-
for label in self._internal.column_labels:
2443-
kdf[tmp_other_col_name(name_like_string(label))] = other.get(label, np.nan)
2451+
other = other[[(other._internal.scol_for(label)
2452+
if label in other._internal.column_labels else F.lit(np.nan))
2453+
.alias(name)
2454+
for label, name
2455+
in zip(self._internal.column_labels, tmp_other_col_names)]]
2456+
kdf[tmp_other_col_names] = other
2457+
elif isinstance(other, Series):
2458+
other = other.to_frame()
2459+
other = other[[other._internal.column_scols[0].alias(name)
2460+
for name in tmp_other_col_names]]
2461+
kdf[tmp_other_col_names] = other
24442462
else:
24452463
for label in self._internal.column_labels:
24462464
kdf[tmp_other_col_name(name_like_string(label))] = other

0 commit comments

Comments
 (0)