diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index 48bf632954..f2fcf1d687 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -271,10 +271,7 @@ def create_load_job( ) def _get_table_update_sql( - self, - table_name: str, - new_columns: Sequence[TColumnSchema], - generate_alter: bool, + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: table = self.prepare_load_table(table_name) sql = SqlJobClientBase._get_table_update_sql(self, table_name, new_columns, generate_alter) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 5209c247f3..29d541e9e4 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -388,11 +388,7 @@ def _get_constraints_sql( return "" def _get_table_update_sql( - self, - table_name: str, - new_columns: Sequence[TColumnSchema], - generate_alter: bool, - separate_alters: bool = False, + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: table = self.prepare_load_table(table_name) diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index 7eb4f21214..84718cd774 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -129,11 +129,7 @@ def create_load_job( return job def _get_table_update_sql( - self, - table_name: str, - new_columns: Sequence[TColumnSchema], - generate_alter: bool, - separate_alters: bool = False, + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) diff --git a/dlt/destinations/impl/ducklake/ducklake.py b/dlt/destinations/impl/ducklake/ducklake.py index c9716725b8..9722326424 100644 --- a/dlt/destinations/impl/ducklake/ducklake.py +++ b/dlt/destinations/impl/ducklake/ducklake.py @@ -134,11 +134,7 @@ def create_load_job( return job def _get_table_update_sql( - self, - table_name: str, - new_columns: Sequence[TColumnSchema], - generate_alter: bool, - separate_alters: bool = False, + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 9d212bfc9a..8d725c5e25 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -196,24 +196,52 @@ def _get_constraints_sql( return f",\nCONSTRAINT {pk_constraint_name} PRIMARY KEY ({quoted_pk_cols})" return "" - def _get_table_update_sql( + def _get_cluster_sql(self, cluster_column_names: Sequence[str]) -> str: + if cluster_column_names: + cluster_column_names_str = ",".join( + [self.sql_client.escape_column_name(col) for col in cluster_column_names] + ) + return f"CLUSTER BY ({cluster_column_names_str})" + else: + return "DROP CLUSTERING KEY" + + def _get_alter_cluster_sql(self, table_name: str, cluster_column_names: Sequence[str]) -> str: + qualified_name = self.sql_client.make_qualified_table_name(table_name) + return self._make_alter_table(qualified_name) + self._get_cluster_sql(cluster_column_names) + + def _add_cluster_sql( self, + sql: List[str], table_name: str, - new_columns: Sequence[TColumnSchema], generate_alter: bool, - separate_alters: bool = False, ) -> List[str]: - sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) + """Adds CLUSTER BY / DROP CLUSTERING KEY clause to SQL statements based on cluster hints. + + This method modifies the input `sql` list in place and also returns it. + """ - cluster_list = [ - self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("cluster") + cluster_column_names = [ + c["name"] + for c in self.schema.get_table_columns(table_name).values() + if c.get("cluster") ] - if cluster_list: - sql[0] = sql[0] + "\nCLUSTER BY (" + ",".join(cluster_list) + ")" + if generate_alter: + # altering -> need to issue separate ALTER TABLE statement for cluster operations + stmt = self._get_alter_cluster_sql(table_name, cluster_column_names) + sql.append(stmt) + elif not generate_alter and cluster_column_names: + # creating -> can append CLUSTER BY clause to CREATE TABLE statement + sql[0] = sql[0] + self._get_cluster_sql(cluster_column_names) return sql + def _get_table_update_sql( + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + ) -> List[str]: + sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) + return self._add_cluster_sql(sql, table_name, generate_alter) + def _from_db_type( self, bq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 7707e5159b..0fcbfa1666 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -695,6 +695,11 @@ def _make_create_table(self, qualified_name: str, table: PreparedTableSchema) -> not_exists_clause = " IF NOT EXISTS " return f"CREATE TABLE{not_exists_clause}{qualified_name}" + @staticmethod + def _make_alter_table(qualified_name: str) -> str: + """Begins ALTER TABLE statement""" + return f"ALTER TABLE {qualified_name}\n" + def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: @@ -712,7 +717,7 @@ def _get_table_update_sql( sql += ")" sql_result.append(sql) else: - sql_base = f"ALTER TABLE {qualified_name}\n" + sql_base = self._make_alter_table(qualified_name) add_column_statements = self._make_add_column_sql(new_columns, table) if self.capabilities.alter_add_multi_column: column_sql = ",\n" diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index e174ad6ae1..f40bce6921 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -232,7 +232,7 @@ Note that we ignore missing columns `ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE` and ## Supported column hints Snowflake supports the following [column hints](../../general-usage/schema#tables-and-columns): -* `cluster` - Creates a cluster column(s). Many columns per table are supported and only when a new table is created. +* `cluster` - Makes column part of [cluster key](https://docs.snowflake.com/en/user-guide/tables-clustering-keys), can be added to many columns. The `cluster` columns are added to the cluster key in order of appearance in the table schema. Changing `cluster` hints after table creation is supported, but the changes will only be applied if/when a new column is added. * `unique` - Creates UNIQUE hint on a Snowflake column, can be added to many columns. ([optional](#additional-destination-options)) * `primary_key` - Creates PRIMARY KEY on selected column(s), may be compound. ([optional](#additional-destination-options)) diff --git a/tests/load/pipeline/test_snowflake_pipeline.py b/tests/load/pipeline/test_snowflake_pipeline.py index 24c5418bd1..ebb6793a67 100644 --- a/tests/load/pipeline/test_snowflake_pipeline.py +++ b/tests/load/pipeline/test_snowflake_pipeline.py @@ -1,6 +1,7 @@ from copy import deepcopy import os import pytest +from typing import cast from pytest_mock import MockerFixture import dlt @@ -315,6 +316,79 @@ def my_resource(): ) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["snowflake"]), + ids=lambda x: x.name, +) +def test_snowflake_cluster_hints(destination_config: DestinationTestConfiguration) -> None: + from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient + + def get_cluster_key(sql_client: SnowflakeSqlClient, table_name: str) -> str: + with sql_client: + _catalog_name, schema_name, table_names = sql_client._get_information_schema_components( + table_name + ) + qry = f""" + SELECT CLUSTERING_KEY FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = '{schema_name}' + AND TABLE_NAME = '{table_names[0]}' + """ + return sql_client.execute_sql(qry)[0][0] + + pipeline = destination_config.setup_pipeline("test_snowflake_cluster_hints", dev_mode=True) + sql_client = cast(SnowflakeSqlClient, pipeline.sql_client()) + table_name = "test_snowflake_cluster_hints" + + @dlt.resource(table_name=table_name) + def test_data(): + return [ + {"c1": 1, "c2": "a"}, + {"c1": 2, "c2": "b"}, + ] + + # create new table with clustering + test_data.apply_hints(columns=[{"name": "c1", "cluster": True}]) + info = pipeline.run(test_data(), **destination_config.run_kwargs) + assert_load_info(info) + assert get_cluster_key(sql_client, table_name) == 'LINEAR("C1")' + + # change cluster hints on existing table without adding new column + test_data.apply_hints(columns=[{"name": "c2", "cluster": True}]) + info = pipeline.run(test_data(), **destination_config.run_kwargs) + assert_load_info(info) + assert get_cluster_key(sql_client, table_name) == 'LINEAR("C1")' # unchanged (no new column) + + # add new column to existing table with pending cluster hints from previous run + test_data.apply_hints(columns=[{"name": "c3", "data_type": "bool"}]) + info = pipeline.run(test_data(), **destination_config.run_kwargs) + assert_load_info(info) + assert get_cluster_key(sql_client, table_name) == 'LINEAR("C1","C2")' # updated + + # remove clustering from existing table + test_data.apply_hints( + columns=[ + {"name": "c1", "cluster": False}, + {"name": "c2", "cluster": False}, + {"name": "c4", "data_type": "bool"}, # include new column to trigger alter + ] + ) + info = pipeline.run(test_data(), **destination_config.run_kwargs) + assert_load_info(info) + assert get_cluster_key(sql_client, table_name) is None + + # add clustering to existing table (and add new column to trigger alter) + test_data.apply_hints( + columns=[ + {"name": "c1", "cluster": True}, + {"name": "c5", "data_type": "bool"}, # include new column to trigger alter + ] + ) + info = pipeline.run(test_data(), **destination_config.run_kwargs) + assert_load_info(info) + assert get_cluster_key(sql_client, table_name) == 'LINEAR("C1")' + + @pytest.mark.skip(reason="perf test for merge") @pytest.mark.parametrize( "destination_config", diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index 24c1091fa5..72d614ccf0 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -8,7 +8,8 @@ import sqlfluff from dlt.common.utils import uniq_id -from dlt.common.schema import Schema, utils +from dlt.common.schema import Schema +from dlt.common.schema.utils import new_table from dlt.destinations import snowflake from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient, SUPPORTED_HINTS from dlt.destinations.impl.snowflake.configuration import ( @@ -119,46 +120,103 @@ def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: assert 'CONSTRAINT "PK_EVENT_TEST_TABLE_' in sql assert 'PRIMARY KEY ("COL1", "COL6")' in sql - # generate alter - mod_update = deepcopy(TABLE_UPDATE[11:]) - mod_update[0]["primary_key"] = True - mod_update[1]["unique"] = True - - sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, True)) - # PK constraint ignored for alter - assert "PRIMARY KEY" not in sql - assert '"COL2_NULL" FLOAT UNIQUE' in sql - def test_alter_table(snowflake_client: SnowflakeClient) -> None: - statements = snowflake_client._get_table_update_sql("event_test_table", TABLE_UPDATE, True) - assert len(statements) == 1 - sql = statements[0] + new_columns = deepcopy(TABLE_UPDATE[1:10]) + statements = snowflake_client._get_table_update_sql("event_test_table", new_columns, True) - # TODO: sqlfluff doesn't parse snowflake multi ADD COLUMN clause correctly - # sqlfluff.parse(sql, dialect='snowflake') + assert len(statements) == 2, "Should have one ADD COLUMN and one DROP CLUSTERING KEY statement" + add_column_sql = statements[0] - assert sql.startswith("ALTER TABLE") - assert sql.count("ALTER TABLE") == 1 - assert sql.count("ADD COLUMN") == 1 - assert '"EVENT_TEST_TABLE"' in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql - assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql - assert '"COL7" BINARY' in sql - assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql - assert '"COL10" DATE' in sql + # TODO: sqlfluff doesn't parse snowflake multi ADD COLUMN clause correctly + # sqlfluff.parse(add_column_sql, dialect='snowflake') + + assert add_column_sql.startswith("ALTER TABLE") + assert add_column_sql.count("ALTER TABLE") == 1 + assert add_column_sql.count("ADD COLUMN") == 1 + assert '"EVENT_TEST_TABLE"' in add_column_sql + assert '"COL1"' not in add_column_sql + assert '"COL2" FLOAT NOT NULL' in add_column_sql + assert '"COL3" BOOLEAN NOT NULL' in add_column_sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in add_column_sql + assert '"COL5" VARCHAR' in add_column_sql + assert '"COL6" NUMBER(38,9) NOT NULL' in add_column_sql + assert '"COL7" BINARY' in add_column_sql + assert '"COL8" NUMBER(38,0)' in add_column_sql + assert '"COL9" VARIANT NOT NULL' in add_column_sql + assert '"COL10" DATE' in add_column_sql + + +def test_alter_table_with_hints(snowflake_client: SnowflakeClient) -> None: + table_name = "event_test_table" - mod_table = deepcopy(TABLE_UPDATE) - mod_table.pop(0) - sql = snowflake_client._get_table_update_sql("event_test_table", mod_table, True)[0] + # mock hints + snowflake_client.active_hints = SUPPORTED_HINTS - assert '"COL1"' not in sql - assert '"COL2" FLOAT NOT NULL' in sql + # test primary key and unique hints + new_columns = deepcopy(TABLE_UPDATE[11:]) + new_columns[0]["primary_key"] = True + new_columns[1]["unique"] = True + statements = snowflake_client._get_table_update_sql(table_name, new_columns, True) + + assert len(statements) == 2, "Should have one ADD COLUMN and one DROP CLUSTERING KEY statement" + add_column_sql = statements[0] + assert "PRIMARY KEY" not in add_column_sql # PK constraint ignored for alter + assert '"COL2_NULL" FLOAT UNIQUE' in add_column_sql + + # test cluster hint + + # case: drop clustering (always run if no cluster hints present in table schema) + cluster_by_sql = statements[1] + + assert cluster_by_sql.startswith("ALTER TABLE") + assert f'"{table_name.upper()}"' in cluster_by_sql + assert cluster_by_sql.endswith("DROP CLUSTERING KEY") + + # case: add clustering (without clustering -> with clustering) + old_columns = deepcopy(TABLE_UPDATE[:1]) + new_columns = deepcopy(TABLE_UPDATE[1:2]) + new_columns[0]["cluster"] = True # COL2 + all_columns = deepcopy(old_columns + new_columns) + snowflake_client.schema.update_table(new_table(table_name, columns=deepcopy(all_columns))) + statements = snowflake_client._get_table_update_sql(table_name, new_columns, True) + + assert len(statements) == 2, "Should have one ADD COLUMN and one CLUSTER BY statement" + cluster_by_sql = statements[1] + assert cluster_by_sql.startswith("ALTER TABLE") + assert f'"{table_name.upper()}"' in cluster_by_sql + assert 'CLUSTER BY ("COL2")' in cluster_by_sql + + # case: modify clustering (extend cluster columns) + old_columns = deepcopy(TABLE_UPDATE[:2]) + old_columns[1]["cluster"] = True # COL2 + new_columns = deepcopy(TABLE_UPDATE[2:5]) + new_columns[2]["cluster"] = True # COL5 + all_columns = deepcopy(old_columns + new_columns) + snowflake_client.schema.update_table(new_table(table_name, columns=all_columns)) + statements = snowflake_client._get_table_update_sql(table_name, new_columns, True) + + assert len(statements) == 2, "Should have one ADD COLUMN and one CLUSTER BY statement" + cluster_by_sql = statements[1] + assert cluster_by_sql.count("ALTER TABLE") == 1 + assert cluster_by_sql.count("CLUSTER BY") == 1 + assert 'CLUSTER BY ("COL2","COL5")' in cluster_by_sql + + # case: modify clustering (reorder cluster columns) + old_columns = deepcopy(TABLE_UPDATE[:5]) + old_columns[1]["cluster"] = True # COL2 + old_columns[4]["cluster"] = True # COL5 + old_columns[1], old_columns[4] = old_columns[4], old_columns[1] # swap order + new_columns = deepcopy(TABLE_UPDATE[5:6]) + all_columns = deepcopy(old_columns + new_columns) + # cannot change column order in existing table schema, so we drop and recreate + snowflake_client.schema.drop_tables([table_name]) + snowflake_client.schema.update_table(new_table(table_name, columns=all_columns)) + statements = snowflake_client._get_table_update_sql(table_name, new_columns, True) + + assert len(statements) == 2, "Should have one ADD COLUMN and one CLUSTER BY statement" + cluster_by_sql = statements[1] + assert 'CLUSTER BY ("COL5","COL2")' in cluster_by_sql # reordered (COL5 first) def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None: @@ -170,9 +228,7 @@ def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None: assert cs_client.sql_client.dataset_name.endswith("staginG") assert cs_client.sql_client.staging_dataset_name.endswith("staginG") # check tables - cs_client.schema.update_table( - utils.new_table("event_test_table", columns=deepcopy(TABLE_UPDATE)) - ) + cs_client.schema.update_table(new_table("event_test_table", columns=deepcopy(TABLE_UPDATE))) sql = cs_client._get_table_update_sql( "Event_test_tablE", list(cs_client.schema.get_table_columns("Event_test_tablE").values()), @@ -192,7 +248,9 @@ def test_create_table_with_partition_and_cluster(snowflake_client: SnowflakeClie mod_update[3]["partition"] = True mod_update[4]["cluster"] = True mod_update[1]["cluster"] = True - statements = snowflake_client._get_table_update_sql("event_test_table", mod_update, False) + table_name = "event_test_table" + snowflake_client.schema.update_table(new_table(table_name, columns=deepcopy(mod_update))) + statements = snowflake_client._get_table_update_sql(table_name, mod_update, False) assert len(statements) == 1 sql = statements[0]