From c57bfed9f41b9446b375a38f721336cf0294713c Mon Sep 17 00:00:00 2001 From: Seba Arriagada Date: Thu, 20 Jul 2023 12:56:32 +0100 Subject: [PATCH] get prod alias from manifest when provided --- data_diff/dbt.py | 18 ++++++++++-------- tests/test_dbt.py | 10 +++++++--- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/data_diff/dbt.py b/data_diff/dbt.py index fd3223de..7815c317 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -163,22 +163,22 @@ def _get_diff_vars( ) -> TDiffVars: dev_database = model.database dev_schema = model.schema_ - + dev_alias = prod_alias = model.alias primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key") # prod path is constructed via configuration or the prod manifest via --state if dbt_parser.prod_manifest_obj: - prod_database, prod_schema = _get_prod_path_from_manifest(model, dbt_parser.prod_manifest_obj) + prod_database, prod_schema, prod_alias = _get_prod_path_from_manifest(model, dbt_parser.prod_manifest_obj) else: prod_database, prod_schema = _get_prod_path_from_config(config, model, dev_database, dev_schema) if dbt_parser.requires_upper: - dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias] if x] - prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.alias] if x] + dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, dev_alias] if x] + prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, prod_alias] if x] primary_keys = [x.upper() for x in primary_keys] else: - dev_qualified_list = [x for x in [dev_database, dev_schema, model.alias] if x] - prod_qualified_list = [x for x in [prod_database, prod_schema, model.alias] if x] + dev_qualified_list = [x for x in [dev_database, dev_schema, dev_alias] if x] + prod_qualified_list = [x for x in [prod_database, prod_schema, prod_alias] if x] datadiff_model_config = dbt_parser.get_datadiff_model_config(model.meta) @@ -225,14 +225,16 @@ def _get_prod_path_from_config(config, model, dev_database, dev_schema) -> Tuple return prod_database, prod_schema -def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str], Tuple[None, None]]: +def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str, str], Tuple[None, None, None]]: prod_database = None prod_schema = None + prod_alias = None prod_model = prod_manifest.nodes.get(model.unique_id, None) if prod_model: prod_database = prod_model.database prod_schema = prod_model.schema_ - return prod_database, prod_schema + prod_alias = prod_model.alias + return prod_database, prod_schema, prod_alias def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None: diff --git a/tests/test_dbt.py b/tests/test_dbt.py index 014e73ae..0213c4c8 100644 --- a/tests/test_dbt.py +++ b/tests/test_dbt.py @@ -684,9 +684,11 @@ def test_get_prod_path_from_manifest_model_exists(self): mock_prod_manifest.nodes.get.return_value = mock_prod_model mock_prod_model.database = "prod_db" mock_prod_model.schema_ = "prod_schema" - prod_database, prod_schema = _get_prod_path_from_manifest(mock_model, mock_prod_manifest) + mock_prod_model.alias = "prod_alias" + prod_database, prod_schema, prod_alias = _get_prod_path_from_manifest(mock_model, mock_prod_manifest) self.assertEqual(prod_database, mock_prod_model.database) self.assertEqual(prod_schema, mock_prod_model.schema_) + self.assertEqual(prod_alias, mock_prod_model.alias) def test_get_prod_path_from_manifest_model_not_exists(self): mock_model = Mock() @@ -696,9 +698,11 @@ def test_get_prod_path_from_manifest_model_not_exists(self): mock_prod_manifest.nodes.get.return_value = None mock_prod_model.database = "prod_db" mock_prod_model.schema_ = "prod_schema" - prod_database, prod_schema = _get_prod_path_from_manifest(mock_model, mock_prod_manifest) + mock_prod_model.alias = "prod_alias" + prod_database, prod_schema, prod_alias = _get_prod_path_from_manifest(mock_model, mock_prod_manifest) self.assertEqual(prod_database, None) self.assertEqual(prod_schema, None) + self.assertEqual(prod_alias, None) def test_get_diff_custom_schema_no_config_exception(self): config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema") @@ -926,7 +930,7 @@ def test_get_diff_vars_call_get_prod_path_from_manifest( mock_dbt_parser.requires_upper = False mock_model.meta = None mock_dbt_parser.prod_manifest_obj = {"manifest_key": "manifest_value"} - mock_prod_path_from_manifest.return_value = ("prod_db", "prod_schema") + mock_prod_path_from_manifest.return_value = ("prod_db", "prod_schema", "prod_alias") diff_vars = _get_diff_vars(mock_dbt_parser, config, mock_model)