Skip to content

Commit da9ba26

Browse files
authored
feat: support left_index and right_index for merge (#2220)
* feat: support left_index and right_index for merge * checkpoint: managed to let code run without error. need to handle column coalescing next * checkpoint: single-index dev complete. still facing errors when dealing with multi-index * wrap up support for single index * fix format * fix tests * fix test * remove unnecessary deps
1 parent bfcc08f commit da9ba26

File tree

7 files changed

+347
-72
lines changed

7 files changed

+347
-72
lines changed

bigframes/core/blocks.py

Lines changed: 98 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2332,41 +2332,61 @@ def merge(
23322332
right_join_ids: typing.Sequence[str],
23332333
sort: bool,
23342334
suffixes: tuple[str, str] = ("_x", "_y"),
2335+
left_index: bool = False,
2336+
right_index: bool = False,
23352337
) -> Block:
23362338
conditions = tuple(
23372339
(lid, rid) for lid, rid in zip(left_join_ids, right_join_ids)
23382340
)
23392341
joined_expr, (get_column_left, get_column_right) = self.expr.relational_join(
23402342
other.expr, type=how, conditions=conditions
23412343
)
2342-
result_columns = []
2343-
matching_join_labels = []
23442344

23452345
left_post_join_ids = tuple(get_column_left[id] for id in left_join_ids)
23462346
right_post_join_ids = tuple(get_column_right[id] for id in right_join_ids)
23472347

2348-
joined_expr, coalesced_ids = coalesce_columns(
2349-
joined_expr, left_post_join_ids, right_post_join_ids, how=how, drop=False
2350-
)
2348+
if left_index or right_index:
2349+
# For some reason pandas coalesces two joining columns if one side is an index.
2350+
joined_expr, resolved_join_ids = coalesce_columns(
2351+
joined_expr, left_post_join_ids, right_post_join_ids
2352+
)
2353+
else:
2354+
joined_expr, resolved_join_ids = resolve_col_join_ids( # type: ignore
2355+
joined_expr,
2356+
left_post_join_ids,
2357+
right_post_join_ids,
2358+
how=how,
2359+
drop=False,
2360+
)
2361+
2362+
result_columns = []
2363+
matching_join_labels = []
23512364

2365+
# Select left value columns
23522366
for col_id in self.value_columns:
23532367
if col_id in left_join_ids:
23542368
key_part = left_join_ids.index(col_id)
23552369
matching_right_id = right_join_ids[key_part]
23562370
if (
2357-
self.col_id_to_label[col_id]
2371+
right_index
2372+
or self.col_id_to_label[col_id]
23582373
== other.col_id_to_label[matching_right_id]
23592374
):
23602375
matching_join_labels.append(self.col_id_to_label[col_id])
2361-
result_columns.append(coalesced_ids[key_part])
2376+
result_columns.append(resolved_join_ids[key_part])
23622377
else:
23632378
result_columns.append(get_column_left[col_id])
23642379
else:
23652380
result_columns.append(get_column_left[col_id])
2381+
2382+
# Select right value columns
23662383
for col_id in other.value_columns:
23672384
if col_id in right_join_ids:
23682385
if other.col_id_to_label[col_id] in matching_join_labels:
23692386
pass
2387+
elif left_index:
2388+
key_part = right_join_ids.index(col_id)
2389+
result_columns.append(resolved_join_ids[key_part])
23702390
else:
23712391
result_columns.append(get_column_right[col_id])
23722392
else:
@@ -2377,11 +2397,22 @@ def merge(
23772397
joined_expr = joined_expr.order_by(
23782398
[
23792399
ordering.OrderingExpression(ex.deref(col_id))
2380-
for col_id in coalesced_ids
2400+
for col_id in resolved_join_ids
23812401
],
23822402
)
23832403

2384-
joined_expr = joined_expr.select_columns(result_columns)
2404+
left_idx_id_post_join = [get_column_left[id] for id in self.index_columns]
2405+
right_idx_id_post_join = [get_column_right[id] for id in other.index_columns]
2406+
index_cols = _resolve_index_col(
2407+
left_idx_id_post_join,
2408+
right_idx_id_post_join,
2409+
resolved_join_ids,
2410+
left_index,
2411+
right_index,
2412+
how,
2413+
)
2414+
2415+
joined_expr = joined_expr.select_columns(result_columns + index_cols)
23852416
labels = utils.merge_column_labels(
23862417
self.column_labels,
23872418
other.column_labels,
@@ -2400,13 +2431,13 @@ def merge(
24002431
or other.index.is_null
24012432
or self.session._default_index_type == bigframes.enums.DefaultIndexKind.NULL
24022433
):
2403-
expr = joined_expr
2404-
index_columns = []
2434+
return Block(joined_expr, index_columns=[], column_labels=labels)
2435+
elif index_cols:
2436+
return Block(joined_expr, index_columns=index_cols, column_labels=labels)
24052437
else:
24062438
expr, offset_index_id = joined_expr.promote_offsets()
24072439
index_columns = [offset_index_id]
2408-
2409-
return Block(expr, index_columns=index_columns, column_labels=labels)
2440+
return Block(expr, index_columns=index_columns, column_labels=labels)
24102441

24112442
def _align_both_axes(
24122443
self, other: Block, how: str
@@ -3115,7 +3146,7 @@ def join_mono_indexed(
31153146
left_index = get_column_left[left.index_columns[0]]
31163147
right_index = get_column_right[right.index_columns[0]]
31173148
# Drop original indices from each side. and used the coalesced combination generated by the join.
3118-
combined_expr, coalesced_join_cols = coalesce_columns(
3149+
combined_expr, coalesced_join_cols = resolve_col_join_ids(
31193150
combined_expr, [left_index], [right_index], how=how
31203151
)
31213152
if sort:
@@ -3180,7 +3211,7 @@ def join_multi_indexed(
31803211
left_ids_post_join = [get_column_left[id] for id in left_join_ids]
31813212
right_ids_post_join = [get_column_right[id] for id in right_join_ids]
31823213
# Drop original indices from each side. and used the coalesced combination generated by the join.
3183-
combined_expr, coalesced_join_cols = coalesce_columns(
3214+
combined_expr, coalesced_join_cols = resolve_col_join_ids(
31843215
combined_expr, left_ids_post_join, right_ids_post_join, how=how
31853216
)
31863217
if sort:
@@ -3223,13 +3254,17 @@ def resolve_label_id(label: Label) -> str:
32233254

32243255

32253256
# TODO: Rewrite just to return expressions
3226-
def coalesce_columns(
3257+
def resolve_col_join_ids(
32273258
expr: core.ArrayValue,
32283259
left_ids: typing.Sequence[str],
32293260
right_ids: typing.Sequence[str],
32303261
how: str,
32313262
drop: bool = True,
32323263
) -> Tuple[core.ArrayValue, Sequence[str]]:
3264+
"""
3265+
Collapses and selects the joining column IDs, with the assumption that
3266+
the ids are all belong to value columns.
3267+
"""
32333268
result_ids = []
32343269
for left_id, right_id in zip(left_ids, right_ids):
32353270
if how == "left" or how == "inner" or how == "cross":
@@ -3241,7 +3276,6 @@ def coalesce_columns(
32413276
if drop:
32423277
expr = expr.drop_columns([left_id])
32433278
elif how == "outer":
3244-
coalesced_id = guid.generate_guid()
32453279
expr, coalesced_id = expr.project_to_id(
32463280
ops.coalesce_op.as_expr(left_id, right_id)
32473281
)
@@ -3253,6 +3287,21 @@ def coalesce_columns(
32533287
return expr, result_ids
32543288

32553289

3290+
def coalesce_columns(
3291+
expr: core.ArrayValue,
3292+
left_ids: typing.Sequence[str],
3293+
right_ids: typing.Sequence[str],
3294+
) -> tuple[core.ArrayValue, list[str]]:
3295+
result_ids = []
3296+
for left_id, right_id in zip(left_ids, right_ids):
3297+
expr, coalesced_id = expr.project_to_id(
3298+
ops.coalesce_op.as_expr(left_id, right_id)
3299+
)
3300+
result_ids.append(coalesced_id)
3301+
3302+
return expr, result_ids
3303+
3304+
32563305
def _cast_index(block: Block, dtypes: typing.Sequence[bigframes.dtypes.Dtype]):
32573306
original_block = block
32583307
result_ids = []
@@ -3468,3 +3517,35 @@ def _pd_index_to_array_value(
34683517
rows.append(row)
34693518

34703519
return core.ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=session)
3520+
3521+
3522+
def _resolve_index_col(
3523+
left_index_cols: list[str],
3524+
right_index_cols: list[str],
3525+
resolved_join_ids: list[str],
3526+
left_index: bool,
3527+
right_index: bool,
3528+
how: typing.Literal[
3529+
"inner",
3530+
"left",
3531+
"outer",
3532+
"right",
3533+
"cross",
3534+
],
3535+
) -> list[str]:
3536+
if left_index and right_index:
3537+
if how == "inner" or how == "left":
3538+
return left_index_cols
3539+
if how == "right":
3540+
return right_index_cols
3541+
if how == "outer":
3542+
return resolved_join_ids
3543+
else:
3544+
return []
3545+
elif left_index and not right_index:
3546+
return right_index_cols
3547+
elif right_index and not left_index:
3548+
return left_index_cols
3549+
else:
3550+
# Joining with value columns only. Existing indices will be discarded.
3551+
return []

0 commit comments

Comments
 (0)