Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

get prod alias from manifest file when provided #652

Merged
merged 1 commit into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,22 +163,22 @@ def _get_diff_vars(
) -> TDiffVars:
dev_database = model.database
dev_schema = model.schema_

dev_alias = prod_alias = model.alias
primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key")

# prod path is constructed via configuration or the prod manifest via --state
if dbt_parser.prod_manifest_obj:
prod_database, prod_schema = _get_prod_path_from_manifest(model, dbt_parser.prod_manifest_obj)
prod_database, prod_schema, prod_alias = _get_prod_path_from_manifest(model, dbt_parser.prod_manifest_obj)
else:
prod_database, prod_schema = _get_prod_path_from_config(config, model, dev_database, dev_schema)

if dbt_parser.requires_upper:
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias] if x]
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.alias] if x]
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, dev_alias] if x]
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, prod_alias] if x]
primary_keys = [x.upper() for x in primary_keys]
else:
dev_qualified_list = [x for x in [dev_database, dev_schema, model.alias] if x]
prod_qualified_list = [x for x in [prod_database, prod_schema, model.alias] if x]
dev_qualified_list = [x for x in [dev_database, dev_schema, dev_alias] if x]
prod_qualified_list = [x for x in [prod_database, prod_schema, prod_alias] if x]

datadiff_model_config = dbt_parser.get_datadiff_model_config(model.meta)

Expand Down Expand Up @@ -225,14 +225,16 @@ def _get_prod_path_from_config(config, model, dev_database, dev_schema) -> Tuple
return prod_database, prod_schema


def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str], Tuple[None, None]]:
def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str, str], Tuple[None, None, None]]:
prod_database = None
prod_schema = None
prod_alias = None
prod_model = prod_manifest.nodes.get(model.unique_id, None)
if prod_model:
prod_database = prod_model.database
prod_schema = prod_model.schema_
return prod_database, prod_schema
prod_alias = prod_model.alias
return prod_database, prod_schema, prod_alias


def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None:
Expand Down
10 changes: 7 additions & 3 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,9 +684,11 @@ def test_get_prod_path_from_manifest_model_exists(self):
mock_prod_manifest.nodes.get.return_value = mock_prod_model
mock_prod_model.database = "prod_db"
mock_prod_model.schema_ = "prod_schema"
prod_database, prod_schema = _get_prod_path_from_manifest(mock_model, mock_prod_manifest)
mock_prod_model.alias = "prod_alias"
prod_database, prod_schema, prod_alias = _get_prod_path_from_manifest(mock_model, mock_prod_manifest)
self.assertEqual(prod_database, mock_prod_model.database)
self.assertEqual(prod_schema, mock_prod_model.schema_)
self.assertEqual(prod_alias, mock_prod_model.alias)

def test_get_prod_path_from_manifest_model_not_exists(self):
mock_model = Mock()
Expand All @@ -696,9 +698,11 @@ def test_get_prod_path_from_manifest_model_not_exists(self):
mock_prod_manifest.nodes.get.return_value = None
mock_prod_model.database = "prod_db"
mock_prod_model.schema_ = "prod_schema"
prod_database, prod_schema = _get_prod_path_from_manifest(mock_model, mock_prod_manifest)
mock_prod_model.alias = "prod_alias"
prod_database, prod_schema, prod_alias = _get_prod_path_from_manifest(mock_model, mock_prod_manifest)
self.assertEqual(prod_database, None)
self.assertEqual(prod_schema, None)
self.assertEqual(prod_alias, None)

def test_get_diff_custom_schema_no_config_exception(self):
config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema")
Expand Down Expand Up @@ -926,7 +930,7 @@ def test_get_diff_vars_call_get_prod_path_from_manifest(
mock_dbt_parser.requires_upper = False
mock_model.meta = None
mock_dbt_parser.prod_manifest_obj = {"manifest_key": "manifest_value"}
mock_prod_path_from_manifest.return_value = ("prod_db", "prod_schema")
mock_prod_path_from_manifest.return_value = ("prod_db", "prod_schema", "prod_alias")

diff_vars = _get_diff_vars(mock_dbt_parser, config, mock_model)

Expand Down