diff --git a/.gitignore b/.gitignore index 644d7186..e1f0a901 100644 --- a/.gitignore +++ b/.gitignore @@ -149,4 +149,6 @@ benchmark_*.png .vscode # History -.history \ No newline at end of file +.history + +docker-compose-local.yml diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 03eaed22..95b1e7cf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -66,7 +66,7 @@ Make sure to update the appropriate `TEST_*_CONN_STRING`, so that it will be inc #### Run the tests -You can run the tests with `unittest`. +You can run the tests with `python -m unittest`. When running against multiple databases, the tests can take a long while. @@ -111,6 +111,15 @@ $ poetry install # Install dependencies $ docker-compose up -d mysql postgres # run mysql and postgres dbs in background ``` +If you want to change the configuration of docker-compose and run the DB containers, copy docker-compose.yml into docker-compose-local.yml, make changes and run +```shell-session +$ cp docker-compose.yml docker-compose-local.yml +$ docker-compose -f docker-compose-local.yml up -d mysql postgres # run mysql and postgres dbs in background + ``` +you will also have to set up `tests/local_settings.py` where `TEST_*_CONN_STRING` can be edited + +`docker-compose-local.yml` and `tests/local_settings.py` is git ignored so should not show up in git changes. + [docker-compose]: https://docs.docker.com/compose/install/ **3. Run Unit Tests** diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 02f3b31d..ffd9bfa7 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -1,34 +1,32 @@ -from copy import deepcopy -from datetime import datetime +import json +import logging import os import sys import time -import json -import logging +from copy import deepcopy +from datetime import datetime from itertools import islice -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Union, List, Set +import click import rich from rich.logging import RichHandler -import click from data_diff import Database, DbPath -from data_diff.schema import RawColumnInfo, create_schema -from data_diff.queries.api import current_timestamp - +from data_diff.config import apply_config_from_file +from data_diff.databases._connect import connect from data_diff.dbt import dbt_diff -from data_diff.utils import eval_name_template, remove_password_from_url, safezip, match_like, LogStatusHandler -from data_diff.diff_tables import Algorithm +from data_diff.diff_tables import Algorithm, TableDiffer from data_diff.hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR from data_diff.joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer -from data_diff.table_segment import TableSegment -from data_diff.databases._connect import connect from data_diff.parse_time import parse_time_before, UNITS_STR, ParseError -from data_diff.config import apply_config_from_file +from data_diff.queries.api import current_timestamp +from data_diff.schema import RawColumnInfo, create_schema +from data_diff.table_segment import TableSegment from data_diff.tracking import disable_tracking, set_entrypoint_name +from data_diff.utils import eval_name_template, remove_password_from_url, safezip, match_like, LogStatusHandler from data_diff.version import __version__ - COLOR_SCHEME = { "+": "green", "-": "red", @@ -347,6 +345,144 @@ def main(conf, run, **kw) -> None: raise +def _get_dbs( + threads: int, database1: str, threads1: int, database2: str, threads2: int, interactive: bool +) -> Tuple[Database, Database]: + db1 = connect(database1, threads1 or threads) + if database1 == database2: + db2 = db1 + else: + db2 = connect(database2, threads2 or threads) + + if interactive: + db1.enable_interactive() + db2.enable_interactive() + + return db1, db2 + + +def _set_age(options: dict, min_age: Optional[str], max_age: Optional[str], db: Database) -> None: + if min_age or max_age: + now: datetime = db.query(current_timestamp(), datetime).replace(tzinfo=None) + try: + if max_age: + options["min_update"] = parse_time_before(now, max_age) + if min_age: + options["max_update"] = parse_time_before(now, min_age) + except ParseError as e: + logging.error(f"Error while parsing age expression: {e}") + + +def _get_table_differ( + algorithm: str, + db1: Database, + db2: Database, + threaded: bool, + threads: int, + assume_unique_key: bool, + sample_exclusive_rows: bool, + materialize_all_rows: bool, + table_write_limit: int, + materialize_to_table: Optional[str], + bisection_factor: Optional[int], + bisection_threshold: Optional[int], +) -> TableDiffer: + algorithm = Algorithm(algorithm) + if algorithm == Algorithm.AUTO: + algorithm = Algorithm.JOINDIFF if db1 == db2 else Algorithm.HASHDIFF + + logging.info(f"Using algorithm '{algorithm.name.lower()}'.") + + if algorithm == Algorithm.JOINDIFF: + return JoinDiffer( + threaded=threaded, + max_threadpool_size=threads and threads * 2, + validate_unique_key=not assume_unique_key, + sample_exclusive_rows=sample_exclusive_rows, + materialize_all_rows=materialize_all_rows, + table_write_limit=table_write_limit, + materialize_to_table=( + materialize_to_table and db1.dialect.parse_table_name(eval_name_template(materialize_to_table)) + ), + ) + + assert algorithm == Algorithm.HASHDIFF + return HashDiffer( + bisection_factor=DEFAULT_BISECTION_FACTOR if bisection_factor is None else bisection_factor, + bisection_threshold=DEFAULT_BISECTION_THRESHOLD if bisection_threshold is None else bisection_threshold, + threaded=threaded, + max_threadpool_size=threads and threads * 2, + ) + + +def _print_result(stats, json_output, diff_iter) -> None: + if stats: + if json_output: + rich.print(json.dumps(diff_iter.get_stats_dict())) + else: + rich.print(diff_iter.get_stats_string()) + + else: + for op, values in diff_iter: + color = COLOR_SCHEME.get(op, "grey62") + + if json_output: + jsonl = json.dumps([op, list(values)]) + rich.print(f"[{color}]{jsonl}[/{color}]") + else: + text = f"{op} {', '.join(map(str, values))}" + rich.print(f"[{color}]{text}[/{color}]") + + sys.stdout.flush() + + +def _get_expanded_columns( + columns: List[str], + case_sensitive: bool, + mutual: Set[str], + db1: Database, + schema1: dict, + table1: str, + db2: Database, + schema2: dict, + table2: str, +) -> Set[str]: + expanded_columns: Set[str] = set() + for c in columns: + cc = c if case_sensitive else c.lower() + match = set(match_like(cc, mutual)) + if not match: + m1 = None if any(match_like(cc, schema1.keys())) else f"{db1}/{table1}" + m2 = None if any(match_like(cc, schema2.keys())) else f"{db2}/{table2}" + not_matched = ", ".join(m for m in [m1, m2] if m) + raise ValueError(f"Column '{c}' not found in: {not_matched}") + + expanded_columns |= match + return expanded_columns + + +def _get_threads(threads: Union[int, str, None], threads1: Optional[int], threads2: Optional[int]) -> Tuple[bool, int]: + threaded = True + if threads is None: + threads = 1 + elif isinstance(threads, str) and threads.lower() == "serial": + assert not (threads1 or threads2) + threaded = False + threads = 1 + else: + try: + threads = int(threads) + except ValueError: + logging.error("Error: threads must be a number, or 'serial'.") + raise + + if threads < 1: + logging.error("Error: threads must be >= 1") + raise ValueError("Error: threads must be >= 1") + + return threaded, threads + + def _data_diff( database1, table1, @@ -393,26 +529,7 @@ def _data_diff( return key_columns = key_columns or ("id",) - bisection_factor = DEFAULT_BISECTION_FACTOR if bisection_factor is None else int(bisection_factor) - bisection_threshold = DEFAULT_BISECTION_THRESHOLD if bisection_threshold is None else int(bisection_threshold) - - threaded = True - if threads is None: - threads = 1 - elif isinstance(threads, str) and threads.lower() == "serial": - assert not (threads1 or threads2) - threaded = False - threads = 1 - else: - try: - threads = int(threads) - except ValueError: - logging.error("Error: threads must be a number, or 'serial'.") - return - if threads < 1: - logging.error("Error: threads must be >= 1") - return - + threaded, threads = _get_threads(threads, threads1, threads2) start = time.monotonic() if database1 is None or database2 is None: @@ -421,133 +538,79 @@ def _data_diff( ) return - db1 = connect(database1, threads1 or threads) - if database1 == database2: - db2 = db1 - else: - db2 = connect(database2, threads2 or threads) - - options = dict( - case_sensitive=case_sensitive, - where=where, - ) - - if min_age or max_age: - now: datetime = db1.query(current_timestamp(), datetime) - now = now.replace(tzinfo=None) - try: - if max_age: - options["min_update"] = parse_time_before(now, max_age) - if min_age: - options["max_update"] = parse_time_before(now, min_age) - except ParseError as e: - logging.error(f"Error while parsing age expression: {e}") - return - - dbs: Tuple[Database, Database] = db1, db2 - - if interactive: - for db in dbs: - db.enable_interactive() - - algorithm = Algorithm(algorithm) - if algorithm == Algorithm.AUTO: - algorithm = Algorithm.JOINDIFF if db1 == db2 else Algorithm.HASHDIFF - - if algorithm == Algorithm.JOINDIFF: - differ = JoinDiffer( - threaded=threaded, - max_threadpool_size=threads and threads * 2, - validate_unique_key=not assume_unique_key, - sample_exclusive_rows=sample_exclusive_rows, - materialize_all_rows=materialize_all_rows, - table_write_limit=table_write_limit, - materialize_to_table=materialize_to_table - and db1.dialect.parse_table_name(eval_name_template(materialize_to_table)), - ) - else: - assert algorithm == Algorithm.HASHDIFF - differ = HashDiffer( - bisection_factor=bisection_factor, - bisection_threshold=bisection_threshold, - threaded=threaded, - max_threadpool_size=threads and threads * 2, + db1: Database + db2: Database + db1, db2 = _get_dbs(threads, database1, threads1, database2, threads2, interactive) + with db1, db2: + options = { + "case_sensitive": case_sensitive, + "where": where, + } + + _set_age(options, min_age, max_age, db1) + dbs: Tuple[Database, Database] = db1, db2 + + differ = _get_table_differ( + algorithm, + db1, + db2, + threaded, + threads, + assume_unique_key, + sample_exclusive_rows, + materialize_all_rows, + table_write_limit, + materialize_to_table, + bisection_factor, + bisection_threshold, ) - table_names = table1, table2 - table_paths = [db.dialect.parse_table_name(t) for db, t in safezip(dbs, table_names)] + table_names = table1, table2 + table_paths = [db.dialect.parse_table_name(t) for db, t in safezip(dbs, table_names)] - schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths))) - schema1, schema2 = schemas = [ - create_schema(db.name, table_path, schema, case_sensitive) - for db, table_path, schema in safezip(dbs, table_paths, schemas) - ] + schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths))) + schema1, schema2 = schemas = [ + create_schema(db.name, table_path, schema, case_sensitive) + for db, table_path, schema in safezip(dbs, table_paths, schemas) + ] - mutual = schema1.keys() & schema2.keys() # Case-aware, according to case_sensitive - logging.debug(f"Available mutual columns: {mutual}") - - expanded_columns = set() - for c in columns: - cc = c if case_sensitive else c.lower() - match = set(match_like(cc, mutual)) - if not match: - m1 = None if any(match_like(cc, schema1.keys())) else f"{db1}/{table1}" - m2 = None if any(match_like(cc, schema2.keys())) else f"{db2}/{table2}" - not_matched = ", ".join(m for m in [m1, m2] if m) - raise ValueError(f"Column '{c}' not found in: {not_matched}") + mutual = schema1.keys() & schema2.keys() # Case-aware, according to case_sensitive + logging.debug(f"Available mutual columns: {mutual}") - expanded_columns |= match - - columns = tuple(expanded_columns - {*key_columns, update_column}) - - if db1 == db2: - diff_schemas( - table_names[0], - table_names[1], - schema1, - schema2, - ( - *key_columns, - update_column, - *columns, - ), + expanded_columns = _get_expanded_columns( + columns, case_sensitive, mutual, db1, schema1, table1, db2, schema2, table2 ) + columns = tuple(expanded_columns - {*key_columns, update_column}) + + if db1 == db2: + diff_schemas( + table_names[0], + table_names[1], + schema1, + schema2, + ( + *key_columns, + update_column, + *columns, + ), + ) - logging.info(f"Diffing using columns: key={key_columns} update={update_column} extra={columns}.") - logging.info(f"Using algorithm '{algorithm.name.lower()}'.") - - segments = [ - TableSegment(db, table_path, key_columns, update_column, columns, **options)._with_raw_schema(raw_schema) - for db, table_path, raw_schema in safezip(dbs, table_paths, schemas) - ] - - diff_iter = differ.diff_tables(*segments) - - if limit: - assert not stats - diff_iter = islice(diff_iter, int(limit)) + logging.info(f"Diffing using columns: key={key_columns} update={update_column} extra={columns}.") - if stats: - if json_output: - rich.print(json.dumps(diff_iter.get_stats_dict())) - else: - rich.print(diff_iter.get_stats_string()) + segments = [ + TableSegment(db, table_path, key_columns, update_column, columns, **options)._with_raw_schema(raw_schema) + for db, table_path, raw_schema in safezip(dbs, table_paths, schemas) + ] - else: - for op, values in diff_iter: - color = COLOR_SCHEME.get(op, "grey62") + diff_iter = differ.diff_tables(*segments) - if json_output: - jsonl = json.dumps([op, list(values)]) - rich.print(f"[{color}]{jsonl}[/{color}]") - else: - text = f"{op} {', '.join(map(str, values))}" - rich.print(f"[{color}]{text}[/{color}]") + if limit: + assert not stats + diff_iter = islice(diff_iter, int(limit)) - sys.stdout.flush() + _print_result(stats, json_output, diff_iter) end = time.monotonic() - logging.info(f"Duration: {end-start:.2f} seconds.") diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index d6549f71..8adea0b5 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -938,6 +938,12 @@ class Database(abc.ABC): is_closed: bool = False _dialect: BaseDialect = None + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + @property def name(self): return type(self).__name__ @@ -1058,7 +1064,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]: return d def select_table_unique_columns(self, path: DbPath) -> str: - "Provide SQL for selecting the names of unique columns in the table" + """Provide SQL for selecting the names of unique columns in the table""" schema, name = self._normalize_table_path(path) return ( @@ -1180,9 +1186,8 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> Que return apply_query(callback, sql_code) def close(self): - "Close connection(s) to the database instance. Querying will stop functioning." + """Close connection(s) to the database instance. Querying will stop functioning.""" self.is_closed = True - return super().close() @property def dialect(self) -> BaseDialect: @@ -1241,18 +1246,20 @@ def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult: return r.result() def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]): - "This method runs in a worker thread" + """This method runs in a worker thread""" if self._init_error: raise self._init_error return self._query_conn(self.thread_local.conn, sql_code) @abstractmethod def create_connection(self): - "Return a connection instance, that supports the .cursor() method." + """Return a connection instance, that supports the .cursor() method.""" def close(self): super().close() self._queue.shutdown() + if hasattr(self.thread_local, "conn"): + self.thread_local.conn.close() @property def is_autocommit(self) -> bool: diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index caa1d0b2..59d8a49d 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -25,7 +25,7 @@ Text, Boolean, Date, - Time + Time, ) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 631b5c49..d93a46d7 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -17,7 +17,7 @@ FractionalType, Boolean, Date, - Time + Time, ) from data_diff.databases.base import BaseDialect, ThreadedDatabase, import_helper, ConnectError from data_diff.databases.base import ( @@ -251,3 +251,8 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: raise ValueError( f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" ) + + def close(self): + super().close() + if self._conn is not None: + self._conn.close() diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index cfe046d2..9b1de1f4 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -146,8 +146,7 @@ def __init__(self, *, thread_count, **kw) -> None: def create_connection(self): vertica = import_vertica() try: - c = vertica.connect(**self._args) - return c + return vertica.connect(**self._args) except vertica.errors.ConnectionError as e: raise ConnectError(*e.args) from e diff --git a/tests/cloud/test_data_source.py b/tests/cloud/test_data_source.py index cc524a32..1e4a8129 100644 --- a/tests/cloud/test_data_source.py +++ b/tests/cloud/test_data_source.py @@ -22,7 +22,7 @@ _test_data_source, ) from data_diff.dbt_parser import TDatadiffConfig - +from tests.common import ansi_stdout_cleanup DATA_SOURCE_CONFIGS = { "snowflake": TDsConfig( @@ -262,7 +262,7 @@ def test_create_ds_snowflake_config_from_dbt_profiles_one_param_passed_through_i ) self.assertEqual(actual_config, config) self.assertEqual( - mock_stdout.getvalue().strip(), + ansi_stdout_cleanup(mock_stdout.getvalue().strip()), 'Cannot extract "account" from dbt profiles.yml. Please, type it manually', ) @@ -294,7 +294,7 @@ def test_create_ds_config_validate_required_parameter(self, mock_stdout): data_source_name=config.name, ) self.assertEqual(actual_config, config) - self.assertEqual(mock_stdout.getvalue().strip(), "Parameter must not be empty") + self.assertEqual(ansi_stdout_cleanup(mock_stdout.getvalue().strip()), "Parameter must not be empty") def test_check_data_source_exists(self): self.assertEqual(_check_data_source_exists(self.data_sources, self.data_sources[0].name), self.data_sources[0]) diff --git a/tests/common.py b/tests/common.py index 2fc6be19..d6dd94e1 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,5 +1,6 @@ import hashlib import os +import re import string import random from typing import Callable @@ -84,15 +85,8 @@ def get_git_revision_short_hash() -> str: db.MsSQL: TEST_MSSQL_CONN_STRING, } -_database_instances = {} - - -def get_conn(cls: type, shared: bool = True) -> Database: - if shared: - if cls not in _database_instances: - _database_instances[cls] = get_conn(cls, shared=False) - return _database_instances[cls] +def get_conn(cls: type) -> Database: return connect(CONN_STRINGS[cls], N_THREADS) @@ -134,17 +128,16 @@ def str_to_checksum(str: str): class DiffTestCase(unittest.TestCase): - "Sets up two tables for diffing" + """Sets up two tables for diffing""" db_cls = None src_schema = None dst_schema = None - shared_connection = True def setUp(self): assert self.db_cls, self.db_cls - self.connection = get_conn(self.db_cls, self.shared_connection) + self.connection = get_conn(self.db_cls) table_suffix = random_table_suffix() self.table_src_name = f"src{table_suffix}" @@ -187,3 +180,7 @@ def table_segment(database, table_path, key_columns, *args, **kw): if isinstance(key_columns, str): key_columns = (key_columns,) return TableSegment(database, table_path, key_columns, *args, **kw) + + +def ansi_stdout_cleanup(ansi_input) -> str: + return re.sub(r"\x1B\[[0-?]*[ -/]*[@-~]", "", ansi_input) diff --git a/tests/test_database.py b/tests/test_database.py index 2713ac16..f5055ce7 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -5,8 +5,9 @@ import attrs import pytz -from data_diff import connect +from data_diff import connect, Database from data_diff import databases as dbs +from data_diff.abcs.database_types import TimestampTZ from data_diff.queries.api import table, current_timestamp from data_diff.queries.extras import NormalizeAsString from data_diff.schema import create_schema @@ -17,7 +18,6 @@ str_to_checksum, random_table_suffix, ) -from data_diff.abcs.database_types import TimestampTZ TEST_DATABASES = { dbs.MySQL, @@ -48,11 +48,11 @@ class TestMD5(unittest.TestCase): def test_md5_as_int(self): self.mysql = connect(TEST_MYSQL_CONN_STRING) - str = "hello world" - query_fragment = self.mysql.dialect.md5_as_int("'{0}'".format(str)) + message = "hello world" + query_fragment = self.mysql.dialect.md5_as_int(f"'{message}'") query = f"SELECT {query_fragment}" - self.assertEqual(str_to_checksum(str), self.mysql.query(query, int)) + self.assertEqual(str_to_checksum(message), self.mysql.query(query, int)) class TestConnect(unittest.TestCase): @@ -74,42 +74,48 @@ def test_correct_timezone(self): if self.db_cls in [dbs.MsSQL]: self.skipTest("No support for session tz.") name = "tbl_" + random_table_suffix() - db = get_conn(self.db_cls) - tbl = table(name, schema={"id": int, "created_at": TimestampTZ(9), "updated_at": TimestampTZ(9)}) - db.query(tbl.create()) + db_connection = get_conn(self.db_cls) + with db_connection: + tbl = table(name, schema={"id": int, "created_at": TimestampTZ(9), "updated_at": TimestampTZ(9)}) + + db_connection.query(tbl.create()) - tz = pytz.timezone("Europe/Berlin") + tz = pytz.timezone("Europe/Berlin") - now = datetime.now(tz) - if isinstance(db, dbs.Presto): - ms = now.microsecond // 1000 * 1000 # Presto max precision is 3 - now = now.replace(microsecond=ms) + now = datetime.now(tz) + if isinstance(db_connection, dbs.Presto): + ms = now.microsecond // 1000 * 1000 # Presto max precision is 3 + now = now.replace(microsecond=ms) - db.query(table(name).insert_row(1, now, now)) - db.query(db.dialect.set_timezone_to_utc()) + db_connection.query(table(name).insert_row(1, now, now)) + db_connection.query(db_connection.dialect.set_timezone_to_utc()) - t = table(name) - raw_schema = db.query_table_schema(t.path) - schema = db._process_table_schema(t.path, raw_schema) - schema = create_schema(db.name, t, schema, case_sensitive=True) - t = attrs.evolve(t, schema=schema) - t.schema["created_at"] = attrs.evolve(t.schema["created_at"], precision=t.schema["created_at"].precision) + table_object = table(name) + raw_schema = db_connection.query_table_schema(table_object.path) + schema = db_connection._process_table_schema(table_object.path, raw_schema) + schema = create_schema(db_connection.name, table_object, schema, case_sensitive=True) + table_object = attrs.evolve(table_object, schema=schema) + table_object.schema["created_at"] = attrs.evolve( + table_object.schema["created_at"], precision=table_object.schema["created_at"].precision + ) - tbl = table(name, schema=t.schema) + tbl = table(name, schema=table_object.schema) - results = db.query(tbl.select(NormalizeAsString(tbl[c]) for c in ["created_at", "updated_at"]), List[Tuple]) + results = db_connection.query( + tbl.select(NormalizeAsString(tbl[c]) for c in ["created_at", "updated_at"]), List[Tuple] + ) - created_at = results[0][1] - updated_at = results[0][1] + created_at = results[0][1] + updated_at = results[0][1] - utc = now.astimezone(pytz.UTC) - expected = utc.__format__("%Y-%m-%d %H:%M:%S.%f") + utc = now.astimezone(pytz.UTC) + expected = utc.__format__("%Y-%m-%d %H:%M:%S.%f") - self.assertEqual(created_at, expected) - self.assertEqual(updated_at, expected) + self.assertEqual(created_at, expected) + self.assertEqual(updated_at, expected) - db.query(tbl.drop()) + db_connection.query(tbl.drop()) @test_each_database @@ -119,51 +125,77 @@ def test_three_part_support(self): self.skipTest("Limited support for 3 part ids") table_name = "tbl_" + random_table_suffix() - db = get_conn(self.db_cls) - db_res = db.query(f"SELECT {db.dialect.current_database()}") - schema_res = db.query(f"SELECT {db.dialect.current_schema()}") - db_name = db_res.rows[0][0] - schema_name = schema_res.rows[0][0] + db_connection = get_conn(self.db_cls) + with db_connection: + db_res = db_connection.query(f"SELECT {db_connection.dialect.current_database()}") + schema_res = db_connection.query(f"SELECT {db_connection.dialect.current_schema()}") + db_name = db_res.rows[0][0] + schema_name = schema_res.rows[0][0] - table_one_part = table((table_name,), schema={"id": int}) - table_two_part = table((schema_name, table_name), schema={"id": int}) - table_three_part = table((db_name, schema_name, table_name), schema={"id": int}) + table_one_part = table((table_name,), schema={"id": int}) + table_two_part = table((schema_name, table_name), schema={"id": int}) + table_three_part = table((db_name, schema_name, table_name), schema={"id": int}) - for part in (table_one_part, table_two_part, table_three_part): - db.query(part.create()) - d = db.query_table_schema(part.path) - assert len(d) == 1 - db.query(part.drop()) + for part in (table_one_part, table_two_part, table_three_part): + db_connection.query(part.create()) + schema = db_connection.query_table_schema(part.path) + assert len(schema) == 1 + db_connection.query(part.drop()) @test_each_database class TestNumericPrecisionParsing(unittest.TestCase): def test_specified_precision(self): name = "tbl_" + random_table_suffix() - db = get_conn(self.db_cls) - tbl = table(name, schema={"value": "DECIMAL(10, 2)"}) - db.query(tbl.create()) - t = table(name) - raw_schema = db.query_table_schema(t.path) - schema = db._process_table_schema(t.path, raw_schema) - self.assertEqual(schema["value"].precision, 2) + db_connection = get_conn(self.db_cls) + with db_connection: + table_object = table(name, schema={"value": "DECIMAL(10, 2)"}) + db_connection.query(table_object.create()) + table_object = table(name) + raw_schema = db_connection.query_table_schema(table_object.path) + schema = db_connection._process_table_schema(table_object.path, raw_schema) + self.assertEqual(schema["value"].precision, 2) def test_specified_zero_precision(self): name = "tbl_" + random_table_suffix() - db = get_conn(self.db_cls) - tbl = table(name, schema={"value": "DECIMAL(10)"}) - db.query(tbl.create()) - t = table(name) - raw_schema = db.query_table_schema(t.path) - schema = db._process_table_schema(t.path, raw_schema) - self.assertEqual(schema["value"].precision, 0) + db_connection = get_conn(self.db_cls) + with db_connection: + table_object = table(name, schema={"value": "DECIMAL(10)"}) + db_connection.query(table_object.create()) + table_object = table(name) + raw_schema = db_connection.query_table_schema(table_object.path) + schema = db_connection._process_table_schema(table_object.path, raw_schema) + self.assertEqual(schema["value"].precision, 0) def test_default_precision(self): name = "tbl_" + random_table_suffix() - db = get_conn(self.db_cls) - tbl = table(name, schema={"value": "DECIMAL"}) - db.query(tbl.create()) - t = table(name) - raw_schema = db.query_table_schema(t.path) - schema = db._process_table_schema(t.path, raw_schema) - self.assertEqual(schema["value"].precision, db.dialect.DEFAULT_NUMERIC_PRECISION) + db_connection = get_conn(self.db_cls) + with db_connection: + table_object = table(name, schema={"value": "DECIMAL"}) + db_connection.query(table_object.create()) + table_object = table(name) + raw_schema = db_connection.query_table_schema(table_object.path) + schema = db_connection._process_table_schema(table_object.path, raw_schema) + self.assertEqual(schema["value"].precision, db_connection.dialect.DEFAULT_NUMERIC_PRECISION) + + +# Skip presto as it doesn't support a close method: +# https://github.com/prestodb/presto-python-client/blob/be2610e524fa8400c9f2baa41ba0159d44ac2b11/prestodb/dbapi.py#L130 +closeable_databases = TEST_DATABASES.copy() +closeable_databases.discard(dbs.Presto) + +test_closeable_databases: Callable = test_each_database_in_list(closeable_databases) + + +@test_closeable_databases +class TestCloseMethod(unittest.TestCase): + def test_close_connection(self): + database: Database = get_conn(self.db_cls) + + # Perform a query to verify the connection is established + with database: + database.query("SELECT 1") + + # Now the connection should be closed, and trying to execute a query should fail. + with self.assertRaises(Exception): # Catch any type of exception. + database.query("SELECT 1") diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 00000000..cc1333cb --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,249 @@ +import unittest + +from data_diff import Database, JoinDiffer, HashDiffer +from data_diff import databases as db +from data_diff.__main__ import _get_dbs, _set_age, _get_table_differ, _get_expanded_columns, _get_threads +from data_diff.databases.mysql import MySQL +from data_diff.diff_tables import TableDiffer +from tests.common import CONN_STRINGS, get_conn, DiffTestCase + + +class TestGetDBS(unittest.TestCase): + def test__get_dbs(self) -> None: + db1: Database + db2: Database + db1_str: str = CONN_STRINGS[db.PostgreSQL] + db2_str: str = CONN_STRINGS[db.PostgreSQL] + + # no threads and 2 threads1 + db1, db2 = _get_dbs(0, db1_str, 2, db2_str, 0, False) + with db1, db2: + assert db1 == db2 + assert db1.thread_count == 2 + + # 3 threads and 0 threads1 + db1, db2 = _get_dbs(3, db1_str, 0, db2_str, 0, False) + with db1, db2: + assert db1 == db2 + assert db1.thread_count == 3 + + # not interactive + db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, False) + with db1, db2: + assert db1 == db2 + assert not db1._interactive + + # interactive + db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, True) + with db1, db2: + assert db1 == db2 + assert db1._interactive + + db2_str: str = CONN_STRINGS[db.MySQL] + + # no threads and 1 threads1 and 2 thread2 + db1, db2 = _get_dbs(0, db1_str, 1, db2_str, 2, False) + with db1, db2: + assert db1 != db2 + assert db1.thread_count == 1 + assert db2.thread_count == 2 + + # 3 threads and 0 threads1 and 0 thread2 + db1, db2 = _get_dbs(3, db1_str, 0, db2_str, 0, False) + with db1, db2: + assert db1 != db2 + assert db1.thread_count == 3 + assert db2.thread_count == 3 + assert db1.thread_count == db2.thread_count + + # not interactive + db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, False) + with db1, db2: + assert db1 != db2 + assert not db1._interactive + assert not db2._interactive + + # interactive + db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, True) + with db1, db2: + assert db1 != db2 + assert db1._interactive + assert db2._interactive + + def test_database_connection_failure(self) -> None: + """Test when database connection fails.""" + with self.assertRaises(Exception): # Assuming that connect() raises Exception on connection failure + _get_dbs(1, "db1_str", 0, "db2_str", 0, False) + + def test_invalid_inputs(self) -> None: + """Test invalid inputs.""" + with self.assertRaises(Exception): # Assuming that connect() raises Exception on failure + _get_dbs(0, "", 0, "", 0, False) # Empty connection strings + + def test_database_object(self) -> None: + """Test returned database objects are valid and not None.""" + db1_str: str = CONN_STRINGS[db.PostgreSQL] + db2_str: str = CONN_STRINGS[db.PostgreSQL] + db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, False) + self.assertIsNotNone(db1) + self.assertIsNotNone(db2) + self.assertIsInstance(db1, Database) + self.assertIsInstance(db2, Database) + + def test_databases_are_different(self) -> None: + """Test separate connections for different databases.""" + db1_str: str = CONN_STRINGS[db.PostgreSQL] + db2_str: str = CONN_STRINGS[db.MySQL] + db1, db2 = _get_dbs(0, db1_str, 1, db2_str, 2, False) + with db1, db2: + self.assertIsNot(db1, db2) # Check that db1 and db2 are not the same object + + +class TestSetAge(unittest.TestCase): + def setUp(self) -> None: + self.database: Database = get_conn(db.PostgreSQL) + + def tearDown(self): + self.database.close() + + def test__set_age(self): + options = {} + _set_age(options, None, None, self.database) + assert len(options) == 0 + + options = {} + _set_age(options, "1d", None, self.database) + assert len(options) == 1 + assert options.get("max_update") is not None + + options = {} + _set_age(options, None, "1d", self.database) + assert len(options) == 1 + assert options.get("min_update") is not None + + options = {} + _set_age(options, "1d", "1d", self.database) + assert len(options) == 2 + assert options.get("max_update") is not None + assert options.get("min_update") is not None + + def test__set_age_db_query_failure(self): + with self.assertRaises(Exception): + options = {} + _set_age(options, "1d", "1d", self.mock_database) + + +class TestGetTableDiffer(unittest.TestCase): + def test__get_table_differ(self): + db1: Database + db2: Database + db1_str: str = CONN_STRINGS[db.PostgreSQL] + db2_str: str = CONN_STRINGS[db.PostgreSQL] + + db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, False) + with db1, db2: + assert db1 == db2 + table_differ: TableDiffer = self._get_differ("auto", db1, db2) + assert isinstance(table_differ, JoinDiffer) + + table_differ: TableDiffer = self._get_differ("joindiff", db1, db2) + assert isinstance(table_differ, JoinDiffer) + + table_differ: TableDiffer = self._get_differ("hashdiff", db1, db2) + assert isinstance(table_differ, HashDiffer) + + db2_str: str = CONN_STRINGS[db.MySQL] + db1, db2 = _get_dbs(1, db1_str, 0, db2_str, 0, False) + with db1, db2: + assert db1 != db2 + table_differ: TableDiffer = self._get_differ("auto", db1, db2) + assert isinstance(table_differ, HashDiffer) + + table_differ: TableDiffer = self._get_differ("joindiff", db1, db2) + assert isinstance(table_differ, JoinDiffer) + + table_differ: TableDiffer = self._get_differ("hashdiff", db1, db2) + assert isinstance(table_differ, HashDiffer) + + @staticmethod + def _get_differ(algorithm, db1, db2): + return _get_table_differ(algorithm, db1, db2, False, 1, False, False, False, 1, None, None, None) + + +class TestGetExpandedColumns(DiffTestCase): + db_cls = MySQL + + def setUp(self): + super().setUp() + + def test__get_expanded_columns(self): + columns = ["user_id", "movie_id", "rating"] + kwargs = { + "db1": self.connection, + "schema1": self.src_schema, + "table1": self.table_src_name, + "db2": self.connection, + "schema2": self.dst_schema, + "table2": self.table_dst_name, + } + expanded_columns = _get_expanded_columns(columns, False, set(columns), **kwargs) + + assert len(expanded_columns) == 3 + assert len(set(expanded_columns) & set(columns)) == 3 + + def test__get_expanded_columns_case_sensitive(self): + columns = ["UserID", "MovieID", "Rating"] + kwargs = { + "db1": self.connection, + "schema1": self.src_schema, + "table1": self.table_src_name, + "db2": self.connection, + "schema2": self.dst_schema, + "table2": self.table_dst_name, + } + expanded_columns = _get_expanded_columns(columns, True, set(columns), **kwargs) + + assert len(expanded_columns) == 3 + assert len(set(expanded_columns) & set(columns)) == 3 + + +class TestGetThreads(unittest.TestCase): + def test__get_threads(self): + threaded, threads = _get_threads(None, None, None) + assert threaded + assert threads == 1 + + threaded, threads = _get_threads(None, 2, 3) + assert threaded + assert threads == 1 + + threaded, threads = _get_threads("serial", None, None) + assert not threaded + assert threads == 1 + + with self.assertRaises(AssertionError): + _get_threads("serial", 1, 2) + + threaded, threads = _get_threads("4", None, None) + assert threaded + assert threads == 4 + + with self.assertRaises(ValueError) as value_error: + _get_threads("auto", None, None) + assert str(value_error.exception) == "invalid literal for int() with base 10: 'auto'" + + threaded, threads = _get_threads(5, None, None) + assert threaded + assert threads == 5 + + threaded, threads = _get_threads(6, 7, 8) + assert threaded + assert threads == 6 + + with self.assertRaises(ValueError) as value_error: + _get_threads(0, None, None) + assert str(value_error.exception) == "Error: threads must be >= 1" + + with self.assertRaises(ValueError) as value_error: + _get_threads(-1, None, None) + assert str(value_error.exception) == "Error: threads must be >= 1" diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index ed1baecf..4d040204 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,11 +1,12 @@ import unittest - +from copy import deepcopy from urllib.parse import quote -from data_diff.queries.api import table, commit -from data_diff import TableSegment, HashDiffer + +from data_diff import TableSegment, HashDiffer, Database +from data_diff import connect_to_table from data_diff import databases as db +from data_diff.queries.api import table, commit from tests.common import get_conn, random_table_suffix, connect -from data_diff import connect_to_table class TestUUID(unittest.TestCase): @@ -118,38 +119,40 @@ def test_100_fields(self): class TestSpecialCharacterPassword(unittest.TestCase): + username: str = "test" + password: str = "passw!!!@rd" + def setUp(self) -> None: - self.connection = get_conn(db.PostgreSQL) + self.connection: Database = get_conn(db.PostgreSQL) + self.table_name = f"table{random_table_suffix()}" - table_suffix = random_table_suffix() + # Setup user with special character '@' in password + self.connection.query(f"DROP USER IF EXISTS {self.username};", None) + self.connection.query(f"CREATE USER {self.username} WITH PASSWORD '{self.password}';", None) - self.table_name = f"table{table_suffix}" - self.table = table(self.table_name) + def tearDown(self): + self.connection.query(f"DROP USER IF EXISTS {self.username};", None) + self.connection.close() def test_special_char_password(self): - password = "passw!!!@rd" - # Setup user with special character '@' in password - self.connection.query("DROP USER IF EXISTS test;", None) - self.connection.query(f"CREATE USER test WITH PASSWORD '{password}';", None) - - password_quoted = quote(password) - db_config = { - "driver": "postgresql", - "host": "localhost", - "port": 5432, - "dbname": "postgres", - "user": "test", - "password": password_quoted, - } + db_config = deepcopy(self.connection._args) + db_config.update( + { + "driver": "postgresql", + "dbname": db_config.pop("database"), + "user": self.username, + "password": quote(self.password), + } + ) # verify pythonic connection method - connect_to_table( - db_config, - self.table_name, - ) + connect_to_table(db_config, self.table_name) # verify connection method with URL string unquoted after it's verified - db_url = f"postgresql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['dbname']}" + db_url = ( + f"postgresql://{db_config['user']}:{db_config['password']}@{db_config['host']}:" + f"{db_config.get('port', 5432)}/{db_config['dbname']}" + ) - connection_verified = connect(db_url) - assert connection_verified._args.get("password") == password + with connect(db_url) as connection_verified: + assert connection_verified._args.get("password") == self.password