diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 0e5255e6..ba006d36 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -461,7 +461,7 @@ def _data_diff( schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths))) schema1, schema2 = schemas = [ - create_schema(db, table_path, schema, case_sensitive) + create_schema(db.name, table_path, schema, case_sensitive) for db, table_path, schema in safezip(dbs, table_paths, schemas) ] diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index 8f842123..9abb7d54 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -94,6 +94,7 @@ def match_path(self, dsn): class Connect: """Provides methods for connecting to a supported database using a URL or connection dict.""" + conn_cache: MutableMapping[Hashable, Database] def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index d55c59a6..3b2f1809 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -21,16 +21,39 @@ from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString from data_diff.utils import ArithString, is_uuid, join_iter, safezip from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this -from data_diff.queries.ast_classes import Alias, BinOp, CaseWhen, Cast, Column, Commit, Concat, ConstantTable, Count, \ - CreateTable, Cte, \ - CurrentTimestamp, DropTable, Func, \ - GroupBy, \ - ITable, In, InsertToTable, IsDistinctFrom, \ - Join, \ - Param, \ - Random, \ - Root, TableAlias, TableOp, TablePath, \ - TimeTravel, TruncateTable, UnaryOp, WhenThen, _ResolveColumn +from data_diff.queries.ast_classes import ( + Alias, + BinOp, + CaseWhen, + Cast, + Column, + Commit, + Concat, + ConstantTable, + Count, + CreateTable, + Cte, + CurrentTimestamp, + DropTable, + Func, + GroupBy, + ITable, + In, + InsertToTable, + IsDistinctFrom, + Join, + Param, + Random, + Root, + TableAlias, + TableOp, + TablePath, + TimeTravel, + TruncateTable, + UnaryOp, + WhenThen, + _ResolveColumn, +) from data_diff.abcs.database_types import ( Array, Struct, @@ -67,17 +90,11 @@ class CompileError(Exception): pass -# TODO: LATER: Resolve the circular imports of databases-compiler-dialects: -# A database uses a compiler to render the SQL query. -# The compiler delegates to a dialect. -# The dialect renders the SQL. -# AS IS: The dialect requires the db to normalize table paths — leading to the back-dependency. -# TO BE: All the tables paths must be pre-normalized before SQL rendering. -# Also: c.database.is_autocommit in render_commit(). -# After this, the Compiler can cease referring Database/Dialect at all, -# and be used only as a CompilingContext (a counter/data-bearing class). -# As a result, it becomes low-level util, and the circular dependency auto-resolves. -# Meanwhile, the easy fix is to simply move the Compiler here. +# TODO: remove once switched to attrs, where ForwardRef[]/strings are resolved. +class _RuntypeHackToFixCicularRefrencedDatabase: + dialect: "BaseDialect" + + @dataclass class Compiler(AbstractCompiler): """ @@ -90,7 +107,7 @@ class Compiler(AbstractCompiler): # Database is needed to normalize tables. Dialect is needed for recursive compilations. # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects. # In practice, we currently bind the dialects to the specific database classes. - database: "Database" + database: _RuntypeHackToFixCicularRefrencedDatabase in_select: bool = False # Compilation runtime flag in_join: bool = False # Compilation runtime flag @@ -102,7 +119,7 @@ class Compiler(AbstractCompiler): _counter: List = field(default_factory=lambda: [0]) @property - def dialect(self) -> "Dialect": + def dialect(self) -> "BaseDialect": return self.database.dialect # TODO: DEPRECATED: Remove once the dialect is used directly in all places. @@ -223,7 +240,6 @@ class BaseDialect(abc.ABC): SUPPORTS_PRIMARY_KEY = False SUPPORTS_INDEXES = False TYPE_CLASSES: Dict[str, type] = {} - MIXINS = frozenset() PLACEHOLDER_TABLE = None # Used for Oracle @@ -414,7 +430,9 @@ def render_checksum(self, c: Compiler, elem: Checksum) -> str: def render_concat(self, c: Compiler, elem: Concat) -> str: # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL - items = [f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '')" for expr in elem.exprs] + items = [ + f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '')" for expr in elem.exprs + ] assert items if len(items) == 1: return items[0] @@ -559,7 +577,7 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str: columns=columns, group_by_exprs=[Code(k) for k in keys], having_exprs=elem.having_exprs, - ) + ), ) keys_str = ", ".join(keys) @@ -567,9 +585,7 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str: having_str = ( " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else "" ) - select = ( - f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}" - ) + select = f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}" if c.in_select: select = f"({select}) {c.new_unique_name()}" @@ -601,7 +617,7 @@ def render_timetravel(self, c: Compiler, elem: TimeTravel) -> str: # TODO: why is it c.? why not self? time-trvelling is the dialect's thing, isnt't it? c.time_travel( elem.table, before=elem.before, timestamp=elem.timestamp, offset=elem.offset, statement=elem.statement - ) + ), ) def render_createtable(self, c: Compiler, elem: CreateTable) -> str: @@ -768,18 +784,6 @@ def _convert_db_precision_to_digits(self, p: int) -> int: # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format return math.floor(math.log(2**p, 10)) - @classmethod - def load_mixins(cls, *abstract_mixins) -> Self: - "Load a list of mixins that implement the given abstract mixins" - mixins = {m for m in cls.MIXINS if issubclass(m, abstract_mixins)} - - class _DialectWithMixins(cls, *mixins, *abstract_mixins): - pass - - _DialectWithMixins.__name__ = cls.__name__ - return _DialectWithMixins() - - @property @abstractmethod def name(self) -> str: @@ -822,7 +826,7 @@ def __getitem__(self, i): return self.rows[i] -class Database(abc.ABC): +class Database(abc.ABC, _RuntypeHackToFixCicularRefrencedDatabase): """Base abstract class for databases. Used for providing connection code and implementation specific SQL utilities. diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index feb98bde..15d60511 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -139,7 +139,9 @@ def time_travel( ) -class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect( + BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue +): name = "BigQuery" ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation TYPE_CLASSES = { @@ -159,7 +161,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Abstra } TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>") TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>") - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} def random(self) -> str: return "RAND()" diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 9366b922..70070934 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -125,7 +125,6 @@ class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, A "DateTime64": Timestamp, "Bool": Boolean, } - MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} def quote(self, s: str) -> str: return f'"{s}"' diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 67d0528d..eba88248 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -79,7 +79,6 @@ class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, A # Boolean "BOOLEAN": Boolean, } - MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} def quote(self, s: str): return f"`{s}`" diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index ba6afd63..ca9f5733 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -68,12 +68,13 @@ def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable: return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio)) -class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect( + BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue +): name = "DuckDB" ROUNDS_ON_PREC_LOSS = False SUPPORTS_PRIMARY_KEY = True SUPPORTS_INDEXES = True - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} TYPE_CLASSES = { # Timestamps diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 28d67c99..0e767417 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -58,7 +58,15 @@ def md5_as_int(self, s: str) -> str: return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))" -class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect( + BaseDialect, + Mixin_Schema, + Mixin_OptimizerHints, + Mixin_MD5, + Mixin_NormalizeValue, + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, +): name = "MsSQL" ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True @@ -98,8 +106,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_ "json": JSON, } - MIXINS = {Mixin_Schema, Mixin_NormalizeValue, Mixin_RandomSample} - def quote(self, s: str): return f"[{s}]" diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index d6dcba9e..e32d34af 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -60,7 +60,15 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: return f"TRIM(CAST({value} AS char))" -class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect( + BaseDialect, + Mixin_Schema, + Mixin_OptimizerHints, + Mixin_MD5, + Mixin_NormalizeValue, + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, +): name = "MySQL" ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True @@ -91,7 +99,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_ # Boolean "boolean": Boolean, } - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} def quote(self, s: str): return f"`{s}`" diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index f0309c11..3b4940c5 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -80,7 +80,15 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) -class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect( + BaseDialect, + Mixin_Schema, + Mixin_OptimizerHints, + Mixin_MD5, + Mixin_NormalizeValue, + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, +): name = "Oracle" SUPPORTS_PRIMARY_KEY = True SUPPORTS_INDEXES = True @@ -96,7 +104,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_ } ROUNDS_ON_PREC_LOSS = True PLACEHOLDER_TABLE = "DUAL" - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} def quote(self, s: str): return f'"{s}"' diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index dec9b9d3..f3495c99 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -60,12 +60,13 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: return f"{value}::text" -class PostgresqlDialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class PostgresqlDialect( + BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue +): name = "PostgreSQL" ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True SUPPORTS_INDEXES = True - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} TYPE_CLASSES = { # Timestamps diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index b4c45751..0ba4e09d 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -76,7 +76,9 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"cast ({value} as int)") -class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect( + BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue +): name = "Presto" ROUNDS_ON_PREC_LOSS = True TYPE_CLASSES = { @@ -94,7 +96,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Abstra # Boolean "boolean": Boolean, } - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} def explain_as_text(self, query: str) -> str: return f"EXPLAIN (FORMAT TEXT) {query}" diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 3a558425..f3a70b76 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -104,7 +104,9 @@ def time_travel( return code(f"{{table}} {at_or_before}({key} => {{value}})", table=table, value=value) -class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect( + BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue +): name = "Snowflake" ROUNDS_ON_PREC_LOSS = False TYPE_CLASSES = { @@ -121,7 +123,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Abstra # Boolean "BOOLEAN": Boolean, } - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} def explain_as_text(self, query: str) -> str: return f"EXPLAIN USING TEXT {query}" diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index e8fe9ec2..fc9edd04 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -78,7 +78,9 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) -class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect( + BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue +): name = "Vertica" ROUNDS_ON_PREC_LOSS = True @@ -96,7 +98,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Abstra # Boolean "boolean": Boolean, } - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} def quote(self, s: str): return f'"{s}"' diff --git a/data_diff/format.py b/data_diff/format.py index 8a515e1b..a8900e84 100644 --- a/data_diff/format.py +++ b/data_diff/format.py @@ -253,7 +253,7 @@ def _make_rows_diff( t1_exclusive_rows: List[Dict[str, Any]], t2_exclusive_rows: List[Dict[str, Any]], diff_rows: List[Dict[str, Any]], - key_columns: List[str] + key_columns: List[str], ) -> RowsDiff: diff_rows_jsonified = [] for row in diff_rows: @@ -268,10 +268,7 @@ def _make_rows_diff( t2_exclusive_rows_jsonified.append(_jsonify_exclusive(row, key_columns)) return RowsDiff( - exclusive=ExclusiveDiff( - dataset1=t1_exclusive_rows_jsonified, - dataset2=t2_exclusive_rows_jsonified - ), + exclusive=ExclusiveDiff(dataset1=t1_exclusive_rows_jsonified, dataset2=t2_exclusive_rows_jsonified), diff=diff_rows_jsonified, ) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index c40a2b99..91e2aecd 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -12,6 +12,7 @@ from data_diff.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake from data_diff.abcs.database_types import NumericType, DbPath +from data_diff.databases.base import Compiler from data_diff.queries.api import ( table, sum_, @@ -23,7 +24,6 @@ rightjoin, this, when, - Compiler, ) from data_diff.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code, ITable from data_diff.queries.extras import NormalizeAsString diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 56efdb20..4c5c45f4 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -796,6 +796,7 @@ class Commit(Statement): @dataclass class Param(ExprNode, ITable): """A value placeholder, to be specified at compilation time using the `cv_params` context variable.""" + name: str @property diff --git a/data_diff/schema.py b/data_diff/schema.py index ae0b3935..db17d287 100644 --- a/data_diff/schema.py +++ b/data_diff/schema.py @@ -1,6 +1,5 @@ import logging -from data_diff import Database from data_diff.utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict from data_diff.abcs.database_types import DbPath @@ -9,13 +8,13 @@ Schema = CaseAwareMapping -def create_schema(db: Database, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: - logger.debug(f"[{db.name}] Schema = {schema}") +def create_schema(db_name: str, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: + logger.debug(f"[{db_name}] Schema = {schema}") if case_sensitive: return CaseSensitiveDict(schema) if len({k.lower() for k in schema}) < len(schema): - logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}') + logger.warning(f'Ambiguous schema for {db_name}:{".".join(table_path)} | Columns = {", ".join(list(schema))}') logger.warning("We recommend to disable case-insensitivity (set --case-sensitive).") return CaseInsensitiveDict(schema) diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 9864824a..aaf747f6 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -142,7 +142,7 @@ def _where(self): def _with_raw_schema(self, raw_schema: dict) -> Self: schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where()) - return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive)) + return self.new(_schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive)) def with_schema(self) -> Self: "Queries the table schema from the database, and returns a new instance of TableSegment, with a schema." diff --git a/data_diff/utils.py b/data_diff/utils.py index b725285e..a3ce90cb 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -1,5 +1,6 @@ import json import logging +import math import re import string from abc import abstractmethod @@ -108,7 +109,6 @@ def as_insensitive(self): return CaseInsensitiveDict(self) - # -- Alphanumerics -- alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase diff --git a/tests/common.py b/tests/common.py index 222ae94b..cca8e798 100644 --- a/tests/common.py +++ b/tests/common.py @@ -149,8 +149,8 @@ def setUp(self): self.table_src_name = f"src{table_suffix}" self.table_dst_name = f"dst{table_suffix}" - self.table_src_path = self.connection.parse_table_name(self.table_src_name) - self.table_dst_path = self.connection.parse_table_name(self.table_dst_name) + self.table_src_path = self.connection.dialect.parse_table_name(self.table_src_name) + self.table_dst_path = self.connection.dialect.parse_table_name(self.table_dst_name) drop_table(self.connection, self.table_src_path) drop_table(self.connection, self.table_dst_path) diff --git a/tests/test_database.py b/tests/test_database.py index f5998609..c63fd95c 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -9,7 +9,13 @@ from data_diff.queries.api import table, current_timestamp from data_diff.queries.extras import NormalizeAsString from data_diff.schema import create_schema -from tests.common import TEST_MYSQL_CONN_STRING, test_each_database_in_list, get_conn, str_to_checksum, random_table_suffix +from tests.common import ( + TEST_MYSQL_CONN_STRING, + test_each_database_in_list, + get_conn, + str_to_checksum, + random_table_suffix, +) from data_diff.abcs.database_types import TimestampTZ TEST_DATABASES = { @@ -39,11 +45,7 @@ def test_connect_to_db(self): class TestMD5(unittest.TestCase): def test_md5_as_int(self): - class MD5Dialect(dbs.mysql.Dialect, dbs.mysql.Mixin_MD5): - pass - self.mysql = connect(TEST_MYSQL_CONN_STRING) - self.mysql.dialect = MD5Dialect() str = "hello world" query_fragment = self.mysql.dialect.md5_as_int("'{0}'".format(str)) @@ -65,7 +67,7 @@ class TestSchema(unittest.TestCase): def test_table_list(self): name = "tbl_" + random_table_suffix() db = get_conn(self.db_cls) - tbl = table(db.parse_table_name(name), schema={"id": int}) + tbl = table(db.dialect.parse_table_name(name), schema={"id": int}) q = db.dialect.list_tables(db.default_schema, name) assert not db.query(q) @@ -79,7 +81,7 @@ def test_type_mapping(self): name = "tbl_" + random_table_suffix() db = get_conn(self.db_cls) tbl = table( - db.parse_table_name(name), + db.dialect.parse_table_name(name), schema={ "int": int, "float": float, @@ -127,7 +129,7 @@ def test_correct_timezone(self): t = table(name) raw_schema = db.query_table_schema(t.path) schema = db._process_table_schema(t.path, raw_schema) - schema = create_schema(self.database, t, schema, case_sensitive=True) + schema = create_schema(db.name, t, schema, case_sensitive=True) t = t.replace(schema=schema) t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 3d345296..6e0e5215 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -689,8 +689,8 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego src_table_name = f"src_{self._testMethodName[11:]}{table_suffix}" dst_table_name = f"dst_{self._testMethodName[11:]}{table_suffix}" - self.src_table_path = src_table_path = src_conn.parse_table_name(src_table_name) - self.dst_table_path = dst_table_path = dst_conn.parse_table_name(dst_table_name) + self.src_table_path = src_table_path = src_conn.dialect.parse_table_name(src_table_name) + self.dst_table_path = dst_table_path = dst_conn.dialect.parse_table_name(dst_table_name) start = time.monotonic() if not BENCHMARK: diff --git a/tests/test_dbt_parser.py b/tests/test_dbt_parser.py index 4e8b20d5..4fbdbde1 100644 --- a/tests/test_dbt_parser.py +++ b/tests/test_dbt_parser.py @@ -128,7 +128,7 @@ def test_get_run_results_models_bad_lower_dbt_version(self, mock_open, mock_arti with self.assertRaises(DataDiffDbtRunResultsVersionError) as ex: DbtParser.get_run_results_models(mock_self) - + mock_open.assert_called_once_with(Path(RUN_RESULTS_PATH)) mock_artifact_parser.assert_called_once_with({}) self.assertIn("version to be", ex.exception.args[0]) @@ -145,10 +145,10 @@ def test_get_run_results_models_no_success(self, mock_open, mock_artifact_parser mock_run_results.metadata.dbt_version = "1.0.0" mock_fail_result.unique_id = "fail_unique_id" mock_run_results.results = [mock_fail_result] - + with self.assertRaises(DataDiffDbtNoSuccessfulModelsInRunError): DbtParser.get_run_results_models(mock_self) - + mock_open.assert_any_call(Path(RUN_RESULTS_PATH)) mock_artifact_parser.assert_called_once_with({}) diff --git a/tests/test_format.py b/tests/test_format.py index 4743acc4..e7ee132c 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -2,7 +2,7 @@ from data_diff.diff_tables import DiffResultWrapper, InfoTree, SegmentInfo, TableSegment from data_diff.format import jsonify from data_diff.abcs.database_types import Integer -from data_diff.databases.base import Database +from tests.test_query import MockDatabase class TestFormat(unittest.TestCase): @@ -13,8 +13,12 @@ def test_jsonify_diff(self): info_tree=InfoTree( info=SegmentInfo( tables=[ - TableSegment(table_path=("db", "schema", "table1"), key_columns=("id",), database=Database()), - TableSegment(table_path=("db", "schema", "table2"), key_columns=("id",), database=Database()), + TableSegment( + table_path=("db", "schema", "table1"), key_columns=("id",), database=MockDatabase() + ), + TableSegment( + table_path=("db", "schema", "table2"), key_columns=("id",), database=MockDatabase() + ), ], diff_schema=( ("is_exclusive_a", bool), @@ -100,8 +104,12 @@ def test_jsonify_no_stats(self): info_tree=InfoTree( info=SegmentInfo( tables=[ - TableSegment(table_path=("db", "schema", "table1"), key_columns=("id",), database=Database()), - TableSegment(table_path=("db", "schema", "table2"), key_columns=("id",), database=Database()), + TableSegment( + table_path=("db", "schema", "table1"), key_columns=("id",), database=MockDatabase() + ), + TableSegment( + table_path=("db", "schema", "table2"), key_columns=("id",), database=MockDatabase() + ), ], diff_schema=( ("is_exclusive_a", bool), @@ -139,7 +147,7 @@ def test_jsonify_no_stats(self): "removed": [], "typeChanged": [], }, - stats_only=True + stats_only=True, ) self.assertEqual( @@ -177,8 +185,12 @@ def test_jsonify_diff_no_difeference(self): info_tree=InfoTree( info=SegmentInfo( tables=[ - TableSegment(table_path=("db", "schema", "table1"), key_columns=("id",), database=Database()), - TableSegment(table_path=("db", "schema", "table2"), key_columns=("id",), database=Database()), + TableSegment( + table_path=("db", "schema", "table1"), key_columns=("id",), database=MockDatabase() + ), + TableSegment( + table_path=("db", "schema", "table2"), key_columns=("id",), database=MockDatabase() + ), ], diff_schema=( ("is_exclusive_a", bool), @@ -251,8 +263,12 @@ def test_jsonify_column_suffix_fix(self): info_tree=InfoTree( info=SegmentInfo( tables=[ - TableSegment(table_path=("db", "schema", "table1"), key_columns=("id_a",), database=Database()), - TableSegment(table_path=("db", "schema", "table2"), key_columns=("id_a",), database=Database()), + TableSegment( + table_path=("db", "schema", "table1"), key_columns=("id_a",), database=MockDatabase() + ), + TableSegment( + table_path=("db", "schema", "table2"), key_columns=("id_a",), database=MockDatabase() + ), ], diff_schema=( ("is_exclusive_a", bool), diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index b2c5c419..e7e9ec86 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -113,7 +113,7 @@ def test_diff_small_tables(self): # self.assertEqual(1, self.differ.stats["table2_min_id"]) # Test materialize - materialize_path = self.connection.parse_table_name(f"test_mat_{random_table_suffix()}") + materialize_path = self.connection.dialect.parse_table_name(f"test_mat_{random_table_suffix()}") mdiffer = self.differ.replace(materialize_to_table=materialize_path) diff = list(mdiffer.diff_tables(self.table, self.table2)) self.assertEqual(expected, diff) diff --git a/tests/test_query.py b/tests/test_query.py index bd731cfb..0d253dd5 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import List, Optional import unittest -from data_diff.abcs.database_types import AbstractDatabase, AbstractDialect +from data_diff.databases.base import Database, BaseDialect from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict from data_diff.databases.base import Compiler, CompileError @@ -14,7 +14,7 @@ def normalize_spaces(s: str): return " ".join(s.split()) -class MockDialect(AbstractDialect): +class MockDialect(BaseDialect): name = "MockDialect" PLACEHOLDER_TABLE = None @@ -66,13 +66,12 @@ def set_timezone_to_utc(self) -> str: def optimizer_hints(self, s: str): return f"/*+ {s} */ " - def load_mixins(self): - raise NotImplementedError() - parse_type = NotImplemented -class MockDatabase(AbstractDatabase): +class MockDatabase(Database): + CONNECT_URI_HELP = "mock://" + CONNECT_URI_PARAMS = [] dialect = MockDialect() _query = NotImplemented diff --git a/tests/test_sql.py b/tests/test_sql.py index 2dcab403..6293d0bd 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -3,7 +3,8 @@ from tests.common import TEST_MYSQL_CONN_STRING from data_diff.databases import connect -from data_diff.queries.api import Compiler, Count, Explain, Select, table, In, BinOp, Code +from data_diff.databases.base import Compiler +from data_diff.queries.api import Count, Explain, Select, table, In, BinOp, Code class TestSQL(unittest.TestCase):