Skip to content

Commit 6a7aa3f

Browse files
committed
fix: respect null values in inner join in get_rows_to_update
1 parent 075a966 commit 6a7aa3f

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

pyiceberg/table/upsert_util.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,31 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
121121
target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table))))
122122

123123
# Step 3: Perform an inner join to find which rows from source exist in target
124-
matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
124+
# PyArrow joins ignore null values, and we want null==null to hold, so we compute the join in Python.
125+
# This is equivalent to:
126+
# matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
127+
source_indices = {}
128+
target_indices = {}
129+
130+
for row in source_index.to_pylist():
131+
idx = row.pop(SOURCE_INDEX_COLUMN_NAME)
132+
key = tuple(row.values())
133+
source_indices[key] = idx
134+
135+
for row in target_index.to_pylist():
136+
idx = row.pop(TARGET_INDEX_COLUMN_NAME)
137+
key = tuple(row.values())
138+
target_indices[key] = idx
139+
140+
matching_indices = [
141+
(source_idx, target_idx)
142+
for key, source_idx in source_indices.items()
143+
if (target_idx := target_indices.get(key)) is not None
144+
]
125145

126146
# Step 4: Compare all rows using Python
127147
to_update_indices = []
128-
for source_idx, target_idx in zip(
129-
matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(), matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist()
130-
):
148+
for source_idx, target_idx in matching_indices:
131149
source_row = source_table.slice(source_idx, 1)
132150
target_row = target_table.slice(target_idx, 1)
133151

0 commit comments

Comments
 (0)