@@ -121,13 +121,31 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
121
121
target_index = target_table .select (join_cols_set ).append_column (TARGET_INDEX_COLUMN_NAME , pa .array (range (len (target_table ))))
122
122
123
123
# Step 3: Perform an inner join to find which rows from source exist in target
124
- matching_indices = source_index .join (target_index , keys = list (join_cols_set ), join_type = "inner" )
124
+ # PyArrow joins ignore null values, and we want null==null to hold, so we compute the join in Python.
125
+ # This is equivalent to:
126
+ # matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
127
+ source_indices = {}
128
+ target_indices = {}
129
+
130
+ for row in source_index .to_pylist ():
131
+ idx = row .pop (SOURCE_INDEX_COLUMN_NAME )
132
+ key = tuple (row .values ())
133
+ source_indices [key ] = idx
134
+
135
+ for row in target_index .to_pylist ():
136
+ idx = row .pop (TARGET_INDEX_COLUMN_NAME )
137
+ key = tuple (row .values ())
138
+ target_indices [key ] = idx
139
+
140
+ matching_indices = [
141
+ (source_idx , target_idx )
142
+ for key , source_idx in source_indices .items ()
143
+ if (target_idx := target_indices .get (key )) is not None
144
+ ]
125
145
126
146
# Step 4: Compare all rows using Python
127
147
to_update_indices = []
128
- for source_idx , target_idx in zip (
129
- matching_indices [SOURCE_INDEX_COLUMN_NAME ].to_pylist (), matching_indices [TARGET_INDEX_COLUMN_NAME ].to_pylist ()
130
- ):
148
+ for source_idx , target_idx in matching_indices :
131
149
source_row = source_table .slice (source_idx , 1 )
132
150
target_row = target_table .slice (target_idx , 1 )
133
151
0 commit comments