Skip to content

feat: set query label session property in bq session #4314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,33 @@ def query_factory() -> Query:
def _begin_session(self, properties: SessionProperties) -> None:
from google.cloud.bigquery import QueryJobConfig

job = self.client.query("SELECT 1;", job_config=QueryJobConfig(create_session=True))
query_label_property = properties.get("query_label")
parsed_query_label = []
if query_label_property and isinstance(
query_label_property, (exp.Array, exp.Paren, exp.Tuple)
):
label_tuples = (
[query_label_property.unnest()]
if isinstance(query_label_property, exp.Paren)
else query_label_property.expressions
)

# query_label is a Paren, Array or Tuple of 2-tuples and validated at load time
for label_tuple in label_tuples:
parsed_query_label.append(
(label_tuple.expressions[0].this, label_tuple.expressions[1].this)
)

if parsed_query_label:
query_label_str = ",".join([":".join(label) for label in parsed_query_label])
query = f'SET @@query_label = "{query_label_str}";SELECT 1;'
else:
query = "SELECT 1;"

job = self.client.query(
query,
job_config=QueryJobConfig(create_session=True),
)
session_info = job.session_info
session_id = session_info.session_id if session_info else None
self._session_id = session_id
Expand Down
61 changes: 61 additions & 0 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,9 @@ def load_sql_based_model(
path,
)

# Validate virtual, physical and session properties
_validate_model_properties(meta_fields, path)

common_kwargs = dict(
pre_statements=pre_statements,
post_statements=post_statements,
Expand Down Expand Up @@ -2634,6 +2637,64 @@ def _validate_model_fields(klass: t.Type[_Model], provided_fields: t.Set[str], p
raise_config_error(f"Invalid extra fields {extra_fields} in the model definition", path)


def _validate_model_properties(meta_fields: dict[str, t.Any], path: Path) -> None:
# store for later validation of specific properties
model_properties: dict[str, t.Any] = {}

# validate that all properties kinds are key-value mappings
for kind in PROPERTIES:
if kind in meta_fields:
kind_properties = meta_fields[kind]
model_properties[kind] = {}

if not isinstance(kind_properties, (exp.Array, exp.Paren, exp.Tuple)):
raise_config_error(
f"Invalid MODEL statement: `{kind}` must be a tuple or array of key-value mappings.",
path,
)

key_value_mappings: t.List[exp.Expression] = (
[kind_properties.unnest()]
if isinstance(kind_properties, exp.Paren)
else kind_properties.expressions
)

for expression in key_value_mappings:
if isinstance(expression, exp.EQ):
model_properties[kind][expression.left.name] = expression.right
else:
raise_config_error(
f"Invalid MODEL statement: all expressions in `{kind}` must be key-value pairs.",
path,
)

# The query_label can be attached to the actual queries at execution time and is expected to be a sequence of 2-tuples
if (
"session_properties" in model_properties
and "query_label" in model_properties["session_properties"]
):
query_label_property = model_properties["session_properties"]["query_label"]
if not (
isinstance(query_label_property, exp.Array)
or isinstance(query_label_property, exp.Tuple)
or isinstance(query_label_property, exp.Paren)
):
raise_config_error(
"Invalid MODEL statement: `session_properties.query_label` must be an array or tuple.",
path,
)
for label_tuple in query_label_property.expressions:
if not (
isinstance(label_tuple, exp.Tuple)
and len(label_tuple.expressions) == 2
and all(isinstance(label, exp.Literal) for label in label_tuple.expressions)
):
raise_config_error(
"Invalid MODEL statement: expressions inside `session_properties.query_label` must be tuples of string literals with length 2.",
path,
)


def _list_of_calls_to_exp(value: t.List[t.Tuple[str, t.Dict[str, t.Any]]]) -> exp.Expression:
return exp.Tuple(
expressions=[
Expand Down
43 changes: 43 additions & 0 deletions tests/core/engine_adapter/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def test_begin_end_session(mocker: MockerFixture):

adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0)

# starting a session without session properties
with adapter.session({}):
assert adapter._connection_pool.get_attribute("session_id") is not None
adapter.execute("SELECT 2;")
Expand All @@ -551,6 +552,48 @@ def test_begin_end_session(mocker: MockerFixture):
assert execute_b_call[1]["query"] == "SELECT 3;"
assert not execute_b_call[1]["job_config"].connection_properties

# starting a new session with session property query_label and array value
with adapter.session(
{
"query_label": exp.Array(
expressions=[
exp.Tuple(
expressions=[
exp.Literal.string("key1"),
exp.Literal.string("value1"),
]
),
exp.Tuple(
expressions=[
exp.Literal.string("key2"),
exp.Literal.string("value2"),
]
),
]
)
}
):
adapter.execute("SELECT 4;")
begin_new_session_call = connection_mock._client.query.call_args_list[3]
assert begin_new_session_call[0][0] == 'SET @@query_label = "key1:value1,key2:value2";SELECT 1;'

# starting a new session with session property query_label and Paren value
with adapter.session(
{
"query_label": exp.Paren(
this=exp.Tuple(
expressions=[
exp.Literal.string("key1"),
exp.Literal.string("value1"),
]
)
)
}
):
adapter.execute("SELECT 5;")
begin_new_session_call = connection_mock._client.query.call_args_list[5]
assert begin_new_session_call[0][0] == 'SET @@query_label = "key1:value1";SELECT 1;'


def _to_sql_calls(execute_mock: t.Any, identify: bool = True) -> t.List[str]:
output = []
Expand Down
62 changes: 61 additions & 1 deletion tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3587,7 +3587,9 @@ def my_model(context, **kwargs):
"""('key_a' = 'value_a', 'key_b' = 1, 'key_c' = TRUE, 'key_d' = 2.0)"""
)

with pytest.raises(ConfigError, match=r"Invalid property 'invalid'.*"):
with pytest.raises(
ConfigError, match=r"Invalid MODEL statement: all expressions in `physical_properties`.*"
):
load_sql_based_model(
d.parse(
"""
Expand Down Expand Up @@ -4248,6 +4250,64 @@ def test_model_session_properties(sushi_context):
"warehouse": "test_warehouse",
}

model = load_sql_based_model(
d.parse(
"""
MODEL (
name test_schema.test_model,
session_properties (
'query_label' = [
('key1', 'value1'),
('key2', 'value2')
]
)
);
SELECT a FROM tbl;
""",
default_dialect="bigquery",
)
)
assert model.session_properties == {
"query_label": exp.Array(
expressions=[
exp.Tuple(
expressions=[
exp.Literal.string("key1"),
exp.Literal.string("value1"),
]
),
exp.Tuple(
expressions=[
exp.Literal.string("key2"),
exp.Literal.string("value2"),
]
),
]
)
}

model = load_sql_based_model(
d.parse(
"""
MODEL (
name test_schema.test_model,
session_properties (
'query_label' = (
('key1', 'value1')
)
)
);
SELECT a FROM tbl;
""",
default_dialect="bigquery",
)
)
assert model.session_properties == {
"query_label": exp.Paren(
this=exp.Tuple(expressions=[exp.Literal.string("key1"), exp.Literal.string("value1")])
)
}


def test_model_jinja_macro_rendering():
expressions = d.parse(
Expand Down