Skip to content

Commit 0f4ddc6

Browse files
committed
Feat(experimental): DBT project conversion
1 parent ec92d47 commit 0f4ddc6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+3425
-39
lines changed

sqlmesh/cli/main.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"environments",
3131
"invalidate",
3232
"table_name",
33+
"dbt",
3334
)
3435
SKIP_CONTEXT_COMMANDS = ("init", "ui")
3536

@@ -1203,3 +1204,48 @@ def state_import(obj: Context, input_file: Path, replace: bool, no_confirm: bool
12031204
"""Import a state export file back into the state database"""
12041205
confirm = not no_confirm
12051206
obj.import_state(input_file=input_file, clear=replace, confirm=confirm)
1207+
1208+
1209+
@cli.group(no_args_is_help=True, hidden=True)
1210+
def dbt() -> None:
1211+
"""Commands for doing dbt-specific things"""
1212+
pass
1213+
1214+
1215+
@dbt.command("convert")
1216+
@click.option(
1217+
"-i",
1218+
"--input-dir",
1219+
help="Path to the DBT project",
1220+
required=True,
1221+
type=click.Path(exists=True, dir_okay=True, file_okay=False, readable=True, path_type=Path),
1222+
)
1223+
@click.option(
1224+
"-o",
1225+
"--output-dir",
1226+
required=True,
1227+
help="Path to write out the converted SQLMesh project",
1228+
type=click.Path(exists=False, dir_okay=True, file_okay=False, readable=True, path_type=Path),
1229+
)
1230+
@click.option(
1231+
"--external-models-file/--no-external-models-file",
1232+
is_flag=True,
1233+
default=True,
1234+
help="Generate external_models.yaml (requires connectivity to the data warehouse)",
1235+
)
1236+
@click.option("--no-prompts", is_flag=True, help="Disable interactive prompts", default=False)
1237+
@click.pass_obj
1238+
@error_handler
1239+
@cli_analytics
1240+
def dbt_convert(
1241+
obj: Context, input_dir: Path, output_dir: Path, external_models_file: bool, no_prompts: bool
1242+
) -> None:
1243+
"""Convert a DBT project to a SQLMesh project"""
1244+
from sqlmesh.dbt.converter.convert import convert_project_files
1245+
1246+
convert_project_files(
1247+
input_dir.absolute(),
1248+
output_dir.absolute(),
1249+
output_external_models=external_models_file,
1250+
no_prompts=no_prompts,
1251+
)

sqlmesh/core/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
MAX_MODEL_DEFINITION_SIZE = 10000
3232
"""Maximum number of characters in a model definition"""
3333

34+
MIGRATED_DBT_PROJECT_NAME = "__dbt_project_name__"
35+
MIGRATED_DBT_PACKAGES = "__dbt_packages__"
36+
3437

3538
# The maximum number of fork processes, used for loading projects
3639
# None means default to process pool, 1 means don't fork, :N is number of processes

sqlmesh/core/loader.py

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@
3838
from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns
3939
from sqlmesh.utils import UniqueKeyDict, sys_path
4040
from sqlmesh.utils.errors import ConfigError
41-
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
41+
from sqlmesh.utils.jinja import (
42+
JinjaMacroRegistry,
43+
MacroExtractor,
44+
SQLMESH_JINJA_PACKAGE,
45+
SQLMESH_DBT_COMPATIBILITY_PACKAGE,
46+
)
4247
from sqlmesh.utils.metaprogramming import import_python_file
4348
from sqlmesh.utils.pydantic import validation_error_message
4449
from sqlmesh.utils.yaml import YAML, load as yaml_load
@@ -384,15 +389,42 @@ def _raise_failed_to_load_model_error(self, path: Path, error: t.Union[str, Exce
384389
class SqlMeshLoader(Loader):
385390
"""Loads macros and models for a context using the SQLMesh file formats"""
386391

392+
@property
393+
def is_migrated_dbt_project(self) -> bool:
394+
return self.migrated_dbt_project_name is not None
395+
396+
@property
397+
def migrated_dbt_project_name(self) -> t.Optional[str]:
398+
return self.config.variables.get(c.MIGRATED_DBT_PROJECT_NAME)
399+
387400
def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
388401
"""Loads all user defined macros."""
402+
403+
create_builtin_globals_module = (
404+
SQLMESH_DBT_COMPATIBILITY_PACKAGE
405+
if self.is_migrated_dbt_project
406+
else SQLMESH_JINJA_PACKAGE
407+
)
408+
389409
# Store a copy of the macro registry
390410
standard_macros = macro.get_registry()
391-
jinja_macros = JinjaMacroRegistry()
411+
412+
top_level_packages = []
413+
if self.is_migrated_dbt_project:
414+
top_level_packages = ["dbt"]
415+
if self.migrated_dbt_project_name:
416+
top_level_packages.append(self.migrated_dbt_project_name)
417+
418+
jinja_macros = JinjaMacroRegistry(
419+
create_builtins_module=create_builtin_globals_module,
420+
top_level_packages=top_level_packages,
421+
)
392422
extractor = MacroExtractor()
393423

394424
macros_max_mtime: t.Optional[float] = None
395425

426+
migrated_dbt_package_base_path = self.config_path / c.MACROS / c.MIGRATED_DBT_PACKAGES
427+
396428
for path in self._glob_paths(
397429
self.config_path / c.MACROS,
398430
ignore_patterns=self.config.ignore_patterns,
@@ -417,16 +449,51 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
417449
macros_max_mtime = (
418450
max(macros_max_mtime, macro_file_mtime) if macros_max_mtime else macro_file_mtime
419451
)
452+
420453
with open(path, "r", encoding="utf-8") as file:
421-
jinja_macros.add_macros(
422-
extractor.extract(file.read(), dialect=self.config.model_defaults.dialect)
423-
)
454+
try:
455+
package: t.Optional[str] = None
456+
if self.is_migrated_dbt_project:
457+
if path.is_relative_to(migrated_dbt_package_base_path):
458+
package = str(
459+
path.relative_to(migrated_dbt_package_base_path).parents[0]
460+
)
461+
else:
462+
package = self.migrated_dbt_project_name
463+
464+
jinja_macros.add_macros(
465+
extractor.extract(file.read(), dialect=self.config.model_defaults.dialect),
466+
package=package,
467+
)
468+
except:
469+
logger.error(f"Unable to read macro file: {path}")
470+
raise
424471

425472
self._macros_max_mtime = macros_max_mtime
426473

427474
macros = macro.get_registry()
428475
macro.set_registry(standard_macros)
429476

477+
if self.is_migrated_dbt_project:
478+
from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS
479+
480+
connection_config = self.context._connection_config
481+
# this triggers the DBT create_builtins_module to have a `target` property which is required for a bunch of DBT macros to work
482+
if dbt_config_type := TARGET_TYPE_TO_CONFIG_CLASS.get(connection_config.type_):
483+
try:
484+
jinja_macros.add_globals(
485+
{
486+
"target": dbt_config_type.from_sqlmesh(
487+
self.context._connection_config,
488+
name=self.config.default_gateway_name,
489+
).attribute_dict()
490+
}
491+
)
492+
except NotImplementedError:
493+
raise ConfigError(
494+
f"No DBT 'Target Type' mapping for connection type: {connection_config.type_}"
495+
)
496+
430497
return macros, jinja_macros
431498

432499
def _load_models(
@@ -499,6 +566,7 @@ def _load() -> t.List[Model]:
499566
infer_names=self.config.model_naming.infer_names,
500567
signal_definitions=signals,
501568
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
569+
migrated_dbt_project_name=self.migrated_dbt_project_name,
502570
)
503571
except Exception as ex:
504572
self._raise_failed_to_load_model_error(path, ex)

sqlmesh/core/model/definition.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2015,6 +2015,7 @@ def load_sql_based_model(
20152015
variables: t.Optional[t.Dict[str, t.Any]] = None,
20162016
infer_names: t.Optional[bool] = False,
20172017
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
2018+
migrated_dbt_project_name: t.Optional[str] = None,
20182019
**kwargs: t.Any,
20192020
) -> Model:
20202021
"""Load a model from a parsed SQLMesh model SQL file.
@@ -2186,6 +2187,7 @@ def load_sql_based_model(
21862187
name,
21872188
query_or_seed_insert,
21882189
time_column_format=time_column_format,
2190+
migrated_dbt_project_name=migrated_dbt_project_name,
21892191
**common_kwargs,
21902192
)
21912193
seed_properties = {
@@ -2393,6 +2395,7 @@ def _create_model(
23932395
signal_definitions: t.Optional[SignalRegistry] = None,
23942396
variables: t.Optional[t.Dict[str, t.Any]] = None,
23952397
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
2398+
migrated_dbt_project_name: t.Optional[str] = None,
23962399
**kwargs: t.Any,
23972400
) -> Model:
23982401
validate_extra_and_required_fields(
@@ -2448,13 +2451,42 @@ def _create_model(
24482451

24492452
if jinja_macros:
24502453
jinja_macros = (
2451-
jinja_macros if jinja_macros.trimmed else jinja_macros.trim(jinja_macro_references)
2454+
jinja_macros
2455+
if jinja_macros.trimmed
2456+
else jinja_macros.trim(jinja_macro_references, package=migrated_dbt_project_name)
24522457
)
24532458
else:
24542459
jinja_macros = JinjaMacroRegistry()
24552460

2456-
for jinja_macro in jinja_macros.root_macros.values():
2457-
used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1])
2461+
# extract {{ var() }} references used in all jinja macro dependencies to check for any variables specific
2462+
# to a migrated DBT package and resolve them accordingly
2463+
# vars are added into __sqlmesh_vars__ in the Python env so that the native SQLMesh var() function can resolve them
2464+
if migrated_dbt_project_name:
2465+
# note: JinjaMacroRegistry is trimmed here so "all_macros" should be just be all the macros used by this model
2466+
for _, _, jinja_macro in jinja_macros.all_macros:
2467+
_, extracted_variable_names = extract_macro_references_and_variables(
2468+
jinja_macro.definition
2469+
)
2470+
used_variables.update(extracted_variable_names)
2471+
2472+
variables = variables or {}
2473+
if (dbt_package_variables := variables.get(c.MIGRATED_DBT_PACKAGES)) and isinstance(
2474+
dbt_package_variables, dict
2475+
):
2476+
# flatten the nested dict structure from the migrated dbt package variables in the SQLmesh config into __dbt_packages.<package>.<variable>
2477+
# to match what extract_macro_references_and_variables() returns. This allows the usage checks in create_python_env() to work
2478+
def _flatten(prefix: str, root: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
2479+
acc = {}
2480+
for k, v in root.items():
2481+
key_with_prefix = f"{prefix}.{k}"
2482+
if isinstance(v, dict):
2483+
acc.update(_flatten(key_with_prefix, v))
2484+
else:
2485+
acc[key_with_prefix] = v
2486+
return acc
2487+
2488+
flattened = _flatten(c.MIGRATED_DBT_PACKAGES, dbt_package_variables)
2489+
variables.update(flattened)
24582490

24592491
model = klass(
24602492
name=name,

sqlmesh/core/model/kind.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,18 @@ def _merge_filter_validator(
491491

492492
return v.transform(d.replace_merge_table_aliases)
493493

494+
@field_validator("batch_concurrency", mode="before")
495+
def _batch_concurrency_validator(
496+
cls, v: t.Optional[exp.Expression], info: ValidationInfo
497+
) -> int:
498+
if isinstance(v, exp.Literal):
499+
return int(
500+
v.to_py()
501+
) # allow batch_concurrency = 1 to be specified in a Model definition without throwing a Pydantic error
502+
if isinstance(v, int):
503+
return v
504+
return 1
505+
494506
@property
495507
def data_hash_values(self) -> t.List[t.Optional[str]]:
496508
return [

sqlmesh/core/renderer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def _resolve_table(table: str | exp.Table) -> str:
178178
)
179179

180180
render_kwargs = {
181+
"dialect": self._dialect,
181182
**date_dict(
182183
to_datetime(execution_time or c.EPOCH),
183184
start_time,

sqlmesh/dbt/adapter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def __init__(
3838
self.jinja_globals = jinja_globals.copy() if jinja_globals else {}
3939
self.jinja_globals["adapter"] = self
4040
self.project_dialect = project_dialect
41+
self.jinja_globals["dialect"] = (
42+
project_dialect # so the dialect is available in the jinja env created by self.dispatch()
43+
)
4144
self.quote_policy = quote_policy or Policy()
4245

4346
@abc.abstractmethod

sqlmesh/dbt/builtin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class Var:
156156
def __init__(self, variables: t.Dict[str, t.Any]) -> None:
157157
self.variables = variables
158158

159-
def __call__(self, name: str, default: t.Optional[t.Any] = None) -> t.Any:
159+
def __call__(self, name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any) -> t.Any:
160160
return self.variables.get(name, default)
161161

162162
def has_var(self, name: str) -> bool:

sqlmesh/dbt/converter/__init__.py

Whitespace-only changes.

sqlmesh/dbt/converter/common.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
import jinja2.nodes as j
3+
from sqlglot import exp
4+
import typing as t
5+
import sqlmesh.core.constants as c
6+
from pathlib import Path
7+
8+
9+
# jinja transform is a function that takes (current node, previous node, parent node) and returns a new Node or None
10+
# returning None means the current node is removed from the tree
11+
# returning a different Node means the current node is replaced with the new Node
12+
JinjaTransform = t.Callable[[j.Node, t.Optional[j.Node], t.Optional[j.Node]], t.Optional[j.Node]]
13+
SQLGlotTransform = t.Callable[[exp.Expression], t.Optional[exp.Expression]]
14+
15+
16+
def _sqlmesh_predefined_macro_variables() -> t.Set[str]:
17+
def _gen() -> t.Iterable[str]:
18+
for suffix in ("dt", "date", "ds", "ts", "tstz", "hour", "epoch", "millis"):
19+
for prefix in ("start", "end", "execution"):
20+
yield f"{prefix}_{suffix}"
21+
22+
for item in ("runtime_stage", "gateway", "this_model", "this_env", "model_kind_name"):
23+
yield item
24+
25+
return set(_gen())
26+
27+
28+
SQLMESH_PREDEFINED_MACRO_VARIABLES = _sqlmesh_predefined_macro_variables()
29+
30+
31+
def infer_dbt_package_from_path(path: Path) -> t.Optional[str]:
32+
"""
33+
Given a path like "sqlmesh-project/macros/__dbt_packages__/foo/bar.sql"
34+
35+
Infer that 'foo' is the DBT package
36+
"""
37+
if c.MIGRATED_DBT_PACKAGES in path.parts:
38+
idx = path.parts.index(c.MIGRATED_DBT_PACKAGES)
39+
return path.parts[idx + 1]
40+
return None

0 commit comments

Comments
 (0)