Skip to content

Commit 76fc451

Browse files
committed
test: add test cases for create_match_filter
1 parent 31ed623 commit 76fc451

File tree

1 file changed

+67
-2
lines changed

1 file changed

+67
-2
lines changed

tests/table/test_upsert.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
from pyiceberg.catalog import Catalog
2525
from pyiceberg.exceptions import NoSuchTableError
26-
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
27-
from pyiceberg.expressions.literals import LongLiteral
26+
from pyiceberg.expressions import AlwaysTrue, And, BooleanExpression, EqualTo, In, IsNaN, IsNull, Or, Reference
27+
from pyiceberg.expressions.literals import DoubleLiteral, LongLiteral
2828
from pyiceberg.io.pyarrow import schema_to_pyarrow
2929
from pyiceberg.schema import Schema
3030
from pyiceberg.table import UpsertResult
@@ -440,6 +440,71 @@ def test_create_match_filter_single_condition() -> None:
440440
)
441441

442442

443+
@pytest.mark.parametrize(
444+
"data, join_cols, expected",
445+
[
446+
pytest.param(
447+
[{"x": 1.0}, {"x": 2.0}, {"x": None}, {"x": 4.0}, {"x": float("nan")}],
448+
["x"],
449+
Or(
450+
left=IsNull(term=Reference(name="x")),
451+
right=Or(
452+
left=IsNaN(term=Reference(name="x")),
453+
right=In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(4.0)}),
454+
),
455+
),
456+
id="single-column",
457+
),
458+
pytest.param(
459+
[
460+
{"x": 1.0, "y": 9.0},
461+
{"x": 2.0, "y": None},
462+
{"x": None, "y": 7.0},
463+
{"x": 4.0, "y": float("nan")},
464+
{"x": float("nan"), "y": 0.0},
465+
],
466+
["x", "y"],
467+
Or(
468+
left=Or(
469+
left=And(
470+
left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(1.0)),
471+
right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(9.0)),
472+
),
473+
right=And(
474+
left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(2.0)),
475+
right=IsNull(term=Reference(name="y")),
476+
),
477+
),
478+
right=Or(
479+
left=And(
480+
left=IsNull(term=Reference(name="x")),
481+
right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(7.0)),
482+
),
483+
right=Or(
484+
left=And(
485+
left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(4.0)),
486+
right=IsNaN(term=Reference(name="y")),
487+
),
488+
right=And(
489+
left=IsNaN(term=Reference(name="x")),
490+
right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(0.0)),
491+
),
492+
),
493+
),
494+
),
495+
id="multi-column",
496+
),
497+
],
498+
)
499+
def test_create_match_filter_with_nulls(data: list[dict], join_cols: list[str], expected: BooleanExpression) -> None:
500+
schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())])
501+
table = pa.Table.from_pylist(data, schema=schema)
502+
503+
expr = create_match_filter(table, join_cols)
504+
505+
assert expr == expected
506+
507+
443508
def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
444509
identifier = "default.test_upsert_with_duplicate_rows_in_table"
445510

0 commit comments

Comments
 (0)