Skip to content

Commit 9619002

Browse files
feat: snowflake clustering key modifications (#3365)
* add support for snowflake clustering key modifications * add cluster column order test case * update snowflake cluster hint docs * switch to reading snowflake cluster hints from table schema
1 parent 1e73d67 commit 9619002

File tree

9 files changed

+218
-68
lines changed

9 files changed

+218
-68
lines changed

dlt/destinations/impl/clickhouse/clickhouse.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,7 @@ def create_load_job(
271271
)
272272

273273
def _get_table_update_sql(
274-
self,
275-
table_name: str,
276-
new_columns: Sequence[TColumnSchema],
277-
generate_alter: bool,
274+
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
278275
) -> List[str]:
279276
table = self.prepare_load_table(table_name)
280277
sql = SqlJobClientBase._get_table_update_sql(self, table_name, new_columns, generate_alter)

dlt/destinations/impl/databricks/databricks.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,7 @@ def _get_constraints_sql(
393393
return ""
394394

395395
def _get_table_update_sql(
396-
self,
397-
table_name: str,
398-
new_columns: Sequence[TColumnSchema],
399-
generate_alter: bool,
400-
separate_alters: bool = False,
396+
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
401397
) -> List[str]:
402398
table = self.prepare_load_table(table_name)
403399

dlt/destinations/impl/dremio/dremio.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,7 @@ def create_load_job(
129129
return job
130130

131131
def _get_table_update_sql(
132-
self,
133-
table_name: str,
134-
new_columns: Sequence[TColumnSchema],
135-
generate_alter: bool,
136-
separate_alters: bool = False,
132+
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
137133
) -> List[str]:
138134
sql = super()._get_table_update_sql(table_name, new_columns, generate_alter)
139135

dlt/destinations/impl/ducklake/ducklake.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,7 @@ def create_load_job(
134134
return job
135135

136136
def _get_table_update_sql(
137-
self,
138-
table_name: str,
139-
new_columns: Sequence[TColumnSchema],
140-
generate_alter: bool,
141-
separate_alters: bool = False,
137+
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
142138
) -> List[str]:
143139
sql = super()._get_table_update_sql(table_name, new_columns, generate_alter)
144140

dlt/destinations/impl/snowflake/snowflake.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,24 +196,52 @@ def _get_constraints_sql(
196196
return f",\nCONSTRAINT {pk_constraint_name} PRIMARY KEY ({quoted_pk_cols})"
197197
return ""
198198

199-
def _get_table_update_sql(
199+
def _get_cluster_sql(self, cluster_column_names: Sequence[str]) -> str:
200+
if cluster_column_names:
201+
cluster_column_names_str = ",".join(
202+
[self.sql_client.escape_column_name(col) for col in cluster_column_names]
203+
)
204+
return f"CLUSTER BY ({cluster_column_names_str})"
205+
else:
206+
return "DROP CLUSTERING KEY"
207+
208+
def _get_alter_cluster_sql(self, table_name: str, cluster_column_names: Sequence[str]) -> str:
209+
qualified_name = self.sql_client.make_qualified_table_name(table_name)
210+
return self._make_alter_table(qualified_name) + self._get_cluster_sql(cluster_column_names)
211+
212+
def _add_cluster_sql(
200213
self,
214+
sql: List[str],
201215
table_name: str,
202-
new_columns: Sequence[TColumnSchema],
203216
generate_alter: bool,
204-
separate_alters: bool = False,
205217
) -> List[str]:
206-
sql = super()._get_table_update_sql(table_name, new_columns, generate_alter)
218+
"""Adds CLUSTER BY / DROP CLUSTERING KEY clause to SQL statements based on cluster hints.
219+
220+
This method modifies the input `sql` list in place and also returns it.
221+
"""
207222

208-
cluster_list = [
209-
self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("cluster")
223+
cluster_column_names = [
224+
c["name"]
225+
for c in self.schema.get_table_columns(table_name).values()
226+
if c.get("cluster")
210227
]
211228

212-
if cluster_list:
213-
sql[0] = sql[0] + "\nCLUSTER BY (" + ",".join(cluster_list) + ")"
229+
if generate_alter:
230+
# altering -> need to issue separate ALTER TABLE statement for cluster operations
231+
stmt = self._get_alter_cluster_sql(table_name, cluster_column_names)
232+
sql.append(stmt)
233+
elif not generate_alter and cluster_column_names:
234+
# creating -> can append CLUSTER BY clause to CREATE TABLE statement
235+
sql[0] = sql[0] + self._get_cluster_sql(cluster_column_names)
214236

215237
return sql
216238

239+
def _get_table_update_sql(
240+
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
241+
) -> List[str]:
242+
sql = super()._get_table_update_sql(table_name, new_columns, generate_alter)
243+
return self._add_cluster_sql(sql, table_name, generate_alter)
244+
217245
def _from_db_type(
218246
self, bq_t: str, precision: Optional[int], scale: Optional[int]
219247
) -> TColumnType:

dlt/destinations/job_client_impl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,11 @@ def _make_create_table(self, qualified_name: str, table: PreparedTableSchema) ->
695695
not_exists_clause = " IF NOT EXISTS "
696696
return f"CREATE TABLE{not_exists_clause}{qualified_name}"
697697

698+
@staticmethod
699+
def _make_alter_table(qualified_name: str) -> str:
700+
"""Begins ALTER TABLE statement"""
701+
return f"ALTER TABLE {qualified_name}\n"
702+
698703
def _get_table_update_sql(
699704
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
700705
) -> List[str]:
@@ -712,7 +717,7 @@ def _get_table_update_sql(
712717
sql += ")"
713718
sql_result.append(sql)
714719
else:
715-
sql_base = f"ALTER TABLE {qualified_name}\n"
720+
sql_base = self._make_alter_table(qualified_name)
716721
add_column_statements = self._make_add_column_sql(new_columns, table)
717722
if self.capabilities.alter_add_multi_column:
718723
column_sql = ",\n"

docs/website/docs/dlt-ecosystem/destinations/snowflake.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ Note that we ignore missing columns `ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE` and
232232

233233
## Supported column hints
234234
Snowflake supports the following [column hints](../../general-usage/schema#tables-and-columns):
235-
* `cluster` - Creates a cluster column(s). Many columns per table are supported and only when a new table is created.
235+
* `cluster` - Makes column part of [cluster key](https://docs.snowflake.com/en/user-guide/tables-clustering-keys), can be added to many columns. The `cluster` columns are added to the cluster key in order of appearance in the table schema. Changing `cluster` hints after table creation is supported, but the changes will only be applied if/when a new column is added.
236236
* `unique` - Creates UNIQUE hint on a Snowflake column, can be added to many columns. ([optional](#additional-destination-options))
237237
* `primary_key` - Creates PRIMARY KEY on selected column(s), may be compound. ([optional](#additional-destination-options))
238238

tests/load/pipeline/test_snowflake_pipeline.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from copy import deepcopy
22
import os
33
import pytest
4+
from typing import cast
45
from pytest_mock import MockerFixture
56

67
import dlt
@@ -315,6 +316,79 @@ def my_resource():
315316
)
316317

317318

319+
@pytest.mark.parametrize(
320+
"destination_config",
321+
destinations_configs(default_sql_configs=True, subset=["snowflake"]),
322+
ids=lambda x: x.name,
323+
)
324+
def test_snowflake_cluster_hints(destination_config: DestinationTestConfiguration) -> None:
325+
from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient
326+
327+
def get_cluster_key(sql_client: SnowflakeSqlClient, table_name: str) -> str:
328+
with sql_client:
329+
_catalog_name, schema_name, table_names = sql_client._get_information_schema_components(
330+
table_name
331+
)
332+
qry = f"""
333+
SELECT CLUSTERING_KEY FROM INFORMATION_SCHEMA.TABLES
334+
WHERE TABLE_SCHEMA = '{schema_name}'
335+
AND TABLE_NAME = '{table_names[0]}'
336+
"""
337+
return sql_client.execute_sql(qry)[0][0]
338+
339+
pipeline = destination_config.setup_pipeline("test_snowflake_cluster_hints", dev_mode=True)
340+
sql_client = cast(SnowflakeSqlClient, pipeline.sql_client())
341+
table_name = "test_snowflake_cluster_hints"
342+
343+
@dlt.resource(table_name=table_name)
344+
def test_data():
345+
return [
346+
{"c1": 1, "c2": "a"},
347+
{"c1": 2, "c2": "b"},
348+
]
349+
350+
# create new table with clustering
351+
test_data.apply_hints(columns=[{"name": "c1", "cluster": True}])
352+
info = pipeline.run(test_data(), **destination_config.run_kwargs)
353+
assert_load_info(info)
354+
assert get_cluster_key(sql_client, table_name) == 'LINEAR("C1")'
355+
356+
# change cluster hints on existing table without adding new column
357+
test_data.apply_hints(columns=[{"name": "c2", "cluster": True}])
358+
info = pipeline.run(test_data(), **destination_config.run_kwargs)
359+
assert_load_info(info)
360+
assert get_cluster_key(sql_client, table_name) == 'LINEAR("C1")' # unchanged (no new column)
361+
362+
# add new column to existing table with pending cluster hints from previous run
363+
test_data.apply_hints(columns=[{"name": "c3", "data_type": "bool"}])
364+
info = pipeline.run(test_data(), **destination_config.run_kwargs)
365+
assert_load_info(info)
366+
assert get_cluster_key(sql_client, table_name) == 'LINEAR("C1","C2")' # updated
367+
368+
# remove clustering from existing table
369+
test_data.apply_hints(
370+
columns=[
371+
{"name": "c1", "cluster": False},
372+
{"name": "c2", "cluster": False},
373+
{"name": "c4", "data_type": "bool"}, # include new column to trigger alter
374+
]
375+
)
376+
info = pipeline.run(test_data(), **destination_config.run_kwargs)
377+
assert_load_info(info)
378+
assert get_cluster_key(sql_client, table_name) is None
379+
380+
# add clustering to existing table (and add new column to trigger alter)
381+
test_data.apply_hints(
382+
columns=[
383+
{"name": "c1", "cluster": True},
384+
{"name": "c5", "data_type": "bool"}, # include new column to trigger alter
385+
]
386+
)
387+
info = pipeline.run(test_data(), **destination_config.run_kwargs)
388+
assert_load_info(info)
389+
assert get_cluster_key(sql_client, table_name) == 'LINEAR("C1")'
390+
391+
318392
@pytest.mark.skip(reason="perf test for merge")
319393
@pytest.mark.parametrize(
320394
"destination_config",

tests/load/snowflake/test_snowflake_table_builder.py

Lines changed: 97 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import sqlfluff
99

1010
from dlt.common.utils import uniq_id
11-
from dlt.common.schema import Schema, utils
11+
from dlt.common.schema import Schema
12+
from dlt.common.schema.utils import new_table
1213
from dlt.destinations import snowflake
1314
from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient, SUPPORTED_HINTS
1415
from dlt.destinations.impl.snowflake.configuration import (
@@ -119,46 +120,103 @@ def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None:
119120
assert 'CONSTRAINT "PK_EVENT_TEST_TABLE_' in sql
120121
assert 'PRIMARY KEY ("COL1", "COL6")' in sql
121122

122-
# generate alter
123-
mod_update = deepcopy(TABLE_UPDATE[11:])
124-
mod_update[0]["primary_key"] = True
125-
mod_update[1]["unique"] = True
126-
127-
sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, True))
128-
# PK constraint ignored for alter
129-
assert "PRIMARY KEY" not in sql
130-
assert '"COL2_NULL" FLOAT UNIQUE' in sql
131-
132123

133124
def test_alter_table(snowflake_client: SnowflakeClient) -> None:
134-
statements = snowflake_client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)
135-
assert len(statements) == 1
136-
sql = statements[0]
125+
new_columns = deepcopy(TABLE_UPDATE[1:10])
126+
statements = snowflake_client._get_table_update_sql("event_test_table", new_columns, True)
137127

138-
# TODO: sqlfluff doesn't parse snowflake multi ADD COLUMN clause correctly
139-
# sqlfluff.parse(sql, dialect='snowflake')
128+
assert len(statements) == 2, "Should have one ADD COLUMN and one DROP CLUSTERING KEY statement"
129+
add_column_sql = statements[0]
140130

141-
assert sql.startswith("ALTER TABLE")
142-
assert sql.count("ALTER TABLE") == 1
143-
assert sql.count("ADD COLUMN") == 1
144-
assert '"EVENT_TEST_TABLE"' in sql
145-
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
146-
assert '"COL2" FLOAT NOT NULL' in sql
147-
assert '"COL3" BOOLEAN NOT NULL' in sql
148-
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
149-
assert '"COL5" VARCHAR' in sql
150-
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
151-
assert '"COL7" BINARY' in sql
152-
assert '"COL8" NUMBER(38,0)' in sql
153-
assert '"COL9" VARIANT NOT NULL' in sql
154-
assert '"COL10" DATE' in sql
131+
# TODO: sqlfluff doesn't parse snowflake multi ADD COLUMN clause correctly
132+
# sqlfluff.parse(add_column_sql, dialect='snowflake')
133+
134+
assert add_column_sql.startswith("ALTER TABLE")
135+
assert add_column_sql.count("ALTER TABLE") == 1
136+
assert add_column_sql.count("ADD COLUMN") == 1
137+
assert '"EVENT_TEST_TABLE"' in add_column_sql
138+
assert '"COL1"' not in add_column_sql
139+
assert '"COL2" FLOAT NOT NULL' in add_column_sql
140+
assert '"COL3" BOOLEAN NOT NULL' in add_column_sql
141+
assert '"COL4" TIMESTAMP_TZ NOT NULL' in add_column_sql
142+
assert '"COL5" VARCHAR' in add_column_sql
143+
assert '"COL6" NUMBER(38,9) NOT NULL' in add_column_sql
144+
assert '"COL7" BINARY' in add_column_sql
145+
assert '"COL8" NUMBER(38,0)' in add_column_sql
146+
assert '"COL9" VARIANT NOT NULL' in add_column_sql
147+
assert '"COL10" DATE' in add_column_sql
148+
149+
150+
def test_alter_table_with_hints(snowflake_client: SnowflakeClient) -> None:
151+
table_name = "event_test_table"
155152

156-
mod_table = deepcopy(TABLE_UPDATE)
157-
mod_table.pop(0)
158-
sql = snowflake_client._get_table_update_sql("event_test_table", mod_table, True)[0]
153+
# mock hints
154+
snowflake_client.active_hints = SUPPORTED_HINTS
159155

160-
assert '"COL1"' not in sql
161-
assert '"COL2" FLOAT NOT NULL' in sql
156+
# test primary key and unique hints
157+
new_columns = deepcopy(TABLE_UPDATE[11:])
158+
new_columns[0]["primary_key"] = True
159+
new_columns[1]["unique"] = True
160+
statements = snowflake_client._get_table_update_sql(table_name, new_columns, True)
161+
162+
assert len(statements) == 2, "Should have one ADD COLUMN and one DROP CLUSTERING KEY statement"
163+
add_column_sql = statements[0]
164+
assert "PRIMARY KEY" not in add_column_sql # PK constraint ignored for alter
165+
assert '"COL2_NULL" FLOAT UNIQUE' in add_column_sql
166+
167+
# test cluster hint
168+
169+
# case: drop clustering (always run if no cluster hints present in table schema)
170+
cluster_by_sql = statements[1]
171+
172+
assert cluster_by_sql.startswith("ALTER TABLE")
173+
assert f'"{table_name.upper()}"' in cluster_by_sql
174+
assert cluster_by_sql.endswith("DROP CLUSTERING KEY")
175+
176+
# case: add clustering (without clustering -> with clustering)
177+
old_columns = deepcopy(TABLE_UPDATE[:1])
178+
new_columns = deepcopy(TABLE_UPDATE[1:2])
179+
new_columns[0]["cluster"] = True # COL2
180+
all_columns = deepcopy(old_columns + new_columns)
181+
snowflake_client.schema.update_table(new_table(table_name, columns=deepcopy(all_columns)))
182+
statements = snowflake_client._get_table_update_sql(table_name, new_columns, True)
183+
184+
assert len(statements) == 2, "Should have one ADD COLUMN and one CLUSTER BY statement"
185+
cluster_by_sql = statements[1]
186+
assert cluster_by_sql.startswith("ALTER TABLE")
187+
assert f'"{table_name.upper()}"' in cluster_by_sql
188+
assert 'CLUSTER BY ("COL2")' in cluster_by_sql
189+
190+
# case: modify clustering (extend cluster columns)
191+
old_columns = deepcopy(TABLE_UPDATE[:2])
192+
old_columns[1]["cluster"] = True # COL2
193+
new_columns = deepcopy(TABLE_UPDATE[2:5])
194+
new_columns[2]["cluster"] = True # COL5
195+
all_columns = deepcopy(old_columns + new_columns)
196+
snowflake_client.schema.update_table(new_table(table_name, columns=all_columns))
197+
statements = snowflake_client._get_table_update_sql(table_name, new_columns, True)
198+
199+
assert len(statements) == 2, "Should have one ADD COLUMN and one CLUSTER BY statement"
200+
cluster_by_sql = statements[1]
201+
assert cluster_by_sql.count("ALTER TABLE") == 1
202+
assert cluster_by_sql.count("CLUSTER BY") == 1
203+
assert 'CLUSTER BY ("COL2","COL5")' in cluster_by_sql
204+
205+
# case: modify clustering (reorder cluster columns)
206+
old_columns = deepcopy(TABLE_UPDATE[:5])
207+
old_columns[1]["cluster"] = True # COL2
208+
old_columns[4]["cluster"] = True # COL5
209+
old_columns[1], old_columns[4] = old_columns[4], old_columns[1] # swap order
210+
new_columns = deepcopy(TABLE_UPDATE[5:6])
211+
all_columns = deepcopy(old_columns + new_columns)
212+
# cannot change column order in existing table schema, so we drop and recreate
213+
snowflake_client.schema.drop_tables([table_name])
214+
snowflake_client.schema.update_table(new_table(table_name, columns=all_columns))
215+
statements = snowflake_client._get_table_update_sql(table_name, new_columns, True)
216+
217+
assert len(statements) == 2, "Should have one ADD COLUMN and one CLUSTER BY statement"
218+
cluster_by_sql = statements[1]
219+
assert 'CLUSTER BY ("COL5","COL2")' in cluster_by_sql # reordered (COL5 first)
162220

163221

164222
def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None:
@@ -170,9 +228,7 @@ def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None:
170228
assert cs_client.sql_client.dataset_name.endswith("staginG")
171229
assert cs_client.sql_client.staging_dataset_name.endswith("staginG")
172230
# check tables
173-
cs_client.schema.update_table(
174-
utils.new_table("event_test_table", columns=deepcopy(TABLE_UPDATE))
175-
)
231+
cs_client.schema.update_table(new_table("event_test_table", columns=deepcopy(TABLE_UPDATE)))
176232
sql = cs_client._get_table_update_sql(
177233
"Event_test_tablE",
178234
list(cs_client.schema.get_table_columns("Event_test_tablE").values()),
@@ -192,7 +248,9 @@ def test_create_table_with_partition_and_cluster(snowflake_client: SnowflakeClie
192248
mod_update[3]["partition"] = True
193249
mod_update[4]["cluster"] = True
194250
mod_update[1]["cluster"] = True
195-
statements = snowflake_client._get_table_update_sql("event_test_table", mod_update, False)
251+
table_name = "event_test_table"
252+
snowflake_client.schema.update_table(new_table(table_name, columns=deepcopy(mod_update)))
253+
statements = snowflake_client._get_table_update_sql(table_name, mod_update, False)
196254
assert len(statements) == 1
197255
sql = statements[0]
198256

0 commit comments

Comments
 (0)