Skip to content

Add iceberg support to table_diff #4441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 26, 2025
23 changes: 22 additions & 1 deletion sqlmesh/core/table_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@

from sqlmesh.core.dialect import to_schema
from sqlmesh.core.engine_adapter.mixins import RowDiffMixin
from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter
from sqlglot import exp, parse_one
from sqlglot.helper import ensure_list
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import quote_identifiers
from sqlglot.optimizer.scope import find_all_in_scope

from sqlmesh.utils.pydantic import PydanticModel
from sqlmesh.utils.errors import SQLMeshError

if t.TYPE_CHECKING:
from sqlmesh.core._typing import TableName
Expand Down Expand Up @@ -431,7 +433,26 @@ def name(e: exp.Expression) -> str:
schema = to_schema(temp_schema, dialect=self.dialect)
temp_table = exp.table_("diff", db=schema.db, catalog=schema.catalog, quoted=True)

with self.adapter.temp_table(query, name=temp_table) as table:
temp_table_kwargs = {}
if isinstance(self.adapter, AthenaEngineAdapter):
# Athena has two table formats: Hive (the default) and Iceberg. TableDiff requires that
# the formats be the same for the source, target, and temp tables.
source_table_type = self.adapter._query_table_type(self.source_table)
target_table_type = self.adapter._query_table_type(self.target_table)

if source_table_type == "iceberg" and target_table_type == "iceberg":
temp_table_kwargs["table_format"] = "iceberg"
# Sets the temp table's format to Iceberg.
# If neither source nor target table is Iceberg, it defaults to Hive (Athena's default).
elif source_table_type == "iceberg" or target_table_type == "iceberg":
raise SQLMeshError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is right. Table diff relies on copying a sample of data from each table into a single table, which on most engines is fine because they have a consistent table format.

I know that timestamp types in particular are handled differently between Athena/Iceberg and Athena/Hive so I dont think we can say "if Iceberg is detected - make the diff table Iceberg" because copying Hive data into it verbatim can still cause an error

f"Source table '{self.source}' format '{source_table_type}' and target table '{self.target}' format '{target_table_type}' "
f"do not match for Athena. Diffing between different table formats is not supported."
)

with self.adapter.temp_table(
query, name=temp_table, columns_to_types=None, **temp_table_kwargs
) as table:
summary_sums = [
exp.func("SUM", "s_exists").as_("s_count"),
exp.func("SUM", "t_exists").as_("t_count"),
Expand Down
82 changes: 82 additions & 0 deletions tests/core/engine_adapter/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlmesh.core.model import load_sql_based_model
from sqlmesh.core.model.definition import SqlModel
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.core.table_diff import TableDiff

from tests.core.engine_adapter import to_sql_calls

Expand All @@ -21,6 +22,16 @@ def adapter(make_mocked_engine_adapter: t.Callable) -> AthenaEngineAdapter:
return make_mocked_engine_adapter(AthenaEngineAdapter)


@pytest.fixture
def table_diff(adapter: AthenaEngineAdapter) -> TableDiff:
return TableDiff(
adapter=adapter,
source="source_table",
target="target_table",
on=["id"],
)


@pytest.mark.parametrize(
"config_s3_warehouse_location,table_properties,table,expected_location",
[
Expand Down Expand Up @@ -483,3 +494,74 @@ def test_iceberg_partition_transforms(adapter: AthenaEngineAdapter):
# Trino syntax - CTAS
"""CREATE TABLE IF NOT EXISTS "test_table" WITH (table_type='iceberg', partitioning=ARRAY['MONTH(business_date)', 'BUCKET(colb, 4)', 'colc'], location='s3://bucket/prefix/test_table/', is_external=false) AS SELECT CAST("business_date" AS TIMESTAMP) AS "business_date", CAST("colb" AS VARCHAR) AS "colb", CAST("colc" AS VARCHAR) AS "colc" FROM (SELECT CAST(1 AS TIMESTAMP) AS "business_date", CAST(2 AS VARCHAR) AS "colb", 'foo' AS "colc" LIMIT 0) AS "_subquery\"""",
]


@pytest.mark.parametrize(
"source_format, target_format, expected_temp_format, expect_error",
[
("hive", "hive", None, False),
("iceberg", "hive", None, True), # Expect error for mismatched formats
("hive", "iceberg", None, True), # Expect error for mismatched formats
("iceberg", "iceberg", "iceberg", False),
(None, "iceberg", None, True), # Source doesn't exist or type unknown, target is iceberg
(
"iceberg",
None,
"iceberg",
True,
), # Target doesn't exist or type unknown, source is iceberg
(None, "hive", None, False), # Source doesn't exist or type unknown, target is hive
("hive", None, None, False), # Target doesn't exist or type unknown, source is hive
(None, None, None, False), # Both don't exist or types unknown
],
)
def test_table_diff_temp_table_format(
table_diff: TableDiff,
mocker: MockerFixture,
source_format: t.Optional[str],
target_format: t.Optional[str],
expected_temp_format: t.Optional[str],
expect_error: bool,
):
adapter = t.cast(AthenaEngineAdapter, table_diff.adapter)

# Mock _query_table_type to return specified formats
def mock_query_table_type(table_name: exp.Table) -> t.Optional[str]:
if table_name.name == "source_table":
return source_format
if table_name.name == "target_table":
return target_format
return "hive" # Default for other tables if any

mocker.patch.object(adapter, "_query_table_type", side_effect=mock_query_table_type)

# Mock temp_table to capture kwargs
mock_temp_table = mocker.patch.object(adapter, "temp_table", autospec=True)
mock_temp_table.return_value.__enter__.return_value = exp.to_table("diff_table")

# Mock fetchdf and other calls made within row_diff to avoid actual DB interaction
mocker.patch.object(adapter, "fetchdf", return_value=pd.DataFrame())
mocker.patch.object(adapter, "get_data_objects", return_value=[])
mocker.patch.object(adapter, "columns", return_value={"id": exp.DataType.build("int")})

if expect_error:
with pytest.raises(
SQLMeshError,
match="do not match for Athena. Diffing between different table formats is not supported.",
):
table_diff.row_diff()
mock_temp_table.assert_not_called() # temp_table should not be called if formats mismatch
return

try:
table_diff.row_diff()
except Exception:
pass # We only care about the temp_table call args for non-error cases

mock_temp_table.assert_called_once()
_, called_kwargs = mock_temp_table.call_args

if expected_temp_format:
assert called_kwargs.get("table_format") == expected_temp_format
else:
assert "table_format" not in called_kwargs