Skip to content

Feat(experimental): DBT project conversion #4495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions sqlmesh/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"environments",
"invalidate",
"table_name",
"dbt",
)
SKIP_CONTEXT_COMMANDS = ("init", "ui")

Expand Down Expand Up @@ -1203,3 +1204,48 @@ def state_import(obj: Context, input_file: Path, replace: bool, no_confirm: bool
"""Import a state export file back into the state database"""
confirm = not no_confirm
obj.import_state(input_file=input_file, clear=replace, confirm=confirm)


@cli.group(no_args_is_help=True, hidden=True)
def dbt() -> None:
"""Commands for doing dbt-specific things"""
pass


@dbt.command("convert")
@click.option(
"-i",
"--input-dir",
help="Path to the DBT project",
required=True,
type=click.Path(exists=True, dir_okay=True, file_okay=False, readable=True, path_type=Path),
)
@click.option(
"-o",
"--output-dir",
required=True,
help="Path to write out the converted SQLMesh project",
type=click.Path(exists=False, dir_okay=True, file_okay=False, readable=True, path_type=Path),
)
@click.option(
"--external-models-file/--no-external-models-file",
is_flag=True,
default=True,
help="Generate external_models.yaml (requires connectivity to the data warehouse)",
)
@click.option("--no-prompts", is_flag=True, help="Disable interactive prompts", default=False)
@click.pass_obj
@error_handler
@cli_analytics
def dbt_convert(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we instead extend the init command like we do for dlt generation?

obj: Context, input_dir: Path, output_dir: Path, external_models_file: bool, no_prompts: bool
) -> None:
"""Convert a DBT project to a SQLMesh project"""
from sqlmesh.dbt.converter.convert import convert_project_files

convert_project_files(
input_dir.absolute(),
output_dir.absolute(),
output_external_models=external_models_file,
no_prompts=no_prompts,
)
3 changes: 3 additions & 0 deletions sqlmesh/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
MAX_MODEL_DEFINITION_SIZE = 10000
"""Maximum number of characters in a model definition"""

MIGRATED_DBT_PROJECT_NAME = "__dbt_project_name__"
MIGRATED_DBT_PACKAGES = "__dbt_packages__"


# The maximum number of fork processes, used for loading projects
# None means default to process pool, 1 means don't fork, :N is number of processes
Expand Down
78 changes: 73 additions & 5 deletions sqlmesh/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns
from sqlmesh.utils import UniqueKeyDict, sys_path
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
from sqlmesh.utils.jinja import (
JinjaMacroRegistry,
MacroExtractor,
SQLMESH_JINJA_PACKAGE,
SQLMESH_DBT_COMPATIBILITY_PACKAGE,
)
from sqlmesh.utils.metaprogramming import import_python_file
from sqlmesh.utils.pydantic import validation_error_message
from sqlmesh.utils.yaml import YAML, load as yaml_load
Expand Down Expand Up @@ -384,15 +389,42 @@ def _raise_failed_to_load_model_error(self, path: Path, error: t.Union[str, Exce
class SqlMeshLoader(Loader):
"""Loads macros and models for a context using the SQLMesh file formats"""

@property
def is_migrated_dbt_project(self) -> bool:
return self.migrated_dbt_project_name is not None

@property
def migrated_dbt_project_name(self) -> t.Optional[str]:
return self.config.variables.get(c.MIGRATED_DBT_PROJECT_NAME)

def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
"""Loads all user defined macros."""

create_builtin_globals_module = (
SQLMESH_DBT_COMPATIBILITY_PACKAGE
if self.is_migrated_dbt_project
else SQLMESH_JINJA_PACKAGE
)

# Store a copy of the macro registry
standard_macros = macro.get_registry()
jinja_macros = JinjaMacroRegistry()

top_level_packages = []
if self.is_migrated_dbt_project:
top_level_packages = ["dbt"]
if self.migrated_dbt_project_name:
top_level_packages.append(self.migrated_dbt_project_name)

jinja_macros = JinjaMacroRegistry(
create_builtins_module=create_builtin_globals_module,
top_level_packages=top_level_packages,
)
extractor = MacroExtractor()

macros_max_mtime: t.Optional[float] = None

migrated_dbt_package_base_path = self.config_path / c.MACROS / c.MIGRATED_DBT_PACKAGES

for path in self._glob_paths(
self.config_path / c.MACROS,
ignore_patterns=self.config.ignore_patterns,
Expand All @@ -417,16 +449,51 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
macros_max_mtime = (
max(macros_max_mtime, macro_file_mtime) if macros_max_mtime else macro_file_mtime
)

with open(path, "r", encoding="utf-8") as file:
jinja_macros.add_macros(
extractor.extract(file.read(), dialect=self.config.model_defaults.dialect)
)
try:
package: t.Optional[str] = None
if self.is_migrated_dbt_project:
if path.is_relative_to(migrated_dbt_package_base_path):
package = str(
path.relative_to(migrated_dbt_package_base_path).parents[0]
)
else:
package = self.migrated_dbt_project_name

jinja_macros.add_macros(
extractor.extract(file.read(), dialect=self.config.model_defaults.dialect),
package=package,
)
except:
logger.error(f"Unable to read macro file: {path}")
raise

self._macros_max_mtime = macros_max_mtime

macros = macro.get_registry()
macro.set_registry(standard_macros)

if self.is_migrated_dbt_project:
from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS

connection_config = self.context._connection_config
# this triggers the DBT create_builtins_module to have a `target` property which is required for a bunch of DBT macros to work
if dbt_config_type := TARGET_TYPE_TO_CONFIG_CLASS.get(connection_config.type_):
try:
jinja_macros.add_globals(
{
"target": dbt_config_type.from_sqlmesh(
self.context._connection_config,
name=self.config.default_gateway_name,
).attribute_dict()
}
)
except NotImplementedError:
raise ConfigError(
f"No DBT 'Target Type' mapping for connection type: {connection_config.type_}"
)

return macros, jinja_macros

def _load_models(
Expand Down Expand Up @@ -499,6 +566,7 @@ def _load() -> t.List[Model]:
infer_names=self.config.model_naming.infer_names,
signal_definitions=signals,
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
migrated_dbt_project_name=self.migrated_dbt_project_name,
)
except Exception as ex:
self._raise_failed_to_load_model_error(path, ex)
Expand Down
38 changes: 35 additions & 3 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2015,6 +2015,7 @@ def load_sql_based_model(
variables: t.Optional[t.Dict[str, t.Any]] = None,
infer_names: t.Optional[bool] = False,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
migrated_dbt_project_name: t.Optional[str] = None,
**kwargs: t.Any,
) -> Model:
"""Load a model from a parsed SQLMesh model SQL file.
Expand Down Expand Up @@ -2186,6 +2187,7 @@ def load_sql_based_model(
name,
query_or_seed_insert,
time_column_format=time_column_format,
migrated_dbt_project_name=migrated_dbt_project_name,
**common_kwargs,
)
seed_properties = {
Expand Down Expand Up @@ -2393,6 +2395,7 @@ def _create_model(
signal_definitions: t.Optional[SignalRegistry] = None,
variables: t.Optional[t.Dict[str, t.Any]] = None,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
migrated_dbt_project_name: t.Optional[str] = None,
**kwargs: t.Any,
) -> Model:
validate_extra_and_required_fields(
Expand Down Expand Up @@ -2448,13 +2451,42 @@ def _create_model(

if jinja_macros:
jinja_macros = (
jinja_macros if jinja_macros.trimmed else jinja_macros.trim(jinja_macro_references)
jinja_macros
if jinja_macros.trimmed
else jinja_macros.trim(jinja_macro_references, package=migrated_dbt_project_name)
)
else:
jinja_macros = JinjaMacroRegistry()

for jinja_macro in jinja_macros.root_macros.values():
used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1])
# extract {{ var() }} references used in all jinja macro dependencies to check for any variables specific
# to a migrated DBT package and resolve them accordingly
# vars are added into __sqlmesh_vars__ in the Python env so that the native SQLMesh var() function can resolve them
if migrated_dbt_project_name:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be encapsulated into its own function?

# note: JinjaMacroRegistry is trimmed here so "all_macros" should be just be all the macros used by this model
for _, _, jinja_macro in jinja_macros.all_macros:
_, extracted_variable_names = extract_macro_references_and_variables(
jinja_macro.definition
)
used_variables.update(extracted_variable_names)

variables = variables or {}
if (dbt_package_variables := variables.get(c.MIGRATED_DBT_PACKAGES)) and isinstance(
dbt_package_variables, dict
):
# flatten the nested dict structure from the migrated dbt package variables in the SQLmesh config into __dbt_packages.<package>.<variable>
# to match what extract_macro_references_and_variables() returns. This allows the usage checks in create_python_env() to work
def _flatten(prefix: str, root: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
acc = {}
for k, v in root.items():
key_with_prefix = f"{prefix}.{k}"
if isinstance(v, dict):
acc.update(_flatten(key_with_prefix, v))
else:
acc[key_with_prefix] = v
return acc

flattened = _flatten(c.MIGRATED_DBT_PACKAGES, dbt_package_variables)
variables.update(flattened)

model = klass(
name=name,
Expand Down
12 changes: 12 additions & 0 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,18 @@ def _merge_filter_validator(

return v.transform(d.replace_merge_table_aliases)

@field_validator("batch_concurrency", mode="before")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? There's already a validator for this field

def _batch_concurrency_validator(
cls, v: t.Optional[exp.Expression], info: ValidationInfo
) -> int:
if isinstance(v, exp.Literal):
return int(
v.to_py()
) # allow batch_concurrency = 1 to be specified in a Model definition without throwing a Pydantic error
if isinstance(v, int):
return v
return 1

@property
def data_hash_values(self) -> t.List[t.Optional[str]]:
return [
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def _resolve_table(table: str | exp.Table) -> str:
)

render_kwargs = {
"dialect": self._dialect,
**date_dict(
to_datetime(execution_time or c.EPOCH),
start_time,
Expand Down
3 changes: 3 additions & 0 deletions sqlmesh/dbt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def __init__(
self.jinja_globals = jinja_globals.copy() if jinja_globals else {}
self.jinja_globals["adapter"] = self
self.project_dialect = project_dialect
self.jinja_globals["dialect"] = (
project_dialect # so the dialect is available in the jinja env created by self.dispatch()
)
self.quote_policy = quote_policy or Policy()

@abc.abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/dbt/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class Var:
def __init__(self, variables: t.Dict[str, t.Any]) -> None:
self.variables = variables

def __call__(self, name: str, default: t.Optional[t.Any] = None) -> t.Any:
def __call__(self, name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any) -> t.Any:
return self.variables.get(name, default)

def has_var(self, name: str) -> bool:
Expand Down
Empty file.
40 changes: 40 additions & 0 deletions sqlmesh/dbt/converter/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations
import jinja2.nodes as j
from sqlglot import exp
import typing as t
import sqlmesh.core.constants as c
from pathlib import Path


# jinja transform is a function that takes (current node, previous node, parent node) and returns a new Node or None
# returning None means the current node is removed from the tree
# returning a different Node means the current node is replaced with the new Node
JinjaTransform = t.Callable[[j.Node, t.Optional[j.Node], t.Optional[j.Node]], t.Optional[j.Node]]
SQLGlotTransform = t.Callable[[exp.Expression], t.Optional[exp.Expression]]


def _sqlmesh_predefined_macro_variables() -> t.Set[str]:
def _gen() -> t.Iterable[str]:
for suffix in ("dt", "date", "ds", "ts", "tstz", "hour", "epoch", "millis"):
for prefix in ("start", "end", "execution"):
yield f"{prefix}_{suffix}"

for item in ("runtime_stage", "gateway", "this_model", "this_env", "model_kind_name"):
yield item

return set(_gen())


SQLMESH_PREDEFINED_MACRO_VARIABLES = _sqlmesh_predefined_macro_variables()


def infer_dbt_package_from_path(path: Path) -> t.Optional[str]:
"""
Given a path like "sqlmesh-project/macros/__dbt_packages__/foo/bar.sql"

Infer that 'foo' is the DBT package
"""
if c.MIGRATED_DBT_PACKAGES in path.parts:
idx = path.parts.index(c.MIGRATED_DBT_PACKAGES)
return path.parts[idx + 1]
return None
Loading