From 93d2ada904457a00a1b310dd321f47528f55bc59 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 29 Sep 2023 15:06:13 +0200 Subject: [PATCH] Resolve circular imports between schemas, datatabases, and compilers --- data_diff/__main__.py | 2 +- data_diff/joindiff_tables.py | 2 +- data_diff/schema.py | 7 +++---- data_diff/table_segment.py | 2 +- tests/test_database.py | 2 +- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 0e5255e6..ba006d36 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -461,7 +461,7 @@ def _data_diff( schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths))) schema1, schema2 = schemas = [ - create_schema(db, table_path, schema, case_sensitive) + create_schema(db.name, table_path, schema, case_sensitive) for db, table_path, schema in safezip(dbs, table_paths, schemas) ] diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index c40a2b99..91e2aecd 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -12,6 +12,7 @@ from data_diff.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake from data_diff.abcs.database_types import NumericType, DbPath +from data_diff.databases.base import Compiler from data_diff.queries.api import ( table, sum_, @@ -23,7 +24,6 @@ rightjoin, this, when, - Compiler, ) from data_diff.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code, ITable from data_diff.queries.extras import NormalizeAsString diff --git a/data_diff/schema.py b/data_diff/schema.py index ae0b3935..db17d287 100644 --- a/data_diff/schema.py +++ b/data_diff/schema.py @@ -1,6 +1,5 @@ import logging -from data_diff import Database from data_diff.utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict from data_diff.abcs.database_types import DbPath @@ -9,13 +8,13 @@ Schema = CaseAwareMapping -def create_schema(db: Database, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: - logger.debug(f"[{db.name}] Schema = {schema}") +def create_schema(db_name: str, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: + logger.debug(f"[{db_name}] Schema = {schema}") if case_sensitive: return CaseSensitiveDict(schema) if len({k.lower() for k in schema}) < len(schema): - logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}') + logger.warning(f'Ambiguous schema for {db_name}:{".".join(table_path)} | Columns = {", ".join(list(schema))}') logger.warning("We recommend to disable case-insensitivity (set --case-sensitive).") return CaseInsensitiveDict(schema) diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 9864824a..aaf747f6 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -142,7 +142,7 @@ def _where(self): def _with_raw_schema(self, raw_schema: dict) -> Self: schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where()) - return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive)) + return self.new(_schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive)) def with_schema(self) -> Self: "Queries the table schema from the database, and returns a new instance of TableSegment, with a schema." diff --git a/tests/test_database.py b/tests/test_database.py index f5998609..b17cb7f0 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -127,7 +127,7 @@ def test_correct_timezone(self): t = table(name) raw_schema = db.query_table_schema(t.path) schema = db._process_table_schema(t.path, raw_schema) - schema = create_schema(self.database, t, schema, case_sensitive=True) + schema = create_schema(self.database.name, t, schema, case_sensitive=True) t = t.replace(schema=schema) t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision)