diff --git a/data_diff/sqeleton/databases/_connect.py b/data_diff/sqeleton/databases/_connect.py index 4417a7df..ad152dda 100644 --- a/data_diff/sqeleton/databases/_connect.py +++ b/data_diff/sqeleton/databases/_connect.py @@ -105,11 +105,6 @@ def for_databases(self, *dbs) -> Self: database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs} return type(self)(database_by_scheme) - def load_mixins(self, *abstract_mixins: AbstractMixin) -> Self: - "Extend all the databases with a list of mixins that implement the given abstract mixins." - database_by_scheme = {k: db.load_mixins(*abstract_mixins) for k, db in self.database_by_scheme.items()} - return type(self)(database_by_scheme) - def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs) -> Database: """Connect to the given database uri diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py index 7b65ad6f..ec41bac4 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/sqeleton/databases/base.py @@ -551,14 +551,6 @@ def list_tables(self, tables_like, schema=None): def table(self, *path, **kw): return bound_table(self, path, **kw) - @classmethod - def load_mixins(cls, *abstract_mixins) -> type: - class _DatabaseWithMixins(cls): - dialect = cls.dialect.load_mixins(*abstract_mixins) - - _DatabaseWithMixins.__name__ = cls.__name__ - return _DatabaseWithMixins - class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/tests/sqeleton/__init__.py b/tests/sqeleton/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/sqeleton/common.py b/tests/sqeleton/common.py deleted file mode 100644 index f5a29c44..00000000 --- a/tests/sqeleton/common.py +++ /dev/null @@ -1,160 +0,0 @@ -import hashlib -import os -import string -import random -from typing import Callable -import unittest -import logging -import subprocess - -from parameterized import parameterized_class - -import data_diff.sqeleton -from data_diff.sqeleton import databases as db -from data_diff.sqeleton.abcs.mixins import AbstractMixin_NormalizeValue -from data_diff.sqeleton.queries import table -from data_diff.sqeleton.databases import Database -from data_diff.sqeleton.query_utils import drop_table -from tests.common import ( - TEST_MYSQL_CONN_STRING, - TEST_POSTGRESQL_CONN_STRING, - TEST_SNOWFLAKE_CONN_STRING, - TEST_PRESTO_CONN_STRING, - TEST_BIGQUERY_CONN_STRING, - TEST_REDSHIFT_CONN_STRING, - TEST_ORACLE_CONN_STRING, - TEST_DATABRICKS_CONN_STRING, - TEST_TRINO_CONN_STRING, - TEST_CLICKHOUSE_CONN_STRING, - TEST_VERTICA_CONN_STRING, - TEST_DUCKDB_CONN_STRING, - N_THREADS, - TEST_ACROSS_ALL_DBS, - TEST_MSSQL_CONN_STRING, -) - - -def get_git_revision_short_hash() -> str: - return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip() - - -GIT_REVISION = get_git_revision_short_hash() - -level = logging.ERROR -if os.environ.get("LOG_LEVEL", False): - level = getattr(logging, os.environ["LOG_LEVEL"].upper()) - -logging.basicConfig(level=level) -logging.getLogger("database").setLevel(level) - -try: - from tests.sqeleton.local_settings import * -except ImportError: - pass # No local settings - - -CONN_STRINGS = { - db.BigQuery: TEST_BIGQUERY_CONN_STRING, - db.MySQL: TEST_MYSQL_CONN_STRING, - db.PostgreSQL: TEST_POSTGRESQL_CONN_STRING, - db.Snowflake: TEST_SNOWFLAKE_CONN_STRING, - db.Redshift: TEST_REDSHIFT_CONN_STRING, - db.Oracle: TEST_ORACLE_CONN_STRING, - db.Presto: TEST_PRESTO_CONN_STRING, - db.Databricks: TEST_DATABRICKS_CONN_STRING, - db.Trino: TEST_TRINO_CONN_STRING, - db.Clickhouse: TEST_CLICKHOUSE_CONN_STRING, - db.Vertica: TEST_VERTICA_CONN_STRING, - db.DuckDB: TEST_DUCKDB_CONN_STRING, - db.MsSQL: TEST_MSSQL_CONN_STRING, -} - -_database_instances = {} - - -def get_conn(cls: type, shared: bool = True) -> Database: - if shared: - if cls not in _database_instances: - _database_instances[cls] = get_conn(cls, shared=False) - return _database_instances[cls] - - con = data_diff.sqeleton.connect.load_mixins(AbstractMixin_NormalizeValue) - return con(CONN_STRINGS[cls], N_THREADS) - - -def _print_used_dbs(): - used = {k.__name__ for k, v in CONN_STRINGS.items() if v is not None} - unused = {k.__name__ for k, v in CONN_STRINGS.items() if v is None} - - print(f"Testing databases: {', '.join(used)}") - if unused: - logging.info(f"Connection not configured; skipping tests for: {', '.join(unused)}") - if TEST_ACROSS_ALL_DBS: - logging.info( - f"Full tests enabled (every db<->db). May take very long when many dbs are involved. ={TEST_ACROSS_ALL_DBS}" - ) - - -_print_used_dbs() -CONN_STRINGS = {k: v for k, v in CONN_STRINGS.items() if v is not None} - - -def random_table_suffix() -> str: - char_set = string.ascii_lowercase + string.digits - suffix = "_" - suffix += "".join(random.choice(char_set) for _ in range(5)) - return suffix - - -def str_to_checksum(str: str): - # hello world - # => 5eb63bbbe01eeed093cb22bb8f5acdc3 - # => cb22bb8f5acdc3 - # => 273350391345368515 - m = hashlib.md5() - m.update(str.encode("utf-8")) # encode to binary - md5 = m.hexdigest() - # 0-indexed, unlike DBs which are 1-indexed here, so +1 in dbs - half_pos = db.MD5_HEXDIGITS - db.CHECKSUM_HEXDIGITS - return int(md5[half_pos:], 16) - - -class DbTestCase(unittest.TestCase): - "Sets up a table for testing" - db_cls = None - table1_schema = None - shared_connection = True - - def setUp(self): - assert self.db_cls, self.db_cls - - self.connection = get_conn(self.db_cls, self.shared_connection) - - table_suffix = random_table_suffix() - self.table1_name = f"src{table_suffix}" - - self.table1_path = self.connection.parse_table_name(self.table1_name) - - drop_table(self.connection, self.table1_path) - - self.src_table = table(self.table1_path, schema=self.table1_schema) - if self.table1_schema: - self.connection.query(self.src_table.create()) - - return super().setUp() - - def tearDown(self): - drop_table(self.connection, self.table1_path) - - -def _parameterized_class_per_conn(test_databases): - test_databases = set(test_databases) - names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases] - return parameterized_class(("name", "db_cls"), names) - - -def test_each_database_in_list(databases) -> Callable: - def _test_per_database(cls): - return _parameterized_class_per_conn(databases)(cls) - - return _test_per_database diff --git a/tests/sqeleton/test_mixins.py b/tests/sqeleton/test_mixins.py deleted file mode 100644 index 02ee8b3e..00000000 --- a/tests/sqeleton/test_mixins.py +++ /dev/null @@ -1,36 +0,0 @@ -import unittest - -from data_diff.sqeleton import connect - -from data_diff.sqeleton.abcs import AbstractDialect, AbstractDatabase -from data_diff.sqeleton.abcs.mixins import ( - AbstractMixin_NormalizeValue, - AbstractMixin_RandomSample, - AbstractMixin_TimeTravel, -) - - -class TestMixins(unittest.TestCase): - def test_normalize(self): - # - Test sanity - ddb1 = connect("duckdb://:memory:") - assert not hasattr(ddb1.dialect, "normalize_boolean") - - # - Test abstract mixins - class NewAbstractDialect(AbstractDialect, AbstractMixin_NormalizeValue, AbstractMixin_RandomSample): - pass - - new_connect = connect.load_mixins(AbstractMixin_NormalizeValue, AbstractMixin_RandomSample) - ddb2: AbstractDatabase[NewAbstractDialect] = new_connect("duckdb://:memory:") - # Implementation may change; Just update the test - assert ddb2.dialect.normalize_boolean("bool", None) == "bool::INTEGER::VARCHAR" - assert ddb2.dialect.random_sample_n("x", 10) - - # - Test immutability - ddb3 = connect("duckdb://:memory:") - assert not hasattr(ddb3.dialect, "normalize_boolean") - - self.assertRaises(TypeError, connect.load_mixins, AbstractMixin_TimeTravel) - - new_connect = connect.for_databases("bigquery", "snowflake").load_mixins(AbstractMixin_TimeTravel) - self.assertRaises(NotImplementedError, new_connect, "duckdb://:memory:") diff --git a/tests/sqeleton/test_database.py b/tests/test_database.py similarity index 97% rename from tests/sqeleton/test_database.py rename to tests/test_database.py index 5faa9abf..1b967cc8 100644 --- a/tests/sqeleton/test_database.py +++ b/tests/test_database.py @@ -7,8 +7,7 @@ from data_diff.sqeleton import connect from data_diff.sqeleton import databases as dbs from data_diff.sqeleton.queries import table, current_timestamp, NormalizeAsString -from tests.common import TEST_MYSQL_CONN_STRING -from tests.sqeleton.common import str_to_checksum, test_each_database_in_list, get_conn, random_table_suffix +from tests.common import TEST_MYSQL_CONN_STRING, test_each_database_in_list, get_conn, str_to_checksum, random_table_suffix from data_diff.sqeleton.abcs.database_types import TimestampTZ TEST_DATABASES = { diff --git a/tests/sqeleton/test_query.py b/tests/test_query.py similarity index 100% rename from tests/sqeleton/test_query.py rename to tests/test_query.py diff --git a/tests/sqeleton/test_sql.py b/tests/test_sql.py similarity index 100% rename from tests/sqeleton/test_sql.py rename to tests/test_sql.py diff --git a/tests/sqeleton/test_utils.py b/tests/test_utils.py similarity index 100% rename from tests/sqeleton/test_utils.py rename to tests/test_utils.py