Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,61 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import functools
import operator
from math import isnan
from typing import Any

import pyarrow as pa
from pyarrow import Table as pyarrow_table
from pyarrow import compute as pc

from pyiceberg.expressions import (
AlwaysFalse,
And,
BooleanExpression,
EqualTo,
In,
IsNaN,
IsNull,
Or,
)


def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
filters: list[BooleanExpression] = []

if len(join_cols) == 1:
return In(join_cols[0], unique_keys[0].to_pylist())
column = join_cols[0]
values = set(unique_keys[0].to_pylist())

if None in values:
filters.append(IsNull(column))
values.remove(None)

if nans := {v for v in values if isinstance(v, float) and isnan(v)}:
filters.append(IsNaN(column))
values -= nans

filters.append(In(column, values))
else:
filters = [
functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()
]

if len(filters) == 0:
return AlwaysFalse()
elif len(filters) == 1:
return filters[0]
else:
return Or(*filters)
def equals(column: str, value: Any) -> BooleanExpression:
if value is None:
return IsNull(column)

if isinstance(value, float) and isnan(value):
return IsNaN(column)

return EqualTo(column, value)

filters = [And(*[equals(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()]

if len(filters) == 0:
return AlwaysFalse()
elif len(filters) == 1:
return filters[0]
else:
return Or(*filters)


def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
Expand Down Expand Up @@ -98,13 +121,16 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table))))

# Step 3: Perform an inner join to find which rows from source exist in target
matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
# PyArrow joins ignore null values, and we want null==null to hold, so we compute the join in Python.
# This is equivalent to:
# matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
source_indices = {tuple(row[col] for col in join_cols): row[SOURCE_INDEX_COLUMN_NAME] for row in source_index.to_pylist()}
target_indices = {tuple(row[col] for col in join_cols): row[TARGET_INDEX_COLUMN_NAME] for row in target_index.to_pylist()}
matching_indices = [(s, t) for key, s in source_indices.items() if (t := target_indices.get(key)) is not None]

# Step 4: Compare all rows using Python
to_update_indices = []
for source_idx, target_idx in zip(
matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(), matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist()
):
for source_idx, target_idx in matching_indices:
source_row = source_table.slice(source_idx, 1)
target_row = target_table.slice(target_idx, 1)

Expand Down
109 changes: 107 additions & 2 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from pathlib import PosixPath
from typing import Any

import pyarrow as pa
import pytest
Expand All @@ -23,8 +24,8 @@

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
from pyiceberg.expressions.literals import LongLiteral
from pyiceberg.expressions import AlwaysTrue, And, BooleanExpression, EqualTo, In, IsNaN, IsNull, Or, Reference
from pyiceberg.expressions.literals import DoubleLiteral, LongLiteral
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.schema import Schema
from pyiceberg.table import UpsertResult
Expand Down Expand Up @@ -440,6 +441,70 @@ def test_create_match_filter_single_condition() -> None:
)


@pytest.mark.parametrize(
"data, expected",
[
pytest.param(
[{"x": 1.0}, {"x": 2.0}, {"x": None}, {"x": 4.0}, {"x": float("nan")}],
Or(
left=IsNull(term=Reference(name="x")),
right=Or(
left=IsNaN(term=Reference(name="x")),
right=In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(4.0)}),
),
),
id="single-column",
),
pytest.param(
[
{"x": 1.0, "y": 9.0},
{"x": 2.0, "y": None},
{"x": None, "y": 7.0},
{"x": 4.0, "y": float("nan")},
{"x": float("nan"), "y": 0.0},
],
Or(
left=Or(
left=And(
left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(1.0)),
right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(9.0)),
),
right=And(
left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(2.0)),
right=IsNull(term=Reference(name="y")),
),
),
right=Or(
left=And(
left=IsNull(term=Reference(name="x")),
right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(7.0)),
),
right=Or(
left=And(
left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(4.0)),
right=IsNaN(term=Reference(name="y")),
),
right=And(
left=IsNaN(term=Reference(name="x")),
right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(0.0)),
),
),
),
),
id="multi-column",
),
],
)
def test_create_match_filter_with_nulls(data: list[dict[str, Any]], expected: BooleanExpression) -> None:
schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())])
table = pa.Table.from_pylist(data, schema=schema)
join_cols = sorted({col for record in data for col in record})

expr = create_match_filter(table, join_cols)

assert expr == expected


def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_duplicate_rows_in_table"

Expand Down Expand Up @@ -710,6 +775,46 @@ def test_upsert_with_nulls(catalog: Catalog) -> None:
schema=schema,
)

# upsert table with null value
data_with_null = pa.Table.from_pylist(
[
{"foo": None, "bar": 1, "baz": False},
],
schema=schema,
)
upd = table.upsert(data_with_null, join_cols=["foo"])
assert upd.rows_updated == 0
assert upd.rows_inserted == 1
assert table.scan().to_arrow() == pa.Table.from_pylist(
[
{"foo": None, "bar": 1, "baz": False},
{"foo": "apple", "bar": 7, "baz": False},
{"foo": "banana", "bar": None, "baz": False},
],
schema=schema,
)

# upsert table with null and non-null values, in two join columns
data_with_null = pa.Table.from_pylist(
[
{"foo": None, "bar": 1, "baz": True},
{"foo": "lemon", "bar": None, "baz": False},
],
schema=schema,
)
upd = table.upsert(data_with_null, join_cols=["foo", "bar"])
assert upd.rows_updated == 1
assert upd.rows_inserted == 1
assert table.scan().to_arrow() == pa.Table.from_pylist(
[
{"foo": "lemon", "bar": None, "baz": False},
{"foo": None, "bar": 1, "baz": True},
{"foo": "apple", "bar": 7, "baz": False},
{"foo": "banana", "bar": None, "baz": False},
],
schema=schema,
)


def test_transaction(catalog: Catalog) -> None:
"""Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is
Expand Down