@@ -2332,41 +2332,61 @@ def merge(
23322332 right_join_ids : typing .Sequence [str ],
23332333 sort : bool ,
23342334 suffixes : tuple [str , str ] = ("_x" , "_y" ),
2335+ left_index : bool = False ,
2336+ right_index : bool = False ,
23352337 ) -> Block :
23362338 conditions = tuple (
23372339 (lid , rid ) for lid , rid in zip (left_join_ids , right_join_ids )
23382340 )
23392341 joined_expr , (get_column_left , get_column_right ) = self .expr .relational_join (
23402342 other .expr , type = how , conditions = conditions
23412343 )
2342- result_columns = []
2343- matching_join_labels = []
23442344
23452345 left_post_join_ids = tuple (get_column_left [id ] for id in left_join_ids )
23462346 right_post_join_ids = tuple (get_column_right [id ] for id in right_join_ids )
23472347
2348- joined_expr , coalesced_ids = coalesce_columns (
2349- joined_expr , left_post_join_ids , right_post_join_ids , how = how , drop = False
2350- )
2348+ if left_index or right_index :
2349+ # For some reason pandas coalesces two joining columns if one side is an index.
2350+ joined_expr , resolved_join_ids = coalesce_columns (
2351+ joined_expr , left_post_join_ids , right_post_join_ids
2352+ )
2353+ else :
2354+ joined_expr , resolved_join_ids = resolve_col_join_ids ( # type: ignore
2355+ joined_expr ,
2356+ left_post_join_ids ,
2357+ right_post_join_ids ,
2358+ how = how ,
2359+ drop = False ,
2360+ )
2361+
2362+ result_columns = []
2363+ matching_join_labels = []
23512364
2365+ # Select left value columns
23522366 for col_id in self .value_columns :
23532367 if col_id in left_join_ids :
23542368 key_part = left_join_ids .index (col_id )
23552369 matching_right_id = right_join_ids [key_part ]
23562370 if (
2357- self .col_id_to_label [col_id ]
2371+ right_index
2372+ or self .col_id_to_label [col_id ]
23582373 == other .col_id_to_label [matching_right_id ]
23592374 ):
23602375 matching_join_labels .append (self .col_id_to_label [col_id ])
2361- result_columns .append (coalesced_ids [key_part ])
2376+ result_columns .append (resolved_join_ids [key_part ])
23622377 else :
23632378 result_columns .append (get_column_left [col_id ])
23642379 else :
23652380 result_columns .append (get_column_left [col_id ])
2381+
2382+ # Select right value columns
23662383 for col_id in other .value_columns :
23672384 if col_id in right_join_ids :
23682385 if other .col_id_to_label [col_id ] in matching_join_labels :
23692386 pass
2387+ elif left_index :
2388+ key_part = right_join_ids .index (col_id )
2389+ result_columns .append (resolved_join_ids [key_part ])
23702390 else :
23712391 result_columns .append (get_column_right [col_id ])
23722392 else :
@@ -2377,11 +2397,22 @@ def merge(
23772397 joined_expr = joined_expr .order_by (
23782398 [
23792399 ordering .OrderingExpression (ex .deref (col_id ))
2380- for col_id in coalesced_ids
2400+ for col_id in resolved_join_ids
23812401 ],
23822402 )
23832403
2384- joined_expr = joined_expr .select_columns (result_columns )
2404+ left_idx_id_post_join = [get_column_left [id ] for id in self .index_columns ]
2405+ right_idx_id_post_join = [get_column_right [id ] for id in other .index_columns ]
2406+ index_cols = _resolve_index_col (
2407+ left_idx_id_post_join ,
2408+ right_idx_id_post_join ,
2409+ resolved_join_ids ,
2410+ left_index ,
2411+ right_index ,
2412+ how ,
2413+ )
2414+
2415+ joined_expr = joined_expr .select_columns (result_columns + index_cols )
23852416 labels = utils .merge_column_labels (
23862417 self .column_labels ,
23872418 other .column_labels ,
@@ -2400,13 +2431,13 @@ def merge(
24002431 or other .index .is_null
24012432 or self .session ._default_index_type == bigframes .enums .DefaultIndexKind .NULL
24022433 ):
2403- expr = joined_expr
2404- index_columns = []
2434+ return Block (joined_expr , index_columns = [], column_labels = labels )
2435+ elif index_cols :
2436+ return Block (joined_expr , index_columns = index_cols , column_labels = labels )
24052437 else :
24062438 expr , offset_index_id = joined_expr .promote_offsets ()
24072439 index_columns = [offset_index_id ]
2408-
2409- return Block (expr , index_columns = index_columns , column_labels = labels )
2440+ return Block (expr , index_columns = index_columns , column_labels = labels )
24102441
24112442 def _align_both_axes (
24122443 self , other : Block , how : str
@@ -3115,7 +3146,7 @@ def join_mono_indexed(
31153146 left_index = get_column_left [left .index_columns [0 ]]
31163147 right_index = get_column_right [right .index_columns [0 ]]
31173148 # Drop original indices from each side. and used the coalesced combination generated by the join.
3118- combined_expr , coalesced_join_cols = coalesce_columns (
3149+ combined_expr , coalesced_join_cols = resolve_col_join_ids (
31193150 combined_expr , [left_index ], [right_index ], how = how
31203151 )
31213152 if sort :
@@ -3180,7 +3211,7 @@ def join_multi_indexed(
31803211 left_ids_post_join = [get_column_left [id ] for id in left_join_ids ]
31813212 right_ids_post_join = [get_column_right [id ] for id in right_join_ids ]
31823213 # Drop original indices from each side. and used the coalesced combination generated by the join.
3183- combined_expr , coalesced_join_cols = coalesce_columns (
3214+ combined_expr , coalesced_join_cols = resolve_col_join_ids (
31843215 combined_expr , left_ids_post_join , right_ids_post_join , how = how
31853216 )
31863217 if sort :
@@ -3223,13 +3254,17 @@ def resolve_label_id(label: Label) -> str:
32233254
32243255
32253256# TODO: Rewrite just to return expressions
3226- def coalesce_columns (
3257+ def resolve_col_join_ids (
32273258 expr : core .ArrayValue ,
32283259 left_ids : typing .Sequence [str ],
32293260 right_ids : typing .Sequence [str ],
32303261 how : str ,
32313262 drop : bool = True ,
32323263) -> Tuple [core .ArrayValue , Sequence [str ]]:
3264+ """
3265+ Collapses and selects the joining column IDs, with the assumption that
3266+ the ids are all belong to value columns.
3267+ """
32333268 result_ids = []
32343269 for left_id , right_id in zip (left_ids , right_ids ):
32353270 if how == "left" or how == "inner" or how == "cross" :
@@ -3241,7 +3276,6 @@ def coalesce_columns(
32413276 if drop :
32423277 expr = expr .drop_columns ([left_id ])
32433278 elif how == "outer" :
3244- coalesced_id = guid .generate_guid ()
32453279 expr , coalesced_id = expr .project_to_id (
32463280 ops .coalesce_op .as_expr (left_id , right_id )
32473281 )
@@ -3253,6 +3287,21 @@ def coalesce_columns(
32533287 return expr , result_ids
32543288
32553289
3290+ def coalesce_columns (
3291+ expr : core .ArrayValue ,
3292+ left_ids : typing .Sequence [str ],
3293+ right_ids : typing .Sequence [str ],
3294+ ) -> tuple [core .ArrayValue , list [str ]]:
3295+ result_ids = []
3296+ for left_id , right_id in zip (left_ids , right_ids ):
3297+ expr , coalesced_id = expr .project_to_id (
3298+ ops .coalesce_op .as_expr (left_id , right_id )
3299+ )
3300+ result_ids .append (coalesced_id )
3301+
3302+ return expr , result_ids
3303+
3304+
32563305def _cast_index (block : Block , dtypes : typing .Sequence [bigframes .dtypes .Dtype ]):
32573306 original_block = block
32583307 result_ids = []
@@ -3468,3 +3517,35 @@ def _pd_index_to_array_value(
34683517 rows .append (row )
34693518
34703519 return core .ArrayValue .from_pyarrow (pa .Table .from_pylist (rows ), session = session )
3520+
3521+
3522+ def _resolve_index_col (
3523+ left_index_cols : list [str ],
3524+ right_index_cols : list [str ],
3525+ resolved_join_ids : list [str ],
3526+ left_index : bool ,
3527+ right_index : bool ,
3528+ how : typing .Literal [
3529+ "inner" ,
3530+ "left" ,
3531+ "outer" ,
3532+ "right" ,
3533+ "cross" ,
3534+ ],
3535+ ) -> list [str ]:
3536+ if left_index and right_index :
3537+ if how == "inner" or how == "left" :
3538+ return left_index_cols
3539+ if how == "right" :
3540+ return right_index_cols
3541+ if how == "outer" :
3542+ return resolved_join_ids
3543+ else :
3544+ return []
3545+ elif left_index and not right_index :
3546+ return right_index_cols
3547+ elif right_index and not left_index :
3548+ return left_index_cols
3549+ else :
3550+ # Joining with value columns only. Existing indices will be discarded.
3551+ return []
0 commit comments