From 7a235cb6e61f270803a5d6bd07a18e537a6eb05e Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 22 Sep 2023 14:28:36 +0200 Subject: [PATCH 1/3] Squash sqeleton's utils & data_diff's utils --- data_diff/sqeleton/abcs/database_types.py | 2 +- data_diff/sqeleton/databases/base.py | 2 +- data_diff/sqeleton/databases/duckdb.py | 2 +- data_diff/sqeleton/databases/oracle.py | 2 +- data_diff/sqeleton/databases/presto.py | 2 +- data_diff/sqeleton/databases/vertica.py | 2 +- data_diff/sqeleton/queries/api.py | 2 +- data_diff/sqeleton/queries/ast_classes.py | 2 +- data_diff/sqeleton/queries/compiler.py | 2 +- data_diff/sqeleton/schema.py | 2 +- data_diff/sqeleton/utils.py | 311 ---------------------- data_diff/table_segment.py | 2 +- data_diff/utils.py | 267 ++++++++++++++++++- tests/test_database_types.py | 2 +- tests/test_diff_tables.py | 2 +- tests/test_query.py | 2 +- tests/test_utils.py | 2 +- 17 files changed, 281 insertions(+), 327 deletions(-) delete mode 100644 data_diff/sqeleton/utils.py diff --git a/data_diff/sqeleton/abcs/database_types.py b/data_diff/sqeleton/abcs/database_types.py index 26909067..82ec8352 100644 --- a/data_diff/sqeleton/abcs/database_types.py +++ b/data_diff/sqeleton/abcs/database_types.py @@ -6,7 +6,7 @@ from runtype import dataclass from typing_extensions import Self -from data_diff.sqeleton.utils import ArithAlphanumeric, ArithUUID, Unknown +from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown DbPath = Tuple[str, ...] diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py index ec41bac4..720bc029 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/sqeleton/databases/base.py @@ -13,7 +13,7 @@ from runtype import dataclass from typing_extensions import Self -from data_diff.sqeleton.utils import is_uuid, safezip +from data_diff.utils import is_uuid, safezip from data_diff.sqeleton.queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this from data_diff.sqeleton.queries.ast_classes import Random from data_diff.sqeleton.abcs.database_types import ( diff --git a/data_diff/sqeleton/databases/duckdb.py b/data_diff/sqeleton/databases/duckdb.py index 3714b00d..2066da5a 100644 --- a/data_diff/sqeleton/databases/duckdb.py +++ b/data_diff/sqeleton/databases/duckdb.py @@ -1,6 +1,6 @@ from typing import Union -from data_diff.sqeleton.utils import match_regexps +from data_diff.utils import match_regexps from data_diff.sqeleton.abcs.database_types import ( Timestamp, TimestampTZ, diff --git a/data_diff/sqeleton/databases/oracle.py b/data_diff/sqeleton/databases/oracle.py index 23fc8e09..3f249441 100644 --- a/data_diff/sqeleton/databases/oracle.py +++ b/data_diff/sqeleton/databases/oracle.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional -from data_diff.sqeleton.utils import match_regexps +from data_diff.utils import match_regexps from data_diff.sqeleton.abcs.database_types import ( Decimal, Float, diff --git a/data_diff/sqeleton/databases/presto.py b/data_diff/sqeleton/databases/presto.py index 6c09a879..a09d7846 100644 --- a/data_diff/sqeleton/databases/presto.py +++ b/data_diff/sqeleton/databases/presto.py @@ -1,7 +1,7 @@ from functools import partial import re -from data_diff.sqeleton.utils import match_regexps +from data_diff.utils import match_regexps from data_diff.sqeleton.abcs.database_types import ( Timestamp, diff --git a/data_diff/sqeleton/databases/vertica.py b/data_diff/sqeleton/databases/vertica.py index 0c03ab79..6a59bcc3 100644 --- a/data_diff/sqeleton/databases/vertica.py +++ b/data_diff/sqeleton/databases/vertica.py @@ -1,6 +1,6 @@ from typing import List -from data_diff.sqeleton.utils import match_regexps +from data_diff.utils import match_regexps from data_diff.sqeleton.databases.base import ( CHECKSUM_HEXDIGITS, MD5_HEXDIGITS, diff --git a/data_diff/sqeleton/queries/api.py b/data_diff/sqeleton/queries/api.py index 301cea32..97a6b00c 100644 --- a/data_diff/sqeleton/queries/api.py +++ b/data_diff/sqeleton/queries/api.py @@ -1,6 +1,6 @@ from typing import Optional -from data_diff.sqeleton.utils import CaseAwareMapping, CaseSensitiveDict +from data_diff.utils import CaseAwareMapping, CaseSensitiveDict from data_diff.sqeleton.queries.ast_classes import * from data_diff.sqeleton.queries.base import args_as_tuple diff --git a/data_diff/sqeleton/queries/ast_classes.py b/data_diff/sqeleton/queries/ast_classes.py index f3b04f73..6ee25ecb 100644 --- a/data_diff/sqeleton/queries/ast_classes.py +++ b/data_diff/sqeleton/queries/ast_classes.py @@ -5,7 +5,7 @@ from runtype import dataclass from typing_extensions import Self -from data_diff.sqeleton.utils import join_iter, ArithString +from data_diff.utils import join_iter, ArithString from data_diff.sqeleton.abcs import Compilable from data_diff.sqeleton.abcs.database_types import AbstractTable from data_diff.sqeleton.abcs.mixins import AbstractMixin_Regex, AbstractMixin_TimeTravel diff --git a/data_diff/sqeleton/queries/compiler.py b/data_diff/sqeleton/queries/compiler.py index f9ab7484..0aaf8dd6 100644 --- a/data_diff/sqeleton/queries/compiler.py +++ b/data_diff/sqeleton/queries/compiler.py @@ -6,7 +6,7 @@ from runtype import dataclass from typing_extensions import Self -from data_diff.sqeleton.utils import ArithString +from data_diff.utils import ArithString from data_diff.sqeleton.abcs import AbstractDatabase, AbstractDialect, DbPath, AbstractCompiler, Compilable import contextvars diff --git a/data_diff/sqeleton/schema.py b/data_diff/sqeleton/schema.py index 01dfeed7..35ebe8ef 100644 --- a/data_diff/sqeleton/schema.py +++ b/data_diff/sqeleton/schema.py @@ -1,6 +1,6 @@ import logging -from data_diff.sqeleton.utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict +from data_diff.utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict from data_diff.sqeleton.abcs import AbstractDatabase, DbPath logger = logging.getLogger("schema") diff --git a/data_diff/sqeleton/utils.py b/data_diff/sqeleton/utils.py deleted file mode 100644 index d1f596db..00000000 --- a/data_diff/sqeleton/utils.py +++ /dev/null @@ -1,311 +0,0 @@ -from typing import ( - Iterable, - Iterator, - MutableMapping, - Type, - Union, - Any, - Sequence, - Dict, - TypeVar, - List, -) -from abc import abstractmethod -import math -import string -import re -from uuid import UUID -from urllib.parse import urlparse - -from typing_extensions import Self - -# -- Common -- - - -def join_iter(joiner: Any, iterable: Iterable) -> Iterable: - it = iter(iterable) - try: - yield next(it) - except StopIteration: - return - for i in it: - yield joiner - yield i - - -def safezip(*args): - "zip but makes sure all sequences are the same length" - lens = list(map(len, args)) - if len(set(lens)) != 1: - raise ValueError(f"Mismatching lengths in arguments to safezip: {lens}") - return zip(*args) - - -def is_uuid(u): - try: - UUID(u) - except ValueError: - return False - return True - - -def match_regexps(regexps: Dict[str, Any], s: str) -> Sequence[tuple]: - for regexp, v in regexps.items(): - m = re.match(regexp + "$", s) - if m: - yield m, v - - -# -- Schema -- - -V = TypeVar("V") - - -class CaseAwareMapping(MutableMapping[str, V]): - @abstractmethod - def get_key(self, key: str) -> str: - ... - - def new(self, initial=()) -> Self: - return type(self)(initial) - - -class CaseInsensitiveDict(CaseAwareMapping): - def __init__(self, initial): - self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} - - def __getitem__(self, key: str) -> V: - return self._dict[key.lower()][1] - - def __iter__(self) -> Iterator[V]: - return iter(self._dict) - - def __len__(self) -> int: - return len(self._dict) - - def __setitem__(self, key: str, value): - k = key.lower() - if k in self._dict: - key = self._dict[k][0] - self._dict[k] = key, value - - def __delitem__(self, key: str): - del self._dict[key.lower()] - - def get_key(self, key: str) -> str: - return self._dict[key.lower()][0] - - def __repr__(self) -> str: - return repr(dict(self.items())) - - -class CaseSensitiveDict(dict, CaseAwareMapping): - def get_key(self, key): - self[key] # Throw KeyError if key doesn't exist - return key - - def as_insensitive(self): - return CaseInsensitiveDict(self) - - -# -- Alphanumerics -- - -alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase - - -class ArithString: - @classmethod - def new(cls, *args, **kw) -> Self: - return cls(*args, **kw) - - def range(self, other: "ArithString", count: int) -> List[Self]: - assert isinstance(other, ArithString) - checkpoints = split_space(self.int, other.int, count) - return [self.new(int=i) for i in checkpoints] - - -class ArithUUID(UUID, ArithString): - "A UUID that supports basic arithmetic (add, sub)" - - def __int__(self): - return self.int - - def __add__(self, other: int) -> Self: - if isinstance(other, int): - return self.new(int=self.int + other) - return NotImplemented - - def __sub__(self, other: Union[UUID, int]): - if isinstance(other, int): - return self.new(int=self.int - other) - elif isinstance(other, UUID): - return self.int - other.int - return NotImplemented - - -def numberToAlphanum(num: int, base: str = alphanums) -> str: - digits = [] - while num > 0: - num, remainder = divmod(num, len(base)) - digits.append(remainder) - return "".join(base[i] for i in digits[::-1]) - - -def alphanumToNumber(alphanum: str, base: str = alphanums) -> int: - num = 0 - for c in alphanum: - num = num * len(base) + base.index(c) - return num - - -def justify_alphanums(s1: str, s2: str): - max_len = max(len(s1), len(s2)) - s1 = s1.ljust(max_len) - s2 = s2.ljust(max_len) - return s1, s2 - - -def alphanums_to_numbers(s1: str, s2: str): - s1, s2 = justify_alphanums(s1, s2) - n1 = alphanumToNumber(s1) - n2 = alphanumToNumber(s2) - return n1, n2 - - -class ArithAlphanumeric(ArithString): - def __init__(self, s: str, max_len=None): - if s is None: - raise ValueError("Alphanum string cannot be None") - if max_len and len(s) > max_len: - raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {max_len}") - - for ch in s: - if ch not in alphanums: - raise ValueError(f"Unexpected character {ch} in alphanum string") - - self._str = s - self._max_len = max_len - - # @property - # def int(self): - # return alphanumToNumber(self._str, alphanums) - - def __str__(self): - s = self._str - if self._max_len: - s = s.rjust(self._max_len, alphanums[0]) - return s - - def __len__(self): - return len(self._str) - - def __repr__(self): - return f'alphanum"{self._str}"' - - def __add__(self, other: "Union[ArithAlphanumeric, int]") -> Self: - if isinstance(other, int): - if other != 1: - raise NotImplementedError("not implemented for arbitrary numbers") - num = alphanumToNumber(self._str) - return self.new(numberToAlphanum(num + 1)) - - return NotImplemented - - def range(self, other: "ArithAlphanumeric", count: int) -> List[Self]: - assert isinstance(other, ArithAlphanumeric) - n1, n2 = alphanums_to_numbers(self._str, other._str) - split = split_space(n1, n2, count) - return [self.new(numberToAlphanum(s)) for s in split] - - def __sub__(self, other: "Union[ArithAlphanumeric, int]") -> float: - if isinstance(other, ArithAlphanumeric): - n1, n2 = alphanums_to_numbers(self._str, other._str) - return n1 - n2 - - return NotImplemented - - def __ge__(self, other): - if not isinstance(other, type(self)): - return NotImplemented - return self._str >= other._str - - def __lt__(self, other): - if not isinstance(other, type(self)): - return NotImplemented - return self._str < other._str - - def __eq__(self, other): - if not isinstance(other, type(self)): - return NotImplemented - return self._str == other._str - - def new(self, *args, **kw) -> Self: - return type(self)(*args, **kw, max_len=self._max_len) - - -def number_to_human(n): - millnames = ["", "k", "m", "b"] - n = float(n) - millidx = max( - 0, - min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3))), - ) - - return "{:.0f}{}".format(n / 10 ** (3 * millidx), millnames[millidx]) - - -def split_space(start, end, count) -> List[int]: - size = end - start - assert count <= size, (count, size) - return list(range(start, end, (size + 1) // (count + 1)))[1 : count + 1] - - -def remove_passwords_in_dict(d: dict, replace_with: str = "***"): - for k, v in d.items(): - if k == "password": - d[k] = replace_with - elif isinstance(v, dict): - remove_passwords_in_dict(v, replace_with) - elif k.startswith("database"): - d[k] = remove_password_from_url(v, replace_with) - - -def _join_if_any(sym, args): - args = list(args) - if not args: - return "" - return sym.join(str(a) for a in args if a) - - -def remove_password_from_url(url: str, replace_with: str = "***") -> str: - parsed = urlparse(url) - account = parsed.username or "" - if parsed.password: - account += ":" + replace_with - host = _join_if_any(":", filter(None, [parsed.hostname, parsed.port])) - netloc = _join_if_any("@", filter(None, [account, host])) - replaced = parsed._replace(netloc=netloc) - return replaced.geturl() - - -def match_like(pattern: str, strs: Sequence[str]) -> Iterable[str]: - reo = re.compile(pattern.replace("%", ".*").replace("?", ".") + "$") - for s in strs: - if reo.match(s): - yield s - - -class UnknownMeta(type): - def __instancecheck__(self, instance): - return instance is Unknown - - def __repr__(self): - return "Unknown" - - -class Unknown(metaclass=UnknownMeta): - def __nonzero__(self): - raise TypeError() - - def __new__(class_, *args, **kwargs): - raise RuntimeError("Unknown is a singleton") diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 46672304..8568ffc4 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -7,7 +7,7 @@ from typing_extensions import Self from data_diff.utils import safezip, Vector -from data_diff.sqeleton.utils import ArithString, split_space +from data_diff.utils import ArithString, split_space from data_diff.sqeleton.databases import Database, DbPath, DbKey, DbTime from data_diff.sqeleton.schema import Schema, create_schema from data_diff.sqeleton.queries import Count, Checksum, SKIP, table, this, Expr, min_, max_, Code diff --git a/data_diff/utils.py b/data_diff/utils.py index 02870f60..67d74dce 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -1,18 +1,37 @@ import json import logging import re -from typing import Dict, Iterable, Sequence +from abc import abstractmethod +from typing import Any, Dict, Iterable, Iterator, MutableMapping, Sequence, TypeVar from urllib.parse import urlparse import operator import threading from datetime import datetime +from uuid import UUID + from packaging.version import parse as parse_version import requests from tabulate import tabulate +from typing_extensions import Self + from data_diff.version import __version__ from rich.status import Status +# -- Common -- + + +def join_iter(joiner: Any, iterable: Iterable) -> Iterable: + it = iter(iterable) + try: + yield next(it) + except StopIteration: + return + for i in it: + yield joiner + yield i + + def safezip(*args): "zip but makes sure all sequences are the same length" lens = list(map(len, args)) @@ -21,6 +40,236 @@ def safezip(*args): return zip(*args) +def is_uuid(u): + try: + UUID(u) + except ValueError: + return False + return True + + +def match_regexps(regexps: Dict[str, Any], s: str) -> Sequence[tuple]: + for regexp, v in regexps.items(): + m = re.match(regexp + "$", s) + if m: + yield m, v + + +# -- Schema -- + +V = TypeVar("V") + + +class CaseAwareMapping(MutableMapping[str, V]): + @abstractmethod + def get_key(self, key: str) -> str: + ... + + def new(self, initial=()) -> Self: + return type(self)(initial) + + +class CaseInsensitiveDict(CaseAwareMapping): + def __init__(self, initial): + self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} + + def __getitem__(self, key: str) -> V: + return self._dict[key.lower()][1] + + def __iter__(self) -> Iterator[V]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + + def __setitem__(self, key: str, value): + k = key.lower() + if k in self._dict: + key = self._dict[k][0] + self._dict[k] = key, value + + def __delitem__(self, key: str): + del self._dict[key.lower()] + + def get_key(self, key: str) -> str: + return self._dict[key.lower()][0] + + def __repr__(self) -> str: + return repr(dict(self.items())) + + +class CaseSensitiveDict(dict, CaseAwareMapping): + def get_key(self, key): + self[key] # Throw KeyError if key doesn't exist + return key + + def as_insensitive(self): + return CaseInsensitiveDict(self) + + + +# -- Alphanumerics -- + +alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase + + +class ArithString: + @classmethod + def new(cls, *args, **kw) -> Self: + return cls(*args, **kw) + + def range(self, other: "ArithString", count: int) -> List[Self]: + assert isinstance(other, ArithString) + checkpoints = split_space(self.int, other.int, count) + return [self.new(int=i) for i in checkpoints] + + +class ArithUUID(UUID, ArithString): + "A UUID that supports basic arithmetic (add, sub)" + + def __int__(self): + return self.int + + def __add__(self, other: int) -> Self: + if isinstance(other, int): + return self.new(int=self.int + other) + return NotImplemented + + def __sub__(self, other: Union[UUID, int]): + if isinstance(other, int): + return self.new(int=self.int - other) + elif isinstance(other, UUID): + return self.int - other.int + return NotImplemented + + +def numberToAlphanum(num: int, base: str = alphanums) -> str: + digits = [] + while num > 0: + num, remainder = divmod(num, len(base)) + digits.append(remainder) + return "".join(base[i] for i in digits[::-1]) + + +def alphanumToNumber(alphanum: str, base: str = alphanums) -> int: + num = 0 + for c in alphanum: + num = num * len(base) + base.index(c) + return num + + +def justify_alphanums(s1: str, s2: str): + max_len = max(len(s1), len(s2)) + s1 = s1.ljust(max_len) + s2 = s2.ljust(max_len) + return s1, s2 + + +def alphanums_to_numbers(s1: str, s2: str): + s1, s2 = justify_alphanums(s1, s2) + n1 = alphanumToNumber(s1) + n2 = alphanumToNumber(s2) + return n1, n2 + + +class ArithAlphanumeric(ArithString): + def __init__(self, s: str, max_len=None): + if s is None: + raise ValueError("Alphanum string cannot be None") + if max_len and len(s) > max_len: + raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {max_len}") + + for ch in s: + if ch not in alphanums: + raise ValueError(f"Unexpected character {ch} in alphanum string") + + self._str = s + self._max_len = max_len + + # @property + # def int(self): + # return alphanumToNumber(self._str, alphanums) + + def __str__(self): + s = self._str + if self._max_len: + s = s.rjust(self._max_len, alphanums[0]) + return s + + def __len__(self): + return len(self._str) + + def __repr__(self): + return f'alphanum"{self._str}"' + + def __add__(self, other: "Union[ArithAlphanumeric, int]") -> Self: + if isinstance(other, int): + if other != 1: + raise NotImplementedError("not implemented for arbitrary numbers") + num = alphanumToNumber(self._str) + return self.new(numberToAlphanum(num + 1)) + + return NotImplemented + + def range(self, other: "ArithAlphanumeric", count: int) -> List[Self]: + assert isinstance(other, ArithAlphanumeric) + n1, n2 = alphanums_to_numbers(self._str, other._str) + split = split_space(n1, n2, count) + return [self.new(numberToAlphanum(s)) for s in split] + + def __sub__(self, other: "Union[ArithAlphanumeric, int]") -> float: + if isinstance(other, ArithAlphanumeric): + n1, n2 = alphanums_to_numbers(self._str, other._str) + return n1 - n2 + + return NotImplemented + + def __ge__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return self._str >= other._str + + def __lt__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return self._str < other._str + + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return self._str == other._str + + def new(self, *args, **kw) -> Self: + return type(self)(*args, **kw, max_len=self._max_len) + + +def number_to_human(n): + millnames = ["", "k", "m", "b"] + n = float(n) + millidx = max( + 0, + min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3))), + ) + + return "{:.0f}{}".format(n / 10 ** (3 * millidx), millnames[millidx]) + + +def split_space(start, end, count) -> List[int]: + size = end - start + assert count <= size, (count, size) + return list(range(start, end, (size + 1) // (count + 1)))[1 : count + 1] + + +def remove_passwords_in_dict(d: dict, replace_with: str = "***"): + for k, v in d.items(): + if k == "password": + d[k] = replace_with + elif isinstance(v, dict): + remove_passwords_in_dict(v, replace_with) + elif k.startswith("database"): + d[k] = remove_password_from_url(v, replace_with) + + def _join_if_any(sym, args): args = list(args) if not args: @@ -248,3 +497,19 @@ def _update_cloud_status(self, log=None): for model_name, status in self.cloud_diff_status.items(): cloud_status_string += f"{status} {model_name}\n" self.status.update(f"{cloud_status_string}{log or ''}") + + +class UnknownMeta(type): + def __instancecheck__(self, instance): + return instance is Unknown + + def __repr__(self): + return "Unknown" + + +class Unknown(metaclass=UnknownMeta): + def __nonzero__(self): + raise TypeError() + + def __new__(class_, *args, **kwargs): + raise RuntimeError("Unknown is a singleton") diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 75b0acee..203731c4 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -13,7 +13,7 @@ from parameterized import parameterized -from data_diff.sqeleton.utils import number_to_human +from data_diff.utils import number_to_human from data_diff.sqeleton.queries import table, commit, this, Code from data_diff.sqeleton.queries.api import insert_rows_in_batches diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index eb41a3ee..052c48ce 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -4,7 +4,7 @@ import unittest from data_diff.sqeleton.queries import table, this, commit, code -from data_diff.sqeleton.utils import ArithAlphanumeric, numberToAlphanum +from data_diff.utils import ArithAlphanumeric, numberToAlphanum from data_diff.hashdiff_tables import HashDiffer from data_diff.joindiff_tables import JoinDiffer diff --git a/tests/test_query.py b/tests/test_query.py index cfa6ada8..b1937028 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -2,7 +2,7 @@ from typing import List, Optional import unittest from data_diff.sqeleton.abcs import AbstractDatabase, AbstractDialect -from data_diff.sqeleton.utils import CaseInsensitiveDict, CaseSensitiveDict +from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict from data_diff.sqeleton.queries import this, table, Compiler, outerjoin, cte, when, coalesce, CompileError from data_diff.sqeleton.queries.ast_classes import Random diff --git a/tests/test_utils.py b/tests/test_utils.py index 973121a2..1277d5be 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,6 @@ import unittest -from data_diff.sqeleton.utils import remove_passwords_in_dict, match_regexps, match_like, number_to_human +from data_diff.utils import remove_passwords_in_dict, match_regexps, match_like, number_to_human class TestUtils(unittest.TestCase): From 3b87269c26462595c9596ff06e046cbc6ef6a856 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Tue, 19 Sep 2023 17:10:12 +0200 Subject: [PATCH 2/3] Squash the hierarchies of databases & dialects from sqeleton & data_diff --- data_diff/databases/__init__.py | 26 +++++++++++----------- data_diff/databases/base.py | 5 ----- data_diff/databases/bigquery.py | 12 ++-------- data_diff/databases/clickhouse.py | 12 ++-------- data_diff/databases/databricks.py | 12 ++-------- data_diff/databases/duckdb.py | 12 ++-------- data_diff/databases/mssql.py | 12 ++-------- data_diff/databases/mysql.py | 12 ++-------- data_diff/databases/oracle.py | 12 ++-------- data_diff/databases/postgresql.py | 12 ++-------- data_diff/databases/presto.py | 12 ++-------- data_diff/databases/redshift.py | 12 ++-------- data_diff/databases/snowflake.py | 12 ++-------- data_diff/databases/trino.py | 12 ++-------- data_diff/databases/vertica.py | 12 ++-------- data_diff/sqeleton/databases/bigquery.py | 2 +- data_diff/sqeleton/databases/clickhouse.py | 2 +- data_diff/sqeleton/databases/databricks.py | 2 +- data_diff/sqeleton/databases/duckdb.py | 2 +- data_diff/sqeleton/databases/mssql.py | 2 +- data_diff/sqeleton/databases/mysql.py | 2 +- data_diff/sqeleton/databases/oracle.py | 2 +- data_diff/sqeleton/databases/postgresql.py | 2 +- data_diff/sqeleton/databases/presto.py | 2 +- data_diff/sqeleton/databases/redshift.py | 4 ++-- data_diff/sqeleton/databases/snowflake.py | 2 +- data_diff/sqeleton/databases/trino.py | 3 ++- data_diff/sqeleton/databases/vertica.py | 2 +- 28 files changed, 54 insertions(+), 162 deletions(-) delete mode 100644 data_diff/databases/base.py diff --git a/data_diff/databases/__init__.py b/data_diff/databases/__init__.py index cae67d1e..eaf97281 100644 --- a/data_diff/databases/__init__.py +++ b/data_diff/databases/__init__.py @@ -1,17 +1,17 @@ from data_diff.sqeleton.databases import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError -from data_diff.databases.postgresql import PostgreSQL -from data_diff.databases.mysql import MySQL -from data_diff.databases.oracle import Oracle -from data_diff.databases.snowflake import Snowflake -from data_diff.databases.bigquery import BigQuery -from data_diff.databases.redshift import Redshift -from data_diff.databases.presto import Presto -from data_diff.databases.databricks import Databricks -from data_diff.databases.trino import Trino -from data_diff.databases.clickhouse import Clickhouse -from data_diff.databases.vertica import Vertica -from data_diff.databases.duckdb import DuckDB -from data_diff.databases.mssql import MsSql +from data_diff.sqeleton.databases.postgresql import PostgreSQL as PostgreSQL +from data_diff.sqeleton.databases.mysql import MySQL as MySQL +from data_diff.sqeleton.databases.oracle import Oracle as Oracle +from data_diff.sqeleton.databases.snowflake import Snowflake as Snowflake +from data_diff.sqeleton.databases.bigquery import BigQuery as BigQuery +from data_diff.sqeleton.databases.redshift import Redshift as Redshift +from data_diff.sqeleton.databases.presto import Presto as Presto +from data_diff.sqeleton.databases.databricks import Databricks as Databricks +from data_diff.sqeleton.databases.trino import Trino as Trino +from data_diff.sqeleton.databases.clickhouse import Clickhouse as Clickhouse +from data_diff.sqeleton.databases.vertica import Vertica as Vertica +from data_diff.sqeleton.databases.duckdb import DuckDB as DuckDB +from data_diff.sqeleton.databases.mssql import MsSQL as MsSql from data_diff.databases._connect import connect diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py deleted file mode 100644 index 5b7ff5ce..00000000 --- a/data_diff/databases/base.py +++ /dev/null @@ -1,5 +0,0 @@ -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue - - -class DatadiffDialect(AbstractMixin_MD5, AbstractMixin_NormalizeValue): - pass diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index a6fdbc9c..8b0bc1e5 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import bigquery -from data_diff.databases.base import DatadiffDialect - - -class Dialect(bigquery.Dialect, bigquery.Mixin_MD5, bigquery.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class BigQuery(bigquery.BigQuery): - dialect = Dialect() +from data_diff.sqeleton.databases.bigquery import Dialect as Dialect +from data_diff.sqeleton.databases.bigquery import BigQuery as BigQuery diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index ce22943b..85083383 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import clickhouse -from data_diff.databases.base import DatadiffDialect - - -class Dialect(clickhouse.Dialect, clickhouse.Mixin_MD5, clickhouse.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class Clickhouse(clickhouse.Clickhouse): - dialect = Dialect() +from data_diff.sqeleton.databases.clickhouse import Dialect as Dialect +from data_diff.sqeleton.databases.clickhouse import Clickhouse as Clickhouse diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 6794c264..36348ffb 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import databricks -from data_diff.databases.base import DatadiffDialect - - -class Dialect(databricks.Dialect, databricks.Mixin_MD5, databricks.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class Databricks(databricks.Databricks): - dialect = Dialect() +from data_diff.sqeleton.databases.databricks import Dialect as Dialect +from data_diff.sqeleton.databases.databricks import Databricks as Databricks diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index e822264e..26558634 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import duckdb -from data_diff.databases.base import DatadiffDialect - - -class Dialect(duckdb.Dialect, duckdb.Mixin_MD5, duckdb.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class DuckDB(duckdb.DuckDB): - dialect = Dialect() +from data_diff.sqeleton.databases.duckdb import Dialect as Dialect +from data_diff.sqeleton.databases.duckdb import DuckDB as DuckDB diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 15163e4f..be1c2cae 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import mssql -from data_diff.databases.base import DatadiffDialect - - -class Dialect(mssql.Dialect, mssql.Mixin_MD5, mssql.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class MsSql(mssql.MsSQL): - dialect = Dialect() +from data_diff.sqeleton.databases.mssql import Dialect as Dialect +from data_diff.sqeleton.databases.mssql import MsSQL as MsSql diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 102620b8..0a715600 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import mysql -from data_diff.databases.base import DatadiffDialect - - -class Dialect(mysql.Dialect, mysql.Mixin_MD5, mysql.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class MySQL(mysql.MySQL): - dialect = Dialect() +from data_diff.sqeleton.databases.mysql import Dialect as Dialect +from data_diff.sqeleton.databases.mysql import MySQL as MySQL diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 3ee4a872..7c10fd11 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import oracle -from data_diff.databases.base import DatadiffDialect - - -class Dialect(oracle.Dialect, oracle.Mixin_MD5, oracle.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class Oracle(oracle.Oracle): - dialect = Dialect() +from data_diff.sqeleton.databases.oracle import Dialect as Dialect +from data_diff.sqeleton.databases.oracle import Oracle as Oracle diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index b63f050a..befe8d44 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import postgresql as pg -from data_diff.databases.base import DatadiffDialect - - -class PostgresqlDialect(pg.PostgresqlDialect, pg.Mixin_MD5, pg.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class PostgreSQL(pg.PostgreSQL): - dialect = PostgresqlDialect() +from data_diff.sqeleton.databases.postgresql import PostgresqlDialect as PostgresqlDialect +from data_diff.sqeleton.databases.postgresql import PostgreSQL as PostgreSQL diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 4ac86b3f..db7c4749 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import presto -from data_diff.databases.base import DatadiffDialect - - -class Dialect(presto.Dialect, presto.Mixin_MD5, presto.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class Presto(presto.Presto): - dialect = Dialect() +from data_diff.sqeleton.databases.presto import Dialect as Dialect +from data_diff.sqeleton.databases.presto import Presto as Presto diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index e6eb3b20..54e9ecc1 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import redshift -from data_diff.databases.base import DatadiffDialect - - -class Dialect(redshift.Dialect, redshift.Mixin_MD5, redshift.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class Redshift(redshift.Redshift): - dialect = Dialect() +from data_diff.sqeleton.databases.redshift import Dialect as Dialect +from data_diff.sqeleton.databases.redshift import Redshift as Redshift diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 7dd8539f..2029a73d 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import snowflake -from data_diff.databases.base import DatadiffDialect - - -class Dialect(snowflake.Dialect, snowflake.Mixin_MD5, snowflake.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class Snowflake(snowflake.Snowflake): - dialect = Dialect() +from data_diff.sqeleton.databases.snowflake import Dialect as Dialect +from data_diff.sqeleton.databases.snowflake import Snowflake as Snowflake diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index a39be906..e60bfb90 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import trino -from data_diff.databases.base import DatadiffDialect - - -class Dialect(trino.Dialect, trino.Mixin_MD5, trino.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class Trino(trino.Trino): - dialect = Dialect() +from data_diff.sqeleton.databases.trino import Dialect as Dialect +from data_diff.sqeleton.databases.trino import Trino as Trino diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 60812a49..83675939 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -1,10 +1,2 @@ -from data_diff.sqeleton.databases import vertica -from data_diff.databases.base import DatadiffDialect - - -class Dialect(vertica.Dialect, vertica.Mixin_MD5, vertica.Mixin_NormalizeValue, DatadiffDialect): - pass - - -class Vertica(vertica.Vertica): - dialect = Dialect() +from data_diff.sqeleton.databases.vertica import Dialect as Dialect +from data_diff.sqeleton.databases.vertica import Vertica as Vertica diff --git a/data_diff/sqeleton/databases/bigquery.py b/data_diff/sqeleton/databases/bigquery.py index 0bac1ff6..bdf9e07d 100644 --- a/data_diff/sqeleton/databases/bigquery.py +++ b/data_diff/sqeleton/databases/bigquery.py @@ -139,7 +139,7 @@ def time_travel( ) -class Dialect(BaseDialect, Mixin_Schema): +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "BigQuery" ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation TYPE_CLASSES = { diff --git a/data_diff/sqeleton/databases/clickhouse.py b/data_diff/sqeleton/databases/clickhouse.py index e14cd226..578bb1e5 100644 --- a/data_diff/sqeleton/databases/clickhouse.py +++ b/data_diff/sqeleton/databases/clickhouse.py @@ -99,7 +99,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" -class Dialect(BaseDialect): +class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Clickhouse" ROUNDS_ON_PREC_LOSS = False TYPE_CLASSES = { diff --git a/data_diff/sqeleton/databases/databricks.py b/data_diff/sqeleton/databases/databricks.py index a5474ee2..e478039f 100644 --- a/data_diff/sqeleton/databases/databricks.py +++ b/data_diff/sqeleton/databases/databricks.py @@ -60,7 +60,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"cast ({value} as int)") -class Dialect(BaseDialect): +class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Databricks" ROUNDS_ON_PREC_LOSS = True TYPE_CLASSES = { diff --git a/data_diff/sqeleton/databases/duckdb.py b/data_diff/sqeleton/databases/duckdb.py index 2066da5a..827a0483 100644 --- a/data_diff/sqeleton/databases/duckdb.py +++ b/data_diff/sqeleton/databases/duckdb.py @@ -75,7 +75,7 @@ def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: return Func("regexp_matches", [string, pattern]) -class Dialect(BaseDialect, Mixin_Schema): +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "DuckDB" ROUNDS_ON_PREC_LOSS = False SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/sqeleton/databases/mssql.py b/data_diff/sqeleton/databases/mssql.py index cc0754a7..d18f3fda 100644 --- a/data_diff/sqeleton/databases/mssql.py +++ b/data_diff/sqeleton/databases/mssql.py @@ -58,7 +58,7 @@ def md5_as_int(self, s: str) -> str: return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))" -class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "MsSQL" ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/sqeleton/databases/mysql.py b/data_diff/sqeleton/databases/mysql.py index a10652b5..7c659749 100644 --- a/data_diff/sqeleton/databases/mysql.py +++ b/data_diff/sqeleton/databases/mysql.py @@ -67,7 +67,7 @@ def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: return BinBoolOp("REGEXP", [string, pattern]) -class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "MySQL" ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/sqeleton/databases/oracle.py b/data_diff/sqeleton/databases/oracle.py index 3f249441..825510a1 100644 --- a/data_diff/sqeleton/databases/oracle.py +++ b/data_diff/sqeleton/databases/oracle.py @@ -80,7 +80,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) -class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Oracle" SUPPORTS_PRIMARY_KEY = True SUPPORTS_INDEXES = True diff --git a/data_diff/sqeleton/databases/postgresql.py b/data_diff/sqeleton/databases/postgresql.py index db34ec54..41228439 100644 --- a/data_diff/sqeleton/databases/postgresql.py +++ b/data_diff/sqeleton/databases/postgresql.py @@ -61,7 +61,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: return f"{value}::text" -class PostgresqlDialect(BaseDialect, Mixin_Schema): +class PostgresqlDialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "PostgreSQL" ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/sqeleton/databases/presto.py b/data_diff/sqeleton/databases/presto.py index a09d7846..3a033ed9 100644 --- a/data_diff/sqeleton/databases/presto.py +++ b/data_diff/sqeleton/databases/presto.py @@ -76,7 +76,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"cast ({value} as int)") -class Dialect(BaseDialect, Mixin_Schema): +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Presto" ROUNDS_ON_PREC_LOSS = True TYPE_CLASSES = { diff --git a/data_diff/sqeleton/databases/redshift.py b/data_diff/sqeleton/databases/redshift.py index 97cbc0e1..e41d961e 100644 --- a/data_diff/sqeleton/databases/redshift.py +++ b/data_diff/sqeleton/databases/redshift.py @@ -7,7 +7,7 @@ DbPath, TimestampTZ, ) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5 +from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.sqeleton.databases.postgresql import ( PostgreSQL, MD5_HEXDIGITS, @@ -51,7 +51,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: return f"nvl2({value}, json_serialize({value}), NULL)" -class Dialect(PostgresqlDialect): +class Dialect(PostgresqlDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Redshift" TYPE_CLASSES = { **PostgresqlDialect.TYPE_CLASSES, diff --git a/data_diff/sqeleton/databases/snowflake.py b/data_diff/sqeleton/databases/snowflake.py index e8bf51f5..6868f52f 100644 --- a/data_diff/sqeleton/databases/snowflake.py +++ b/data_diff/sqeleton/databases/snowflake.py @@ -104,7 +104,7 @@ def time_travel( return code(f"{{table}} {at_or_before}({key} => {{value}})", table=table, value=value) -class Dialect(BaseDialect, Mixin_Schema): +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Snowflake" ROUNDS_ON_PREC_LOSS = False TYPE_CLASSES = { diff --git a/data_diff/sqeleton/databases/trino.py b/data_diff/sqeleton/databases/trino.py index 20411749..a255b9a7 100644 --- a/data_diff/sqeleton/databases/trino.py +++ b/data_diff/sqeleton/databases/trino.py @@ -1,3 +1,4 @@ +from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.sqeleton.abcs.database_types import TemporalType, ColType_UUID from data_diff.sqeleton.databases import presto from data_diff.sqeleton.databases.base import import_helper @@ -29,7 +30,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: return f"TRIM({value})" -class Dialect(presto.Dialect): +class Dialect(presto.Dialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Trino" diff --git a/data_diff/sqeleton/databases/vertica.py b/data_diff/sqeleton/databases/vertica.py index 6a59bcc3..9642ff7c 100644 --- a/data_diff/sqeleton/databases/vertica.py +++ b/data_diff/sqeleton/databases/vertica.py @@ -78,7 +78,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) -class Dialect(BaseDialect, Mixin_Schema): +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Vertica" ROUNDS_ON_PREC_LOSS = True From 871c201988c2c44a0901700768e6a30cc283dd62 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 22 Sep 2023 17:30:27 +0200 Subject: [PATCH 3/3] Squash the modules of databases & dialects from sqeleton & data_diff --- data_diff/__init__.py | 5 +- data_diff/__main__.py | 6 +- data_diff/abcs/__init__.py | 0 data_diff/{sqeleton => }/abcs/compiler.py | 0 .../{sqeleton => }/abcs/database_types.py | 0 data_diff/{sqeleton => }/abcs/mixins.py | 4 +- data_diff/{sqeleton => }/bound_exprs.py | 13 +- data_diff/databases/__init__.py | 33 +- data_diff/databases/_connect.py | 255 ++++++++++++++- data_diff/{sqeleton => }/databases/base.py | 16 +- data_diff/databases/bigquery.py | 299 +++++++++++++++++- data_diff/databases/clickhouse.py | 198 +++++++++++- data_diff/databases/databricks.py | 201 +++++++++++- data_diff/databases/duckdb.py | 194 +++++++++++- data_diff/databases/mssql.py | 216 ++++++++++++- data_diff/databases/mysql.py | 161 +++++++++- data_diff/databases/oracle.py | 208 +++++++++++- data_diff/databases/postgresql.py | 184 ++++++++++- data_diff/databases/presto.py | 204 +++++++++++- data_diff/databases/redshift.py | 178 ++++++++++- data_diff/databases/snowflake.py | 230 +++++++++++++- data_diff/databases/trino.py | 50 ++- data_diff/databases/vertica.py | 183 ++++++++++- data_diff/diff_tables.py | 4 +- data_diff/format.py | 2 +- data_diff/hashdiff_tables.py | 4 +- data_diff/joindiff_tables.py | 15 +- data_diff/queries/__init__.py | 0 data_diff/{sqeleton => }/queries/api.py | 6 +- .../{sqeleton => }/queries/ast_classes.py | 17 +- data_diff/{sqeleton => }/queries/base.py | 3 - data_diff/{sqeleton => }/queries/compiler.py | 5 +- data_diff/{sqeleton => }/queries/extras.py | 6 +- data_diff/query_utils.py | 6 +- data_diff/{sqeleton => }/schema.py | 2 +- data_diff/sqeleton/__init__.py | 2 - data_diff/sqeleton/abcs/__init__.py | 15 - data_diff/sqeleton/databases/__init__.py | 26 -- data_diff/sqeleton/databases/_connect.py | 283 ----------------- data_diff/sqeleton/databases/bigquery.py | 297 ----------------- data_diff/sqeleton/databases/clickhouse.py | 196 ------------ data_diff/sqeleton/databases/databricks.py | 199 ------------ data_diff/sqeleton/databases/duckdb.py | 192 ----------- data_diff/sqeleton/databases/mssql.py | 214 ------------- data_diff/sqeleton/databases/mysql.py | 160 ---------- data_diff/sqeleton/databases/oracle.py | 206 ------------ data_diff/sqeleton/databases/postgresql.py | 183 ----------- data_diff/sqeleton/databases/presto.py | 202 ------------ data_diff/sqeleton/databases/redshift.py | 176 ----------- data_diff/sqeleton/databases/snowflake.py | 228 ------------- data_diff/sqeleton/databases/trino.py | 48 --- data_diff/sqeleton/databases/vertica.py | 181 ----------- data_diff/sqeleton/queries/__init__.py | 25 -- data_diff/table_segment.py | 10 +- data_diff/utils.py | 3 +- tests/common.py | 6 +- tests/test_api.py | 4 +- tests/test_cli.py | 2 +- tests/test_database.py | 9 +- tests/test_database_types.py | 10 +- tests/test_diff_tables.py | 2 +- tests/test_format.py | 4 +- tests/test_joindiff.py | 4 +- tests/test_postgresql.py | 3 +- tests/test_query.py | 9 +- tests/test_sql.py | 4 +- 66 files changed, 2836 insertions(+), 2975 deletions(-) create mode 100644 data_diff/abcs/__init__.py rename data_diff/{sqeleton => }/abcs/compiler.py (100%) rename data_diff/{sqeleton => }/abcs/database_types.py (100%) rename data_diff/{sqeleton => }/abcs/mixins.py (98%) rename data_diff/{sqeleton => }/bound_exprs.py (85%) rename data_diff/{sqeleton => }/databases/base.py (97%) create mode 100644 data_diff/queries/__init__.py rename data_diff/{sqeleton => }/queries/api.py (97%) rename data_diff/{sqeleton => }/queries/ast_classes.py (98%) rename data_diff/{sqeleton => }/queries/base.py (76%) rename data_diff/{sqeleton => }/queries/compiler.py (92%) rename data_diff/{sqeleton => }/queries/extras.py (89%) rename data_diff/{sqeleton => }/schema.py (91%) delete mode 100644 data_diff/sqeleton/__init__.py delete mode 100644 data_diff/sqeleton/abcs/__init__.py delete mode 100644 data_diff/sqeleton/databases/__init__.py delete mode 100644 data_diff/sqeleton/databases/_connect.py delete mode 100644 data_diff/sqeleton/databases/bigquery.py delete mode 100644 data_diff/sqeleton/databases/clickhouse.py delete mode 100644 data_diff/sqeleton/databases/databricks.py delete mode 100644 data_diff/sqeleton/databases/duckdb.py delete mode 100644 data_diff/sqeleton/databases/mssql.py delete mode 100644 data_diff/sqeleton/databases/mysql.py delete mode 100644 data_diff/sqeleton/databases/oracle.py delete mode 100644 data_diff/sqeleton/databases/postgresql.py delete mode 100644 data_diff/sqeleton/databases/presto.py delete mode 100644 data_diff/sqeleton/databases/redshift.py delete mode 100644 data_diff/sqeleton/databases/snowflake.py delete mode 100644 data_diff/sqeleton/databases/trino.py delete mode 100644 data_diff/sqeleton/databases/vertica.py delete mode 100644 data_diff/sqeleton/queries/__init__.py diff --git a/data_diff/__init__.py b/data_diff/__init__.py index bbdffb01..60c79b10 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -1,9 +1,8 @@ from typing import Sequence, Tuple, Iterator, Optional, Union -from data_diff.sqeleton.abcs import DbTime, DbPath - +from data_diff.abcs.database_types import DbTime, DbPath from data_diff.tracking import disable_tracking -from data_diff.databases import connect +from data_diff.databases._connect import connect from data_diff.diff_tables import Algorithm from data_diff.hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR from data_diff.joindiff_tables import JoinDiffer, TABLE_WRITE_LIMIT diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 481c829f..77dc7fb6 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -12,8 +12,8 @@ from rich.logging import RichHandler import click -from data_diff.sqeleton.schema import create_schema -from data_diff.sqeleton.queries.api import current_timestamp +from data_diff.schema import create_schema +from data_diff.queries.api import current_timestamp from data_diff.dbt import dbt_diff from data_diff.utils import eval_name_template, remove_password_from_url, safezip, match_like, LogStatusHandler @@ -21,7 +21,7 @@ from data_diff.hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR from data_diff.joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer from data_diff.table_segment import TableSegment -from data_diff.databases import connect +from data_diff.databases._connect import connect from data_diff.parse_time import parse_time_before, UNITS_STR, ParseError from data_diff.config import apply_config_from_file from data_diff.tracking import disable_tracking, set_entrypoint_name diff --git a/data_diff/abcs/__init__.py b/data_diff/abcs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/data_diff/sqeleton/abcs/compiler.py b/data_diff/abcs/compiler.py similarity index 100% rename from data_diff/sqeleton/abcs/compiler.py rename to data_diff/abcs/compiler.py diff --git a/data_diff/sqeleton/abcs/database_types.py b/data_diff/abcs/database_types.py similarity index 100% rename from data_diff/sqeleton/abcs/database_types.py rename to data_diff/abcs/database_types.py diff --git a/data_diff/sqeleton/abcs/mixins.py b/data_diff/abcs/mixins.py similarity index 98% rename from data_diff/sqeleton/abcs/mixins.py rename to data_diff/abcs/mixins.py index e33129a2..17f06064 100644 --- a/data_diff/sqeleton/abcs/mixins.py +++ b/data_diff/abcs/mixins.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from data_diff.sqeleton.abcs.database_types import ( +from data_diff.abcs.database_types import ( Array, TemporalType, FractionalType, @@ -10,7 +10,7 @@ JSON, Struct, ) -from data_diff.sqeleton.abcs.compiler import Compilable +from data_diff.abcs.compiler import Compilable class AbstractMixin(ABC): diff --git a/data_diff/sqeleton/bound_exprs.py b/data_diff/bound_exprs.py similarity index 85% rename from data_diff/sqeleton/bound_exprs.py rename to data_diff/bound_exprs.py index 8bbb3063..1742b74c 100644 --- a/data_diff/sqeleton/bound_exprs.py +++ b/data_diff/bound_exprs.py @@ -7,10 +7,11 @@ from runtype import dataclass from typing_extensions import Self -from data_diff.sqeleton.abcs import AbstractDatabase, AbstractCompiler -from data_diff.sqeleton.queries.ast_classes import ExprNode, ITable, TablePath, Compilable -from data_diff.sqeleton.queries.api import table -from data_diff.sqeleton.schema import create_schema +from data_diff.abcs.database_types import AbstractDatabase +from data_diff.abcs.compiler import AbstractCompiler +from data_diff.queries.ast_classes import ExprNode, TablePath, Compilable +from data_diff.queries.api import table +from data_diff.schema import create_schema @dataclass @@ -80,8 +81,8 @@ def bound_table(database: AbstractDatabase, table_path: Union[TablePath, str, tu # Database.table = bound_table # def test(): -# from data_diff.sqeleton. import connect -# from data_diff.sqeleton.queries.api import table +# from data_diff import connect +# from data_diff.queries.api import table # d = connect("mysql://erez:qweqwe123@localhost/erez") # t = table(('Rating',)) diff --git a/data_diff/databases/__init__.py b/data_diff/databases/__init__.py index eaf97281..842cc731 100644 --- a/data_diff/databases/__init__.py +++ b/data_diff/databases/__init__.py @@ -1,17 +1,16 @@ -from data_diff.sqeleton.databases import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError - -from data_diff.sqeleton.databases.postgresql import PostgreSQL as PostgreSQL -from data_diff.sqeleton.databases.mysql import MySQL as MySQL -from data_diff.sqeleton.databases.oracle import Oracle as Oracle -from data_diff.sqeleton.databases.snowflake import Snowflake as Snowflake -from data_diff.sqeleton.databases.bigquery import BigQuery as BigQuery -from data_diff.sqeleton.databases.redshift import Redshift as Redshift -from data_diff.sqeleton.databases.presto import Presto as Presto -from data_diff.sqeleton.databases.databricks import Databricks as Databricks -from data_diff.sqeleton.databases.trino import Trino as Trino -from data_diff.sqeleton.databases.clickhouse import Clickhouse as Clickhouse -from data_diff.sqeleton.databases.vertica import Vertica as Vertica -from data_diff.sqeleton.databases.duckdb import DuckDB as DuckDB -from data_diff.sqeleton.databases.mssql import MsSQL as MsSql - -from data_diff.databases._connect import connect +from data_diff.databases.base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError, BaseDialect, Database +from data_diff.databases._connect import connect as connect +from data_diff.databases._connect import Connect as Connect +from data_diff.databases.postgresql import PostgreSQL as PostgreSQL +from data_diff.databases.mysql import MySQL as MySQL +from data_diff.databases.oracle import Oracle as Oracle +from data_diff.databases.snowflake import Snowflake as Snowflake +from data_diff.databases.bigquery import BigQuery as BigQuery +from data_diff.databases.redshift import Redshift as Redshift +from data_diff.databases.presto import Presto as Presto +from data_diff.databases.databricks import Databricks as Databricks +from data_diff.databases.trino import Trino as Trino +from data_diff.databases.clickhouse import Clickhouse as Clickhouse +from data_diff.databases.vertica import Vertica as Vertica +from data_diff.databases.duckdb import DuckDB as DuckDB +from data_diff.databases.mssql import MsSQL as MsSQL diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index fcef1069..8f842123 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -1,7 +1,15 @@ import logging +from typing import Hashable, MutableMapping, Type, Optional, Union, Dict +from itertools import zip_longest +from contextlib import suppress +import weakref +import dsnparse +import toml -from data_diff.sqeleton.databases import Connect +from runtype import dataclass +from typing_extensions import Self +from data_diff.databases.base import Database, ThreadedDatabase from data_diff.databases.postgresql import PostgreSQL from data_diff.databases.mysql import MySQL from data_diff.databases.oracle import Oracle @@ -14,7 +22,57 @@ from data_diff.databases.clickhouse import Clickhouse from data_diff.databases.vertica import Vertica from data_diff.databases.duckdb import DuckDB -from data_diff.databases.mssql import MsSql +from data_diff.databases.mssql import MsSQL + + +@dataclass +class MatchUriPath: + database_cls: Type[Database] + + def match_path(self, dsn): + help_str = self.database_cls.CONNECT_URI_HELP + params = self.database_cls.CONNECT_URI_PARAMS + kwparams = self.database_cls.CONNECT_URI_KWPARAMS + + dsn_dict = dict(dsn.query) + matches = {} + for param, arg in zip_longest(params, dsn.paths): + if param is None: + raise ValueError(f"Too many parts to path. Expected format: {help_str}") + + optional = param.endswith("?") + param = param.rstrip("?") + + if arg is None: + try: + arg = dsn_dict.pop(param) + except KeyError: + if not optional: + raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") + + arg = None + + assert param and param not in matches + matches[param] = arg + + for param in kwparams: + try: + arg = dsn_dict.pop(param) + except KeyError: + raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") + + assert param and arg and param not in matches, (param, arg, matches.keys()) + matches[param] = arg + + for param, value in dsn_dict.items(): + if param in matches: + raise ValueError( + f"Parameter '{param}' already provided as positional argument. Expected format: {help_str}" + ) + + matches[param] = value + + return matches DATABASE_BY_SCHEME = { @@ -30,10 +88,201 @@ "trino": Trino, "clickhouse": Clickhouse, "vertica": Vertica, - "mssql": MsSql, + "mssql": MsSQL, } +class Connect: + """Provides methods for connecting to a supported database using a URL or connection dict.""" + conn_cache: MutableMapping[Hashable, Database] + + def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): + self.database_by_scheme = database_by_scheme + self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()} + self.conn_cache = weakref.WeakValueDictionary() + + def for_databases(self, *dbs) -> Self: + database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs} + return type(self)(database_by_scheme) + + def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs) -> Database: + """Connect to the given database uri + + thread_count determines the max number of worker threads per database, + if relevant. None means no limit. + + Parameters: + db_uri (str): The URI for the database to connect + thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) + + Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. + + Supported schemes: + - postgresql + - mysql + - oracle + - snowflake + - bigquery + - redshift + - presto + - databricks + - trino + - clickhouse + - vertica + - duckdb + """ + + dsn = dsnparse.parse(db_uri) + if len(dsn.schemes) > 1: + raise NotImplementedError("No support for multiple schemes") + (scheme,) = dsn.schemes + + if scheme == "toml": + toml_path = dsn.path or dsn.host + database = dsn.fragment + if not database: + raise ValueError("Must specify a database name, e.g. 'toml://path#database'. ") + with open(toml_path) as f: + config = toml.load(f) + try: + conn_dict = config["database"][database] + except KeyError: + raise ValueError(f"Cannot find database config named '{database}'.") + return self.connect_with_dict(conn_dict, thread_count, **kwargs) + + try: + matcher = self.match_uri_path[scheme] + except KeyError: + raise NotImplementedError(f"Scheme '{scheme}' currently not supported") + + cls = matcher.database_cls + + if scheme == "databricks": + assert not dsn.user + kw = {} + kw["access_token"] = dsn.password + kw["http_path"] = dsn.path + kw["server_hostname"] = dsn.host + kw.update(dsn.query) + elif scheme == "duckdb": + kw = {} + kw["filepath"] = dsn.dbname + kw["dbname"] = dsn.user + else: + kw = matcher.match_path(dsn) + + if scheme == "bigquery": + kw["project"] = dsn.host + return cls(**kw, **kwargs) + + if scheme == "snowflake": + kw["account"] = dsn.host + assert not dsn.port + kw["user"] = dsn.user + kw["password"] = dsn.password + else: + if scheme == "oracle": + kw["host"] = dsn.hostloc + else: + kw["host"] = dsn.host + kw["port"] = dsn.port + kw["user"] = dsn.user + if dsn.password: + kw["password"] = dsn.password + + kw = {k: v for k, v in kw.items() if v is not None} + + if issubclass(cls, ThreadedDatabase): + db = cls(thread_count=thread_count, **kw, **kwargs) + else: + db = cls(**kw, **kwargs) + + return self._connection_created(db) + + def connect_with_dict(self, d, thread_count, **kwargs): + d = dict(d) + driver = d.pop("driver") + try: + matcher = self.match_uri_path[driver] + except KeyError: + raise NotImplementedError(f"Driver '{driver}' currently not supported") + + cls = matcher.database_cls + if issubclass(cls, ThreadedDatabase): + db = cls(thread_count=thread_count, **d, **kwargs) + else: + db = cls(**d, **kwargs) + + return self._connection_created(db) + + def _connection_created(self, db): + "Nop function to be overridden by subclasses." + return db + + def __call__( + self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True, **kwargs + ) -> Database: + """Connect to a database using the given database configuration. + + Configuration can be given either as a URI string, or as a dict of {option: value}. + + The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf. + + thread_count determines the max number of worker threads per database, + if relevant. None means no limit. + + Parameters: + db_conf (str | dict): The configuration for the database to connect. URI or dict. + thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) + shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True) + bigquery_credentials (google.oauth2.credentials.Credentials): Custom Google oAuth2 credential for BigQuery. + (default: None) + + Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. + + Supported drivers: + - postgresql + - mysql + - oracle + - snowflake + - bigquery + - redshift + - presto + - databricks + - trino + - clickhouse + - vertica + + Example: + >>> connect("mysql://localhost/db") + + >>> connect({"driver": "mysql", "host": "localhost", "database": "db"}) + + """ + cache_key = self.__make_cache_key(db_conf) + if shared: + with suppress(KeyError): + conn = self.conn_cache[cache_key] + if not conn.is_closed: + return conn + + if isinstance(db_conf, str): + conn = self.connect_to_uri(db_conf, thread_count, **kwargs) + elif isinstance(db_conf, dict): + conn = self.connect_with_dict(db_conf, thread_count, **kwargs) + else: + raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.") + + if shared: + self.conn_cache[cache_key] = conn + return conn + + def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable: + if isinstance(db_conf, dict): + return tuple(db_conf.items()) + return db_conf + + class Connect_SetUTC(Connect): """Provides methods for connecting to a supported database using a URL or connection dict. diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/databases/base.py similarity index 97% rename from data_diff/sqeleton/databases/base.py rename to data_diff/databases/base.py index 720bc029..a89ab74e 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/databases/base.py @@ -2,7 +2,7 @@ import math import sys import logging -from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar, TYPE_CHECKING +from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar from functools import partial, wraps from concurrent.futures import ThreadPoolExecutor import threading @@ -14,9 +14,9 @@ from typing_extensions import Self from data_diff.utils import is_uuid, safezip -from data_diff.sqeleton.queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this -from data_diff.sqeleton.queries.ast_classes import Random -from data_diff.sqeleton.abcs.database_types import ( +from data_diff.queries.api import Expr, Compiler, table, Select, SKIP, Explain, Code, this +from data_diff.queries.ast_classes import Random +from data_diff.abcs.database_types import ( AbstractDatabase, Array, Struct, @@ -39,14 +39,14 @@ Boolean, JSON, ) -from data_diff.sqeleton.abcs.mixins import Compilable -from data_diff.sqeleton.abcs.mixins import ( +from data_diff.abcs.mixins import Compilable +from data_diff.abcs.mixins import ( AbstractMixin_Schema, AbstractMixin_RandomSample, AbstractMixin_NormalizeValue, AbstractMixin_OptimizerHints, ) -from data_diff.sqeleton.bound_exprs import bound_table +from data_diff.bound_exprs import bound_table logger = logging.getLogger("database") @@ -315,7 +315,7 @@ class Database(AbstractDatabase[T]): Used for providing connection code and implementation specific SQL utilities. - Instanciated using :meth:`~data_diff.sqeleton.connect` + Instanciated using :meth:`~data_diff.connect` """ default_schema: str = None diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 8b0bc1e5..5925234f 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,2 +1,297 @@ -from data_diff.sqeleton.databases.bigquery import Dialect as Dialect -from data_diff.sqeleton.databases.bigquery import BigQuery as BigQuery +import re +from typing import Any, List, Union +from data_diff.abcs.database_types import ( + ColType, + Array, + JSON, + Struct, + Timestamp, + Datetime, + Integer, + Decimal, + Float, + Text, + DbPath, + FractionalType, + TemporalType, + Boolean, + UnknownColType, +) +from data_diff.abcs.mixins import ( + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, + AbstractMixin_Schema, + AbstractMixin_TimeTravel, +) +from data_diff.abcs.compiler import Compilable +from data_diff.queries.api import this, table, SKIP, code +from data_diff.databases.base import ( + BaseDialect, + Database, + import_helper, + parse_table_name, + ConnectError, + apply_query, + QueryResult, +) +from data_diff.databases.base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter, Mixin_RandomSample + + +@import_helper(text="Please install BigQuery and configure your google-cloud access.") +def import_bigquery(): + from google.cloud import bigquery + + return bigquery + + +def import_bigquery_service_account(): + from google.oauth2 import service_account + + return service_account + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" + + if coltype.precision == 0: + return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000', {value})" + elif coltype.precision == 6: + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + + timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return f"format('%.{coltype.precision}f', {value})" + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"cast({value} as int)") + + def normalize_json(self, value: str, _coltype: JSON) -> str: + # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: + # Got error: 400 Grouping is not defined for arguments of type ARRAY at … + # So we do the best effort and compare it as strings, hoping that the JSON forms + # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. + return f"to_json_string({value})" + + def normalize_array(self, value: str, _coltype: Array) -> str: + # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: + # Got error: 400 Grouping is not defined for arguments of type ARRAY at … + # So we do the best effort and compare it as strings, hoping that the JSON forms + # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. + return f"to_json_string({value})" + + def normalize_struct(self, value: str, _coltype: Struct) -> str: + # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: + # Got error: 400 Grouping is not defined for arguments of type ARRAY at … + # So we do the best effort and compare it as strings, hoping that the JSON forms + # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. + return f"to_json_string({value})" + + +class Mixin_Schema(AbstractMixin_Schema): + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + table(table_schema, "INFORMATION_SCHEMA", "TABLES") + .where( + this.table_schema == table_schema, + this.table_name.like(like) if like is not None else SKIP, + this.table_type == "BASE TABLE", + ) + .select(this.table_name) + ) + + +class Mixin_TimeTravel(AbstractMixin_TimeTravel): + def time_travel( + self, + table: Compilable, + before: bool = False, + timestamp: Compilable = None, + offset: Compilable = None, + statement: Compilable = None, + ) -> Compilable: + if before: + raise NotImplementedError("before=True not supported for BigQuery time-travel") + + if statement is not None: + raise NotImplementedError("BigQuery time-travel doesn't support querying by statement id") + + if timestamp is not None: + assert offset is None + return code("{table} FOR SYSTEM_TIME AS OF {timestamp}", table=table, timestamp=timestamp) + + assert offset is not None + return code( + "{table} FOR SYSTEM_TIME AS OF TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL {offset} HOUR);", + table=table, + offset=offset, + ) + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "BigQuery" + ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation + TYPE_CLASSES = { + # Dates + "TIMESTAMP": Timestamp, + "DATETIME": Datetime, + # Numbers + "INT64": Integer, + "INT32": Integer, + "NUMERIC": Decimal, + "BIGNUMERIC": Decimal, + "FLOAT64": Float, + "FLOAT32": Float, + "STRING": Text, + "BOOL": Boolean, + "JSON": JSON, + } + TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>") + TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>") + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} + + def random(self) -> str: + return "RAND()" + + def quote(self, s: str): + return f"`{s}`" + + def to_string(self, s: str): + return f"cast({s} as string)" + + def type_repr(self, t) -> str: + try: + return {str: "STRING", float: "FLOAT64"}[t] + except KeyError: + return super().type_repr(t) + + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + *args: Any, # pass-through args + **kwargs: Any, # pass-through args + ) -> ColType: + col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs) + if isinstance(col_type, UnknownColType): + m = self.TYPE_ARRAY_RE.fullmatch(type_repr) + if m: + item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs) + col_type = Array(item_type=item_type) + + # We currently ignore structs' structure, but later can parse it too. Examples: + # - STRUCT (unnamed) + # - STRUCT (named) + # - STRUCT> (with complex fields) + # - STRUCT> (nested) + m = self.TYPE_STRUCT_RE.fullmatch(type_repr) + if m: + col_type = Struct() + + return col_type + + def to_comparable(self, value: str, coltype: ColType) -> str: + """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" + if isinstance(coltype, (JSON, Array, Struct)): + return self.normalize_value_by_type(value, coltype) + else: + return super().to_comparable(value, coltype) + + def set_timezone_to_utc(self) -> str: + raise NotImplementedError() + + +class BigQuery(Database): + CONNECT_URI_HELP = "bigquery:///" + CONNECT_URI_PARAMS = ["dataset"] + dialect = Dialect() + + def __init__(self, project, *, dataset, bigquery_credentials=None, **kw): + credentials = bigquery_credentials + bigquery = import_bigquery() + + keyfile = kw.pop("keyfile", None) + if keyfile: + bigquery_service_account = import_bigquery_service_account() + credentials = bigquery_service_account.Credentials.from_service_account_file( + keyfile, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + self._client = bigquery.Client(project=project, credentials=credentials, **kw) + self.project = project + self.dataset = dataset + + self.default_schema = dataset + + def _normalize_returned_value(self, value): + if isinstance(value, bytes): + return value.decode() + return value + + def _query_atom(self, sql_code: str): + from google.cloud import bigquery + + try: + result = self._client.query(sql_code).result() + columns = [c.name for c in result.schema] + rows = list(result) + except Exception as e: + msg = "Exception when trying to execute SQL code:\n %s\n\nGot error: %s" + raise ConnectError(msg % (sql_code, e)) + + if rows and isinstance(rows[0], bigquery.table.Row): + rows = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in rows] + return QueryResult(rows, columns) + + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult: + return apply_query(self._query_atom, sql_code) + + def close(self): + super().close() + self._client.close() + + def select_table_schema(self, path: DbPath) -> str: + project, schema, name = self._normalize_table_path(path) + return ( + "SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale " + f"FROM `{project}`.`{schema}`.INFORMATION_SCHEMA.COLUMNS " + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + ) + + def query_table_unique_columns(self, path: DbPath) -> List[str]: + return [] + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 0: + raise ValueError(f"{self.name}: Bad table path for {self}: ()") + elif len(path) == 1: + return (self.project, self.default_schema, path[0]) + elif len(path) == 2: + return (self.project,) + path + elif len(path) == 3: + return path + else: + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: [project.]schema.table" + ) + + def parse_table_name(self, name: str) -> DbPath: + path = parse_table_name(name) + return tuple(i for i in self._normalize_table_path(path) if i is not None) + + @property + def is_autocommit(self) -> bool: + return True diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 85083383..9366b922 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -1,2 +1,196 @@ -from data_diff.sqeleton.databases.clickhouse import Dialect as Dialect -from data_diff.sqeleton.databases.clickhouse import Clickhouse as Clickhouse +from typing import Optional, Type + +from data_diff.databases.base import ( + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + TIMESTAMP_PRECISION_POS, + BaseDialect, + ThreadedDatabase, + import_helper, + ConnectError, + Mixin_RandomSample, +) +from data_diff.abcs.database_types import ( + ColType, + Decimal, + Float, + Integer, + FractionalType, + Native_UUID, + TemporalType, + Text, + Timestamp, + Boolean, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue + +# https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings/#default-database +DEFAULT_DATABASE = "default" + + +@import_helper("clickhouse") +def import_clickhouse(): + import clickhouse_driver + + return clickhouse_driver + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS + return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_number(self, value: str, coltype: FractionalType) -> str: + # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. + # For example: + # select toString(toDecimal128(1.10, 2)); -- the result is 1.1 + # select toString(toDecimal128(1.00, 2)); -- the result is 1 + # So, we should use some custom approach to save these trailing zeros. + # To avoid it, we can add a small value like 0.000001 to prevent dropping of zeros from the end when casting. + # For examples above it looks like: + # select toString(toDecimal128(1.10, 2 + 1) + toDecimal128(0.001, 3)); -- the result is 1.101 + # After that, cut an extra symbol from the string, i.e. 1.101 -> 1.10 + # So, the algorithm is: + # 1. Cast to decimal with precision + 1 + # 2. Add a small value 10^(-precision-1) + # 3. Cast the result to string + # 4. Drop the extra digit from the string. To do that, we need to slice the string + # with length = digits in an integer part + 1 (symbol of ".") + precision + + if coltype.precision == 0: + return self.to_string(f"round({value})") + + precision = coltype.precision + # TODO: too complex, is there better performance way? + value = f""" + if({value} >= 0, '', '-') || left( + toString( + toDecimal128( + round(abs({value}), {precision}), + {precision} + 1 + ) + + + toDecimal128( + exp10(-{precision + 1}), + {precision} + 1 + ) + ), + toUInt8( + greatest( + floor(log10(abs({value}))) + 1, + 1 + ) + ) + 1 + {precision} + ) + """ + return value + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + prec = coltype.precision + if coltype.rounds: + timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)" + return self.to_string(timestamp) + + fractional = f"toUnixTimestamp64Micro(toDateTime64({value}, {prec})) % 1000000" + fractional = f"lpad({self.to_string(fractional)}, 6, '0')" + value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}" + return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" + + +class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Clickhouse" + ROUNDS_ON_PREC_LOSS = False + TYPE_CLASSES = { + "Int8": Integer, + "Int16": Integer, + "Int32": Integer, + "Int64": Integer, + "Int128": Integer, + "Int256": Integer, + "UInt8": Integer, + "UInt16": Integer, + "UInt32": Integer, + "UInt64": Integer, + "UInt128": Integer, + "UInt256": Integer, + "Float32": Float, + "Float64": Float, + "Decimal": Decimal, + "UUID": Native_UUID, + "String": Text, + "FixedString": Text, + "DateTime": Timestamp, + "DateTime64": Timestamp, + "Bool": Boolean, + } + MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str) -> str: + return f'"{s}"' + + def to_string(self, s: str) -> str: + return f"toString({s})" + + def _convert_db_precision_to_digits(self, p: int) -> int: + # Done the same as for PostgreSQL but need to rewrite in another way + # because it does not help for float with a big integer part. + return super()._convert_db_precision_to_digits(p) - 2 + + def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: + nullable_prefix = "Nullable(" + if type_repr.startswith(nullable_prefix): + type_repr = type_repr[len(nullable_prefix) :].rstrip(")") + + if type_repr.startswith("Decimal"): + type_repr = "Decimal" + elif type_repr.startswith("FixedString"): + type_repr = "FixedString" + elif type_repr.startswith("DateTime64"): + type_repr = "DateTime64" + + return self.TYPE_CLASSES.get(type_repr) + + # def timestamp_value(self, t: DbTime) -> str: + # # return f"'{t}'" + # return f"'{str(t)[:19]}'" + + def set_timezone_to_utc(self) -> str: + raise NotImplementedError() + + def current_timestamp(self) -> str: + return "now()" + + +class Clickhouse(ThreadedDatabase): + dialect = Dialect() + CONNECT_URI_HELP = "clickhouse://:@/" + CONNECT_URI_PARAMS = ["database?"] + + def __init__(self, *, thread_count: int, **kw): + super().__init__(thread_count=thread_count) + + self._args = kw + # In Clickhouse database and schema are the same + self.default_schema = kw.get("database", DEFAULT_DATABASE) + + def create_connection(self): + clickhouse = import_clickhouse() + + class SingleConnection(clickhouse.dbapi.connection.Connection): + """Not thread-safe connection to Clickhouse""" + + def cursor(self, cursor_factory=None): + if not len(self.cursors): + _ = super().cursor() + return self.cursors[0] + + try: + return SingleConnection(**self._args) + except clickhouse.OperationError as e: + raise ConnectError(*e.args) from e + + @property + def is_autocommit(self) -> bool: + return True diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 36348ffb..1b8aa33a 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,2 +1,199 @@ -from data_diff.sqeleton.databases.databricks import Dialect as Dialect -from data_diff.sqeleton.databases.databricks import Databricks as Databricks +import math +from typing import Dict, Sequence +import logging + +from data_diff.abcs.database_types import ( + Integer, + Float, + Decimal, + Timestamp, + Text, + TemporalType, + NumericType, + DbPath, + ColType, + UnknownColType, + Boolean, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.databases.base import ( + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + BaseDialect, + ThreadedDatabase, + import_helper, + parse_table_name, + Mixin_RandomSample, +) + + +@import_helper(text="You can install it using 'pip install databricks-sql-connector'") +def import_databricks(): + import databricks.sql + + return databricks + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + """Databricks timestamp contains no more than 6 digits in precision""" + + if coltype.rounds: + timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)" + return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')" + + precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision) + return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" + + def normalize_number(self, value: str, coltype: NumericType) -> str: + value = f"cast({value} as decimal(38, {coltype.precision}))" + if coltype.precision > 0: + value = f"format_number({value}, {coltype.precision})" + return f"replace({self.to_string(value)}, ',', '')" + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"cast ({value} as int)") + + +class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Databricks" + ROUNDS_ON_PREC_LOSS = True + TYPE_CLASSES = { + # Numbers + "INT": Integer, + "SMALLINT": Integer, + "TINYINT": Integer, + "BIGINT": Integer, + "FLOAT": Float, + "DOUBLE": Float, + "DECIMAL": Decimal, + # Timestamps + "TIMESTAMP": Timestamp, + # Text + "STRING": Text, + # Boolean + "BOOLEAN": Boolean, + } + MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str): + return f"`{s}`" + + def to_string(self, s: str) -> str: + return f"cast({s} as string)" + + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues + return max(super()._convert_db_precision_to_digits(p) - 2, 0) + + def set_timezone_to_utc(self) -> str: + return "SET TIME ZONE 'UTC'" + + +class Databricks(ThreadedDatabase): + dialect = Dialect() + CONNECT_URI_HELP = "databricks://:@/" + CONNECT_URI_PARAMS = ["catalog", "schema"] + + def __init__(self, *, thread_count, **kw): + logging.getLogger("databricks.sql").setLevel(logging.WARNING) + + self._args = kw + self.default_schema = kw.get("schema", "default") + self.catalog = self._args.get("catalog", "hive_metastore") + super().__init__(thread_count=thread_count) + + def create_connection(self): + databricks = import_databricks() + + try: + return databricks.sql.connect( + server_hostname=self._args["server_hostname"], + http_path=self._args["http_path"], + access_token=self._args["access_token"], + catalog=self.catalog, + ) + except databricks.sql.exc.Error as e: + raise ConnectionError(*e.args) from e + + def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. + # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html + # So, to obtain information about schema, we should use another approach. + + conn = self.create_connection() + + catalog, schema, table = self._normalize_table_path(path) + with conn.cursor() as cursor: + cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table) + try: + rows = cursor.fetchall() + finally: + conn.close() + if not rows: + raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") + + d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} + assert len(d) == len(rows) + return d + + def _process_table_schema( + self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None + ): + accept = {i.lower() for i in filter_columns} + rows = [row for name, row in raw_schema.items() if name.lower() in accept] + + resulted_rows = [] + for row in rows: + row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1] + type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType) + + if issubclass(type_cls, Integer): + row = (row[0], row_type, None, None, 0) + + elif issubclass(type_cls, Float): + numeric_precision = math.ceil(row[2] / math.log(2, 10)) + row = (row[0], row_type, None, numeric_precision, None) + + elif issubclass(type_cls, Decimal): + items = row[1][8:].rstrip(")").split(",") + numeric_precision, numeric_scale = int(items[0]), int(items[1]) + row = (row[0], row_type, None, numeric_precision, numeric_scale) + + elif issubclass(type_cls, Timestamp): + row = (row[0], row_type, row[2], None, None) + + else: + row = (row[0], row_type, None, None, None) + + resulted_rows.append(row) + + col_dict: Dict[str, ColType] = {row[0]: self.dialect.parse_type(path, *row) for row in resulted_rows} + + self._refine_coltypes(path, col_dict, where) + return col_dict + + def parse_table_name(self, name: str) -> DbPath: + path = parse_table_name(name) + return tuple(i for i in self._normalize_table_path(path) if i is not None) + + @property + def is_autocommit(self) -> bool: + return True + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return self.catalog, self.default_schema, path[0] + elif len(path) == 2: + return self.catalog, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or catalog.schema.table" + ) diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index 26558634..f7fdaadd 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -1,2 +1,192 @@ -from data_diff.sqeleton.databases.duckdb import Dialect as Dialect -from data_diff.sqeleton.databases.duckdb import DuckDB as DuckDB +from typing import Union + +from data_diff.utils import match_regexps +from data_diff.abcs.database_types import ( + Timestamp, + TimestampTZ, + DbPath, + ColType, + Float, + Decimal, + Integer, + TemporalType, + Native_UUID, + Text, + FractionalType, + Boolean, + AbstractTable, +) +from data_diff.abcs.mixins import ( + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, + AbstractMixin_RandomSample, + AbstractMixin_Regex, +) +from data_diff.databases.base import ( + Database, + BaseDialect, + import_helper, + ConnectError, + ThreadLocalInterpreter, + TIMESTAMP_PRECISION_POS, +) +from data_diff.databases.base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Mixin_Schema +from data_diff.queries.ast_classes import Func, Compilable +from data_diff.queries.api import code + + +@import_helper("duckdb") +def import_duckdb(): + import duckdb + + return duckdb + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + # It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers. + if coltype.rounds and coltype.precision > 0: + return f"CONCAT(SUBSTRING(STRFTIME({value}::TIMESTAMP, '%Y-%m-%d %H:%M:%S.'),1,23), LPAD(((ROUND(strftime({value}::timestamp, '%f')::DECIMAL(15,7)/100000,{coltype.precision-1})*100000)::INT)::VARCHAR,6,'0'))" + + return f"rpad(substring(strftime({value}::timestamp, '%Y-%m-%d %H:%M:%S.%f'),1,{TIMESTAMP_PRECISION_POS+coltype.precision}),26,'0')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"{value}::DECIMAL(38, {coltype.precision})") + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"{value}::INTEGER") + + +class Mixin_RandomSample(AbstractMixin_RandomSample): + def random_sample_n(self, tbl: AbstractTable, size: int) -> AbstractTable: + return code("SELECT * FROM ({tbl}) USING SAMPLE {size};", tbl=tbl, size=size) + + def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> AbstractTable: + return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio)) + + +class Mixin_Regex(AbstractMixin_Regex): + def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: + return Func("regexp_matches", [string, pattern]) + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "DuckDB" + ROUNDS_ON_PREC_LOSS = False + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_INDEXES = True + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + TYPE_CLASSES = { + # Timestamps + "TIMESTAMP WITH TIME ZONE": TimestampTZ, + "TIMESTAMP": Timestamp, + # Numbers + "DOUBLE": Float, + "FLOAT": Float, + "DECIMAL": Decimal, + "INTEGER": Integer, + "BIGINT": Integer, + # Text + "VARCHAR": Text, + "TEXT": Text, + # UUID + "UUID": Native_UUID, + # Bool + "BOOLEAN": Boolean, + } + + def quote(self, s: str): + return f'"{s}"' + + def to_string(self, s: str): + return f"{s}::VARCHAR" + + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues in PostgreSQL + return super()._convert_db_precision_to_digits(p) - 2 + + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + regexps = { + r"DECIMAL\((\d+),(\d+)\)": Decimal, + } + + for m, t_cls in match_regexps(regexps, type_repr): + precision = int(m.group(2)) + return t_cls(precision=precision) + + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) + + def set_timezone_to_utc(self) -> str: + return "SET GLOBAL TimeZone='UTC'" + + def current_timestamp(self) -> str: + return "current_timestamp" + + +class DuckDB(Database): + dialect = Dialect() + SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it + default_schema = "main" + CONNECT_URI_HELP = "duckdb://@" + CONNECT_URI_PARAMS = ["database", "dbpath"] + + def __init__(self, **kw): + self._args = kw + self._conn = self.create_connection() + + @property + def is_autocommit(self) -> bool: + return True + + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): + "Uses the standard SQL cursor interface" + return self._query_conn(self._conn, sql_code) + + def close(self): + super().close() + self._conn.close() + + def create_connection(self): + ddb = import_duckdb() + try: + return ddb.connect(self._args["filepath"]) + except ddb.OperationalError as e: + raise ConnectError(*e.args) from e + + def select_table_schema(self, path: DbPath) -> str: + database, schema, table = self._normalize_table_path(path) + + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, database) + + return ( + f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return None, self.default_schema, path[0] + elif len(path) == 2: + return None, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index be1c2cae..28d67c99 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -1,2 +1,214 @@ -from data_diff.sqeleton.databases.mssql import Dialect as Dialect -from data_diff.sqeleton.databases.mssql import MsSQL as MsSql +from typing import Optional +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.databases.base import ( + CHECKSUM_HEXDIGITS, + Mixin_OptimizerHints, + Mixin_RandomSample, + QueryError, + ThreadedDatabase, + import_helper, + ConnectError, + BaseDialect, +) +from data_diff.databases.base import Mixin_Schema +from data_diff.abcs.database_types import ( + JSON, + Timestamp, + TimestampTZ, + DbPath, + Float, + Decimal, + Integer, + TemporalType, + Native_UUID, + Text, + FractionalType, + Boolean, +) + + +@import_helper("mssql") +def import_mssql(): + import pyodbc + + return pyodbc + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.precision > 0: + formatted_value = ( + f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss') + '.' + " + f"SUBSTRING(FORMAT({value}, 'fffffff'), 1, {coltype.precision})" + ) + else: + formatted_value = f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss')" + + return formatted_value + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + if coltype.precision == 0: + return f"CAST(FLOOR({value}) AS VARCHAR)" + + return f"FORMAT({value}, 'N{coltype.precision}')" + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))" + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "MsSQL" + ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_INDEXES = True + TYPE_CLASSES = { + # Timestamps + "datetimeoffset": TimestampTZ, + "datetime": Timestamp, + "datetime2": Timestamp, + "smalldatetime": Timestamp, + "date": Timestamp, + # Numbers + "float": Float, + "real": Float, + "decimal": Decimal, + "money": Decimal, + "smallmoney": Decimal, + # int + "int": Integer, + "bigint": Integer, + "tinyint": Integer, + "smallint": Integer, + # Text + "varchar": Text, + "char": Text, + "text": Text, + "ntext": Text, + "nvarchar": Text, + "nchar": Text, + "binary": Text, + "varbinary": Text, + # UUID + "uniqueidentifier": Native_UUID, + # Bool + "bit": Boolean, + # JSON + "json": JSON, + } + + MIXINS = {Mixin_Schema, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str): + return f"[{s}]" + + def set_timezone_to_utc(self) -> str: + raise NotImplementedError("MsSQL does not support a session timezone setting.") + + def current_timestamp(self) -> str: + return "GETDATE()" + + def current_database(self) -> str: + return "DB_NAME()" + + def current_schema(self) -> str: + return """default_schema_name + FROM sys.database_principals + WHERE name = CURRENT_USER""" + + def to_string(self, s: str): + return f"CONVERT(varchar, {s})" + + def type_repr(self, t) -> str: + try: + return {bool: "bit"}[t] + except KeyError: + return super().type_repr(t) + + def random(self) -> str: + return "rand()" + + def is_distinct_from(self, a: str, b: str) -> str: + # IS (NOT) DISTINCT FROM is available only since SQLServer 2022. + # See: https://stackoverflow.com/a/18684859/857383 + return f"(({a}<>{b} OR {a} IS NULL OR {b} IS NULL) AND NOT({a} IS NULL AND {b} IS NULL))" + + def offset_limit( + self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + ) -> str: + if offset: + raise NotImplementedError("No support for OFFSET in query") + + result = "" + if not has_order_by: + result += "ORDER BY 1" + + result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY" + return result + + def constant_values(self, rows) -> str: + values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) + return f"VALUES {values}" + + +class MsSQL(ThreadedDatabase): + dialect = Dialect() + # + CONNECT_URI_HELP = "mssql://:@//" + CONNECT_URI_PARAMS = ["database", "schema"] + + def __init__(self, host, port, user, password, *, database, thread_count, **kw): + args = dict(server=host, port=port, database=database, user=user, password=password, **kw) + self._args = {k: v for k, v in args.items() if v is not None} + self._args["driver"] = "{ODBC Driver 18 for SQL Server}" + + # TODO temp dev debug + self._args["TrustServerCertificate"] = "yes" + + try: + self.default_database = self._args["database"] + self.default_schema = self._args["schema"] + except KeyError: + raise ValueError("Specify a default database and schema.") + + super().__init__(thread_count=thread_count) + + def create_connection(self): + self._mssql = import_mssql() + try: + connection = self._mssql.connect(**self._args) + return connection + except self._mssql.Error as error: + raise ConnectError(*error.args) from error + + def select_table_schema(self, path: DbPath) -> str: + """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" + database, schema, name = self._normalize_table_path(path) + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, self.dialect.quote(database)) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + f"FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + ) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return self.default_database, self.default_schema, path[0] + elif len(path) == 2: + return self.default_database, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) + + def _query_cursor(self, c, sql_code: str): + try: + return super()._query_cursor(c, sql_code) + except self._mssql.DatabaseError as e: + raise QueryError(e) diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 0a715600..910ff78d 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,2 +1,159 @@ -from data_diff.sqeleton.databases.mysql import Dialect as Dialect -from data_diff.sqeleton.databases.mysql import MySQL as MySQL +from data_diff.abcs.database_types import ( + Datetime, + Timestamp, + Float, + Decimal, + Integer, + Text, + TemporalType, + FractionalType, + ColType_UUID, + Boolean, + Date, +) +from data_diff.abcs.mixins import ( + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, + AbstractMixin_Regex, +) +from data_diff.databases.base import ( + Mixin_OptimizerHints, + ThreadedDatabase, + import_helper, + ConnectError, + BaseDialect, + Compilable, +) +from data_diff.databases.base import ( + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + TIMESTAMP_PRECISION_POS, + Mixin_Schema, + Mixin_RandomSample, +) +from data_diff.queries.ast_classes import BinBoolOp + + +@import_helper("mysql") +def import_mysql(): + import mysql.connector + + return mysql.connector + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") + + s = self.to_string(f"cast({value} as datetime(6))") + return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + return f"TRIM(CAST({value} AS char))" + + +class Mixin_Regex(AbstractMixin_Regex): + def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: + return BinBoolOp("REGEXP", [string, pattern]) + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "MySQL" + ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_INDEXES = True + TYPE_CLASSES = { + # Dates + "datetime": Datetime, + "timestamp": Timestamp, + "date": Date, + # Numbers + "double": Float, + "float": Float, + "decimal": Decimal, + "int": Integer, + "bigint": Integer, + "mediumint": Integer, + "smallint": Integer, + "tinyint": Integer, + # Text + "varchar": Text, + "char": Text, + "varbinary": Text, + "binary": Text, + "text": Text, + "mediumtext": Text, + "longtext": Text, + "tinytext": Text, + # Boolean + "boolean": Boolean, + } + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str): + return f"`{s}`" + + def to_string(self, s: str): + return f"cast({s} as char)" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"not ({a} <=> {b})" + + def random(self) -> str: + return "RAND()" + + def type_repr(self, t) -> str: + try: + return { + str: "VARCHAR(1024)", + }[t] + except KeyError: + return super().type_repr(t) + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN FORMAT=TREE {query}" + + def optimizer_hints(self, s: str): + return f"/*+ {s} */ " + + def set_timezone_to_utc(self) -> str: + return "SET @@session.time_zone='+00:00'" + + +class MySQL(ThreadedDatabase): + dialect = Dialect() + SUPPORTS_ALPHANUMS = False + SUPPORTS_UNIQUE_CONSTAINT = True + CONNECT_URI_HELP = "mysql://:@/" + CONNECT_URI_PARAMS = ["database?"] + + def __init__(self, *, thread_count, **kw): + self._args = kw + + super().__init__(thread_count=thread_count) + + # In MySQL schema and database are synonymous + try: + self.default_schema = kw["database"] + except KeyError: + raise ValueError("MySQL URL must specify a database") + + def create_connection(self): + mysql = import_mysql() + try: + return mysql.connect(charset="utf8", use_unicode=True, **self._args) + except mysql.Error as e: + if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR: + raise ConnectError("Bad user name or password") from e + elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: + raise ConnectError("Database does not exist") from e + raise ConnectError(*e.args) from e diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 7c10fd11..f0309c11 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,2 +1,206 @@ -from data_diff.sqeleton.databases.oracle import Dialect as Dialect -from data_diff.sqeleton.databases.oracle import Oracle as Oracle +from typing import Dict, List, Optional + +from data_diff.utils import match_regexps +from data_diff.abcs.database_types import ( + Decimal, + Float, + Text, + DbPath, + TemporalType, + ColType, + DbTime, + ColType_UUID, + Timestamp, + TimestampTZ, + FractionalType, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema +from data_diff.abcs.compiler import Compilable +from data_diff.queries.api import this, table, SKIP +from data_diff.databases.base import ( + BaseDialect, + Mixin_OptimizerHints, + ThreadedDatabase, + import_helper, + ConnectError, + QueryError, + Mixin_RandomSample, +) +from data_diff.databases.base import TIMESTAMP_PRECISION_POS + +SESSION_TIME_ZONE = None # Changed by the tests + + +@import_helper("oracle") +def import_oracle(): + import oracledb + + return oracledb + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + # standard_hash is faster than DBMS_CRYPTO.Hash + # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? + return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + # Cast is necessary for correct MD5 (trimming not enough) + return f"CAST(TRIM({value}) AS VARCHAR(36))" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" + + if coltype.precision > 0: + truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')" + else: + truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')" + return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + # FM999.9990 + format_str = "FM" + "9" * (38 - coltype.precision) + if coltype.precision: + format_str += "0." + "9" * (coltype.precision - 1) + "0" + return f"to_char({value}, '{format_str}')" + + +class Mixin_Schema(AbstractMixin_Schema): + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + table("ALL_TABLES") + .where( + this.OWNER == table_schema, + this.TABLE_NAME.like(like) if like is not None else SKIP, + ) + .select(table_name=this.TABLE_NAME) + ) + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Oracle" + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_INDEXES = True + TYPE_CLASSES: Dict[str, type] = { + "NUMBER": Decimal, + "FLOAT": Float, + # Text + "CHAR": Text, + "NCHAR": Text, + "NVARCHAR2": Text, + "VARCHAR2": Text, + "DATE": Timestamp, + } + ROUNDS_ON_PREC_LOSS = True + PLACEHOLDER_TABLE = "DUAL" + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str): + return f'"{s}"' + + def to_string(self, s: str): + return f"cast({s} as varchar(1024))" + + def offset_limit( + self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + ) -> str: + if offset: + raise NotImplementedError("No support for OFFSET in query") + + return f"FETCH NEXT {limit} ROWS ONLY" + + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) + return f"({joined_exprs})" + + def timestamp_value(self, t: DbTime) -> str: + return "timestamp '%s'" % t.isoformat(" ") + + def random(self) -> str: + return "dbms_random.value" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"DECODE({a}, {b}, 1, 0) = 0" + + def type_repr(self, t) -> str: + try: + return { + str: "VARCHAR(1024)", + }[t] + except KeyError: + return super().type_repr(t) + + def constant_values(self, rows) -> str: + return " UNION ALL ".join( + "SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows + ) + + def explain_as_text(self, query: str) -> str: + raise NotImplementedError("Explain not yet implemented in Oracle") + + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + regexps = { + r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, + r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, + r"TIMESTAMP\((\d)\)": Timestamp, + } + + for m, t_cls in match_regexps(regexps, type_repr): + precision = int(m.group(1)) + return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) + + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) + + def set_timezone_to_utc(self) -> str: + return "ALTER SESSION SET TIME_ZONE = 'UTC'" + + def current_timestamp(self) -> str: + return "LOCALTIMESTAMP" + + +class Oracle(ThreadedDatabase): + dialect = Dialect() + CONNECT_URI_HELP = "oracle://:@/" + CONNECT_URI_PARAMS = ["database?"] + + def __init__(self, *, host, database, thread_count, **kw): + self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) + + self.default_schema = kw.get("user").upper() + + super().__init__(thread_count=thread_count) + + def create_connection(self): + self._oracle = import_oracle() + try: + c = self._oracle.connect(**self.kwargs) + if SESSION_TIME_ZONE: + c.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'") + return c + except Exception as e: + raise ConnectError(*e.args) from e + + def _query_cursor(self, c, sql_code: str): + try: + return super()._query_cursor(c, sql_code) + except self._oracle.DatabaseError as e: + raise QueryError(e) + + def select_table_schema(self, path: DbPath) -> str: + schema, name = self._normalize_table_path(path) + + return ( + f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" + f" FROM ALL_TAB_COLUMNS WHERE table_name = '{name}' AND owner = '{schema}'" + ) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index befe8d44..dec9b9d3 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -1,2 +1,182 @@ -from data_diff.sqeleton.databases.postgresql import PostgresqlDialect as PostgresqlDialect -from data_diff.sqeleton.databases.postgresql import PostgreSQL as PostgreSQL +from typing import List +from data_diff.abcs.database_types import ( + DbPath, + JSON, + Timestamp, + TimestampTZ, + Float, + Decimal, + Integer, + TemporalType, + Native_UUID, + Text, + FractionalType, + Boolean, + Date, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.databases.base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema +from data_diff.databases.base import ( + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + _CHECKSUM_BITSIZE, + TIMESTAMP_PRECISION_POS, + Mixin_RandomSample, +) + +SESSION_TIME_ZONE = None # Changed by the tests + + +@import_helper("postgresql") +def import_postgresql(): + import psycopg2.extras + + psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select) + return psycopg2 + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" + + timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"{value}::decimal(38, {coltype.precision})") + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"{value}::int") + + def normalize_json(self, value: str, _coltype: JSON) -> str: + return f"{value}::text" + + +class PostgresqlDialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "PostgreSQL" + ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_INDEXES = True + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + TYPE_CLASSES = { + # Timestamps + "timestamp with time zone": TimestampTZ, + "timestamp without time zone": Timestamp, + "timestamp": Timestamp, + "date": Date, + # Numbers + "double precision": Float, + "real": Float, + "decimal": Decimal, + "smallint": Integer, + "integer": Integer, + "numeric": Decimal, + "bigint": Integer, + # Text + "character": Text, + "character varying": Text, + "varchar": Text, + "text": Text, + "json": JSON, + "jsonb": JSON, + "uuid": Native_UUID, + "boolean": Boolean, + } + + def quote(self, s: str): + return f'"{s}"' + + def to_string(self, s: str): + return f"{s}::varchar" + + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) + return f"({joined_exprs})" + + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues in PostgreSQL + return super()._convert_db_precision_to_digits(p) - 2 + + def set_timezone_to_utc(self) -> str: + return "SET TIME ZONE 'UTC'" + + def current_timestamp(self) -> str: + return "current_timestamp" + + def type_repr(self, t) -> str: + if isinstance(t, TimestampTZ): + return f"timestamp ({t.precision}) with time zone" + return super().type_repr(t) + + +class PostgreSQL(ThreadedDatabase): + dialect = PostgresqlDialect() + SUPPORTS_UNIQUE_CONSTAINT = True + CONNECT_URI_HELP = "postgresql://:@/" + CONNECT_URI_PARAMS = ["database?"] + + default_schema = "public" + + def __init__(self, *, thread_count, **kw): + self._args = kw + + super().__init__(thread_count=thread_count) + + def create_connection(self): + if not self._args: + self._args["host"] = None # psycopg2 requires 1+ arguments + + pg = import_postgresql() + try: + c = pg.connect(**self._args) + if SESSION_TIME_ZONE: + c.cursor().execute(f"SET TIME ZONE '{SESSION_TIME_ZONE}'") + return c + except pg.OperationalError as e: + raise ConnectError(*e.args) from e + + def select_table_schema(self, path: DbPath) -> str: + database, schema, table = self._normalize_table_path(path) + + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, database) + + return ( + f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def select_table_unique_columns(self, path: DbPath) -> str: + database, schema, table = self._normalize_table_path(path) + + info_schema_path = ["information_schema", "key_column_usage"] + if database: + info_schema_path.insert(0, database) + + return ( + "SELECT column_name " + f"FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return None, self.default_schema, path[0] + elif len(path) == 2: + return None, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index db7c4749..b4c45751 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,2 +1,202 @@ -from data_diff.sqeleton.databases.presto import Dialect as Dialect -from data_diff.sqeleton.databases.presto import Presto as Presto +from functools import partial +import re + +from data_diff.utils import match_regexps + +from data_diff.abcs.database_types import ( + Timestamp, + TimestampTZ, + Integer, + Float, + Text, + FractionalType, + DbPath, + DbTime, + Decimal, + ColType, + ColType_UUID, + TemporalType, + Boolean, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.databases.base import ( + BaseDialect, + Database, + import_helper, + ThreadLocalInterpreter, + Mixin_Schema, + Mixin_RandomSample, +) +from data_diff.databases.base import ( + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + TIMESTAMP_PRECISION_POS, +) + + +def query_cursor(c, sql_code): + c.execute(sql_code) + if sql_code.lower().startswith("select"): + return c.fetchall() + # Required for the query to actually run 🤯 + if re.match(r"(insert|create|truncate|drop|explain)", sql_code, re.IGNORECASE): + return c.fetchone() + + +@import_helper("presto") +def import_presto(): + import prestodb + + return prestodb + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + # Trim doesn't work on CHAR type + return f"TRIM(CAST({value} AS VARCHAR))" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + # TODO rounds + if coltype.rounds: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" + else: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" + + return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"cast ({value} as int)") + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Presto" + ROUNDS_ON_PREC_LOSS = True + TYPE_CLASSES = { + # Timestamps + "timestamp with time zone": TimestampTZ, + "timestamp without time zone": Timestamp, + "timestamp": Timestamp, + # Numbers + "integer": Integer, + "bigint": Integer, + "real": Float, + "double": Float, + # Text + "varchar": Text, + # Boolean + "boolean": Boolean, + } + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN (FORMAT TEXT) {query}" + + def type_repr(self, t) -> str: + if isinstance(t, TimestampTZ): + return f"timestamp with time zone" + + try: + return {float: "REAL"}[t] + except KeyError: + return super().type_repr(t) + + def timestamp_value(self, t: DbTime) -> str: + return f"timestamp '{t.isoformat(' ')}'" + + def quote(self, s: str): + return f'"{s}"' + + def to_string(self, s: str): + return f"cast({s} as varchar)" + + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + _numeric_scale: int = None, + ) -> ColType: + timestamp_regexps = { + r"timestamp\((\d)\)": Timestamp, + r"timestamp\((\d)\) with time zone": TimestampTZ, + } + for m, t_cls in match_regexps(timestamp_regexps, type_repr): + precision = int(m.group(1)) + return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) + + number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} + for m, n_cls in match_regexps(number_regexps, type_repr): + _prec, scale = map(int, m.groups()) + return n_cls(scale) + + string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text} + for m, n_cls in match_regexps(string_regexps, type_repr): + return n_cls() + + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + + def set_timezone_to_utc(self) -> str: + return "SET TIME ZONE '+00:00'" + + def current_timestamp(self) -> str: + return "current_timestamp" + + +class Presto(Database): + dialect = Dialect() + CONNECT_URI_HELP = "presto://@//" + CONNECT_URI_PARAMS = ["catalog", "schema"] + + default_schema = "public" + + def __init__(self, **kw): + prestodb = import_presto() + + if kw.get("schema"): + self.default_schema = kw.get("schema") + + if kw.get("auth") == "basic": # if auth=basic, add basic authenticator for Presto + kw["auth"] = prestodb.auth.BasicAuthentication(kw.pop("user"), kw.pop("password")) + + if "cert" in kw: # if a certificate was specified in URI, verify session with cert + cert = kw.pop("cert") + self._conn = prestodb.dbapi.connect(**kw) + self._conn._http_session.verify = cert + else: + self._conn = prestodb.dbapi.connect(**kw) + + def _query(self, sql_code: str) -> list: + "Uses the standard SQL cursor interface" + c = self._conn.cursor() + + if isinstance(sql_code, ThreadLocalInterpreter): + return sql_code.apply_queries(partial(query_cursor, c)) + + return query_cursor(c, sql_code) + + def close(self): + super().close() + self._conn.close() + + def select_table_schema(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision, NULL as numeric_scale " + "FROM INFORMATION_SCHEMA.COLUMNS " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + @property + def is_autocommit(self) -> bool: + return False diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index 54e9ecc1..d11029c0 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,2 +1,176 @@ -from data_diff.sqeleton.databases.redshift import Dialect as Dialect -from data_diff.sqeleton.databases.redshift import Redshift as Redshift +from typing import List, Dict +from data_diff.abcs.database_types import ( + Float, + JSON, + TemporalType, + FractionalType, + DbPath, + TimestampTZ, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.databases.postgresql import ( + PostgreSQL, + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + TIMESTAMP_PRECISION_POS, + PostgresqlDialect, + Mixin_NormalizeValue, +) + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" + + +class Mixin_NormalizeValue(Mixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"{value}::timestamp(6)" + # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. + secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" + # Get the milliseconds from timestamp. + ms = f"extract(ms from {timestamp})" + # Get the microseconds from timestamp, without the milliseconds! + us = f"extract(us from {timestamp})" + # epoch = Total time since epoch in microseconds. + epoch = f"{secs}*1000000 + {ms}*1000 + {us}" + timestamp6 = ( + f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" + ) + else: + timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"{value}::decimal(38,{coltype.precision})") + + def normalize_json(self, value: str, _coltype: JSON) -> str: + return f"nvl2({value}, json_serialize({value}), NULL)" + + +class Dialect(PostgresqlDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Redshift" + TYPE_CLASSES = { + **PostgresqlDialect.TYPE_CLASSES, + "double": Float, + "real": Float, + "super": JSON, + } + SUPPORTS_INDEXES = False + + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) + return f"({joined_exprs})" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"({a} IS NULL != {b} IS NULL) OR ({a}!={b})" + + def type_repr(self, t) -> str: + if isinstance(t, TimestampTZ): + return f"timestamptz" + return super().type_repr(t) + + +class Redshift(PostgreSQL): + dialect = Dialect() + CONNECT_URI_HELP = "redshift://:@/" + CONNECT_URI_PARAMS = ["database?"] + + def select_table_schema(self, path: DbPath) -> str: + database, schema, table = self._normalize_table_path(path) + + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, database) + + return ( + f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" + ) + + def select_external_table_schema(self, path: DbPath) -> str: + database, schema, table = self._normalize_table_path(path) + + db_clause = "" + if database: + db_clause = f" AND redshift_database_name = '{database.lower()}'" + + return ( + f"""SELECT + columnname AS column_name + , CASE WHEN external_type = 'string' THEN 'varchar' ELSE external_type END AS data_type + , NULL AS datetime_precision + , NULL AS numeric_precision + , NULL AS numeric_scale + FROM svv_external_columns + WHERE tablename = '{table.lower()}' AND schemaname = '{schema.lower()}' + """ + + db_clause + ) + + def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]: + rows = self.query(self.select_external_table_schema(path), list) + if not rows: + raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") + + d = {r[0]: r for r in rows} + assert len(d) == len(rows) + return d + + def select_view_columns(self, path: DbPath) -> str: + _, schema, table = self._normalize_table_path(path) + + return """select * from pg_get_cols('{}.{}') + cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int) + """.format( + schema, table + ) + + def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]: + rows = self.query(self.select_view_columns(path), list) + + if not rows: + raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns") + + output = {} + for r in rows: + col_name = r[2] + type_info = r[3].split("(") + base_type = type_info[0] + precision = None + scale = None + + if len(type_info) > 1: + if base_type == "numeric": + precision, scale = type_info[1][:-1].split(",") + precision = int(precision) + scale = int(scale) + + out = [col_name, base_type, None, precision, scale] + output[col_name] = tuple(out) + + return output + + def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + try: + return super().query_table_schema(path) + except RuntimeError: + try: + return self.query_external_table_schema(path) + except RuntimeError: + return self.query_pg_get_cols(path) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return None, self.default_schema, path[0] + elif len(path) == 2: + return None, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 2029a73d..3a558425 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,2 +1,228 @@ -from data_diff.sqeleton.databases.snowflake import Dialect as Dialect -from data_diff.sqeleton.databases.snowflake import Snowflake as Snowflake +from typing import Union, List +import logging + +from data_diff.abcs.database_types import ( + Timestamp, + TimestampTZ, + Decimal, + Float, + Text, + FractionalType, + TemporalType, + DbPath, + Boolean, + Date, +) +from data_diff.abcs.mixins import ( + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, + AbstractMixin_Schema, + AbstractMixin_TimeTravel, +) +from data_diff.abcs.compiler import Compilable +from data_diff.queries.api import table, this, SKIP, code +from data_diff.databases.base import ( + BaseDialect, + ConnectError, + Database, + import_helper, + CHECKSUM_MASK, + ThreadLocalInterpreter, + Mixin_RandomSample, +) + + +@import_helper("snowflake") +def import_snowflake(): + import snowflake.connector + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.backends import default_backend + + return snowflake, serialization, default_backend + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))" + else: + timestamp = f"cast(convert_timezone('UTC', {value}) as timestamp({coltype.precision}))" + + return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"{value}::int") + + +class Mixin_Schema(AbstractMixin_Schema): + def table_information(self) -> Compilable: + return table("INFORMATION_SCHEMA", "TABLES") + + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + self.table_information() + .where( + this.TABLE_SCHEMA == table_schema, + this.TABLE_NAME.like(like) if like is not None else SKIP, + this.TABLE_TYPE == "BASE TABLE", + ) + .select(table_name=this.TABLE_NAME) + ) + + +class Mixin_TimeTravel(AbstractMixin_TimeTravel): + def time_travel( + self, + table: Compilable, + before: bool = False, + timestamp: Compilable = None, + offset: Compilable = None, + statement: Compilable = None, + ) -> Compilable: + at_or_before = "AT" if before else "BEFORE" + if timestamp is not None: + assert offset is None and statement is None + key = "timestamp" + value = timestamp + elif offset is not None: + assert statement is None + key = "offset" + value = offset + else: + assert statement is not None + key = "statement" + value = statement + + return code(f"{{table}} {at_or_before}({key} => {{value}})", table=table, value=value) + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Snowflake" + ROUNDS_ON_PREC_LOSS = False + TYPE_CLASSES = { + # Timestamps + "TIMESTAMP_NTZ": Timestamp, + "TIMESTAMP_LTZ": Timestamp, + "TIMESTAMP_TZ": TimestampTZ, + "DATE": Date, + # Numbers + "NUMBER": Decimal, + "FLOAT": Float, + # Text + "TEXT": Text, + # Boolean + "BOOLEAN": Boolean, + } + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN USING TEXT {query}" + + def quote(self, s: str): + return f'"{s}"' + + def to_string(self, s: str): + return f"cast({s} as string)" + + def table_information(self) -> Compilable: + return table("INFORMATION_SCHEMA", "TABLES") + + def set_timezone_to_utc(self) -> str: + return "ALTER SESSION SET TIMEZONE = 'UTC'" + + def optimizer_hints(self, hints: str) -> str: + raise NotImplementedError("Optimizer hints not yet implemented in snowflake") + + def type_repr(self, t) -> str: + if isinstance(t, TimestampTZ): + return f"timestamp_tz({t.precision})" + return super().type_repr(t) + + +class Snowflake(Database): + dialect = Dialect() + CONNECT_URI_HELP = "snowflake://:@//?warehouse=" + CONNECT_URI_PARAMS = ["database", "schema"] + CONNECT_URI_KWPARAMS = ["warehouse"] + + def __init__(self, *, schema: str, **kw): + snowflake, serialization, default_backend = import_snowflake() + logging.getLogger("snowflake.connector").setLevel(logging.WARNING) + + # Ignore the error: snowflake.connector.network.RetryRequest: could not find io module state + # It's a known issue: https://github.com/snowflakedb/snowflake-connector-python/issues/145 + logging.getLogger("snowflake.connector.network").disabled = True + + assert '"' not in schema, "Schema name should not contain quotes!" + # If a private key is used, read it from the specified path and pass it as "private_key" to the connector. + if "key" in kw: + with open(kw.get("key"), "rb") as key: + if "password" in kw: + raise ConnectError("Cannot use password and key at the same time") + if kw.get("private_key_passphrase"): + encoded_passphrase = kw.get("private_key_passphrase").encode() + else: + encoded_passphrase = None + p_key = serialization.load_pem_private_key( + key.read(), + password=encoded_passphrase, + backend=default_backend(), + ) + + kw["private_key"] = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + self._conn = snowflake.connector.connect(schema=f'"{schema}"', **kw) + + self.default_schema = schema + + def close(self): + super().close() + self._conn.close() + + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): + "Uses the standard SQL cursor interface" + return self._query_conn(self._conn, sql_code) + + def select_table_schema(self, path: DbPath) -> str: + """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" + database, schema, name = self._normalize_table_path(path) + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, database) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + f"FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + ) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return None, self.default_schema, path[0] + elif len(path) == 2: + return None, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) + + @property + def is_autocommit(self) -> bool: + return True + + def query_table_unique_columns(self, path: DbPath) -> List[str]: + return [] diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index e60bfb90..e2095758 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,2 +1,48 @@ -from data_diff.sqeleton.databases.trino import Dialect as Dialect -from data_diff.sqeleton.databases.trino import Trino as Trino +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.abcs.database_types import TemporalType, ColType_UUID +from data_diff.databases import presto +from data_diff.databases.base import import_helper +from data_diff.databases.base import TIMESTAMP_PRECISION_POS + + +@import_helper("trino") +def import_trino(): + import trino + + return trino + + +Mixin_MD5 = presto.Mixin_MD5 + + +class Mixin_NormalizeValue(presto.Mixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + s = f"date_format(cast({value} as timestamp({coltype.precision})), '%Y-%m-%d %H:%i:%S.%f')" + else: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" + + return ( + f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')" + ) + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + return f"TRIM({value})" + + +class Dialect(presto.Dialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Trino" + + +class Trino(presto.Presto): + dialect = Dialect() + CONNECT_URI_HELP = "trino://@//" + CONNECT_URI_PARAMS = ["catalog", "schema"] + + def __init__(self, **kw): + trino = import_trino() + + if kw.get("schema"): + self.default_schema = kw.get("schema") + + self._conn = trino.dbapi.connect(**kw) diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 83675939..e8fe9ec2 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -1,2 +1,181 @@ -from data_diff.sqeleton.databases.vertica import Dialect as Dialect -from data_diff.sqeleton.databases.vertica import Vertica as Vertica +from typing import List + +from data_diff.utils import match_regexps +from data_diff.databases.base import ( + CHECKSUM_HEXDIGITS, + MD5_HEXDIGITS, + TIMESTAMP_PRECISION_POS, + BaseDialect, + ConnectError, + DbPath, + ColType, + ThreadedDatabase, + import_helper, + Mixin_RandomSample, +) +from data_diff.abcs.database_types import ( + Decimal, + Float, + FractionalType, + Integer, + TemporalType, + Text, + Timestamp, + TimestampTZ, + Boolean, + ColType_UUID, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema +from data_diff.abcs.compiler import Compilable +from data_diff.queries.api import table, this, SKIP + + +@import_helper("vertica") +def import_vertica(): + import vertica_python + + return vertica_python + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" + + timestamp6 = f"TO_CHAR({value}::TIMESTAMP(6), 'YYYY-MM-DD HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"CAST({value} AS NUMERIC(38, {coltype.precision}))") + + def normalize_uuid(self, value: str, _coltype: ColType_UUID) -> str: + # Trim doesn't work on CHAR type + return f"TRIM(CAST({value} AS VARCHAR))" + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"cast ({value} as int)") + + +class Mixin_Schema(AbstractMixin_Schema): + def table_information(self) -> Compilable: + return table("v_catalog", "tables") + + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + self.table_information() + .where( + this.table_schema == table_schema, + this.table_name.like(like) if like is not None else SKIP, + ) + .select(this.table_name) + ) + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Vertica" + ROUNDS_ON_PREC_LOSS = True + + TYPE_CLASSES = { + # Timestamps + "timestamp": Timestamp, + "timestamptz": TimestampTZ, + # Numbers + "numeric": Decimal, + "int": Integer, + "float": Float, + # Text + "char": Text, + "varchar": Text, + # Boolean + "boolean": Boolean, + } + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str): + return f'"{s}"' + + def concat(self, items: List[str]) -> str: + return " || ".join(items) + + def to_string(self, s: str) -> str: + return f"CAST({s} AS VARCHAR)" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"not ({a} <=> {b})" + + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + timestamp_regexps = { + r"timestamp\(?(\d?)\)?": Timestamp, + r"timestamptz\(?(\d?)\)?": TimestampTZ, + } + for m, t_cls in match_regexps(timestamp_regexps, type_repr): + precision = int(m.group(1)) if m.group(1) else 6 + return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) + + number_regexps = { + r"numeric\((\d+),(\d+)\)": Decimal, + } + for m, n_cls in match_regexps(number_regexps, type_repr): + _prec, scale = map(int, m.groups()) + return n_cls(scale) + + string_regexps = { + r"varchar\((\d+)\)": Text, + r"char\((\d+)\)": Text, + } + for m, n_cls in match_regexps(string_regexps, type_repr): + return n_cls() + + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + + def set_timezone_to_utc(self) -> str: + return "SET TIME ZONE TO 'UTC'" + + def current_timestamp(self) -> str: + return "current_timestamp(6)" + + +class Vertica(ThreadedDatabase): + dialect = Dialect() + CONNECT_URI_HELP = "vertica://:@/" + CONNECT_URI_PARAMS = ["database?"] + + default_schema = "public" + + def __init__(self, *, thread_count, **kw): + self._args = kw + self._args["AUTOCOMMIT"] = False + + super().__init__(thread_count=thread_count) + + def create_connection(self): + vertica = import_vertica() + try: + c = vertica.connect(**self._args) + return c + except vertica.errors.ConnectionError as e: + raise ConnectError(*e.args) from e + + def select_table_schema(self, path: DbPath) -> str: + schema, name = self._normalize_table_path(path) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + "FROM V_CATALOG.COLUMNS " + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + ) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 519018f6..08c18391 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,7 +1,6 @@ """Provides classes for performing a table diff """ -import re import time from abc import ABC, abstractmethod from dataclasses import field @@ -14,12 +13,11 @@ from runtype import dataclass from data_diff.info_tree import InfoTree, SegmentInfo - from data_diff.utils import dbt_diff_string_template, run_as_daemon, safezip, getLogger, truncate_error, Vector from data_diff.thread_utils import ThreadedYielder from data_diff.table_segment import TableSegment, create_mesh_from_points from data_diff.tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled -from data_diff.sqeleton.abcs import IKey +from data_diff.abcs.database_types import IKey logger = getLogger(__name__) diff --git a/data_diff/format.py b/data_diff/format.py index bfeb0b1e..8a515e1b 100644 --- a/data_diff/format.py +++ b/data_diff/format.py @@ -4,7 +4,7 @@ from runtype import dataclass from data_diff.diff_tables import DiffResultWrapper -from data_diff.sqeleton.abcs.database_types import ( +from data_diff.abcs.database_types import ( JSON, Boolean, ColType, diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 65072ed4..3fc030ec 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -8,13 +8,11 @@ from runtype import dataclass -from data_diff.sqeleton.abcs import ColType_UUID, NumericType, PrecisionType, StringType, Boolean, JSON - +from data_diff.abcs.database_types import ColType_UUID, NumericType, PrecisionType, StringType, Boolean, JSON from data_diff.info_tree import InfoTree from data_diff.utils import safezip, diffs_are_equiv_jsons from data_diff.thread_utils import ThreadedYielder from data_diff.table_segment import TableSegment - from data_diff.diff_tables import TableDiffer BENCHMARK = os.environ.get("BENCHMARK", False) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 26ba1e0e..667786a7 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -10,14 +10,11 @@ from runtype import dataclass -from data_diff.sqeleton.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake, DbPath -from data_diff.sqeleton.abcs import NumericType -from data_diff.sqeleton.queries import ( +from data_diff.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake +from data_diff.abcs.database_types import NumericType, DbPath +from data_diff.queries.api import ( table, sum_, - min_, - max_, - avg, and_, if_, or_, @@ -28,11 +25,9 @@ when, Compiler, ) -from data_diff.sqeleton.queries.ast_classes import Concat, Count, Expr, Func, Random, TablePath, Code, ITable -from data_diff.sqeleton.queries.extras import NormalizeAsString - +from data_diff.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code, ITable +from data_diff.queries.extras import NormalizeAsString from data_diff.info_tree import InfoTree - from data_diff.query_utils import append_to_table, drop_table from data_diff.utils import safezip from data_diff.table_segment import TableSegment diff --git a/data_diff/queries/__init__.py b/data_diff/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/data_diff/sqeleton/queries/api.py b/data_diff/queries/api.py similarity index 97% rename from data_diff/sqeleton/queries/api.py rename to data_diff/queries/api.py index 97a6b00c..82786871 100644 --- a/data_diff/sqeleton/queries/api.py +++ b/data_diff/queries/api.py @@ -1,8 +1,6 @@ -from typing import Optional - from data_diff.utils import CaseAwareMapping, CaseSensitiveDict -from data_diff.sqeleton.queries.ast_classes import * -from data_diff.sqeleton.queries.base import args_as_tuple +from data_diff.queries.ast_classes import * +from data_diff.queries.base import args_as_tuple this = This() diff --git a/data_diff/sqeleton/queries/ast_classes.py b/data_diff/queries/ast_classes.py similarity index 98% rename from data_diff/sqeleton/queries/ast_classes.py rename to data_diff/queries/ast_classes.py index 6ee25ecb..70cb355f 100644 --- a/data_diff/sqeleton/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,18 +1,19 @@ from dataclasses import field from datetime import datetime -from typing import Any, Generator, List, Optional, Sequence, Type, Union, Dict +from typing import Any, Generator, List, Optional, Sequence, Union, Dict from runtype import dataclass from typing_extensions import Self from data_diff.utils import join_iter, ArithString -from data_diff.sqeleton.abcs import Compilable -from data_diff.sqeleton.abcs.database_types import AbstractTable -from data_diff.sqeleton.abcs.mixins import AbstractMixin_Regex, AbstractMixin_TimeTravel -from data_diff.sqeleton.schema import Schema - -from data_diff.sqeleton.queries.compiler import Compiler, cv_params, Root, CompileError -from data_diff.sqeleton.queries.base import SKIP, DbPath, args_as_tuple, SqeletonError +from data_diff.abcs.compiler import Compilable +from data_diff.abcs.database_types import AbstractTable +from data_diff.abcs.mixins import AbstractMixin_Regex, AbstractMixin_TimeTravel +from data_diff.schema import Schema + +from data_diff.queries.compiler import Compiler, cv_params, Root, CompileError +from data_diff.queries.base import SKIP, args_as_tuple, SqeletonError +from data_diff.abcs.database_types import DbPath class QueryBuilderError(SqeletonError): diff --git a/data_diff/sqeleton/queries/base.py b/data_diff/queries/base.py similarity index 76% rename from data_diff/sqeleton/queries/base.py rename to data_diff/queries/base.py index d229e175..205c2211 100644 --- a/data_diff/sqeleton/queries/base.py +++ b/data_diff/queries/base.py @@ -1,8 +1,5 @@ from typing import Generator -from data_diff.sqeleton.abcs import DbPath, DbKey -from data_diff.sqeleton.schema import Schema - class _SKIP: def __repr__(self): diff --git a/data_diff/sqeleton/queries/compiler.py b/data_diff/queries/compiler.py similarity index 92% rename from data_diff/sqeleton/queries/compiler.py rename to data_diff/queries/compiler.py index 0aaf8dd6..e6246236 100644 --- a/data_diff/sqeleton/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -7,7 +7,8 @@ from typing_extensions import Self from data_diff.utils import ArithString -from data_diff.sqeleton.abcs import AbstractDatabase, AbstractDialect, DbPath, AbstractCompiler, Compilable +from data_diff.abcs.database_types import AbstractDatabase, AbstractDialect, DbPath +from data_diff.abcs.compiler import AbstractCompiler, Compilable import contextvars @@ -44,7 +45,7 @@ def compile(self, elem, params=None) -> str: cv_params.set(params) if self.root and isinstance(elem, Compilable) and not isinstance(elem, Root): - from data_diff.sqeleton.queries.ast_classes import Select + from data_diff.queries.ast_classes import Select elem = Select(columns=[elem]) diff --git a/data_diff/sqeleton/queries/extras.py b/data_diff/queries/extras.py similarity index 89% rename from data_diff/sqeleton/queries/extras.py rename to data_diff/queries/extras.py index 4a1d58c1..8e916601 100644 --- a/data_diff/sqeleton/queries/extras.py +++ b/data_diff/queries/extras.py @@ -3,10 +3,10 @@ from typing import Callable, Sequence from runtype import dataclass -from data_diff.sqeleton.abcs.database_types import ColType, Native_UUID +from data_diff.abcs.database_types import ColType, Native_UUID -from data_diff.sqeleton.queries.compiler import Compiler -from data_diff.sqeleton.queries.ast_classes import Expr, ExprNode, Concat, Code +from data_diff.queries.compiler import Compiler +from data_diff.queries.ast_classes import Expr, ExprNode, Concat, Code @dataclass diff --git a/data_diff/query_utils.py b/data_diff/query_utils.py index 4b963039..a4887728 100644 --- a/data_diff/query_utils.py +++ b/data_diff/query_utils.py @@ -2,8 +2,10 @@ from contextlib import suppress -from data_diff.sqeleton.databases import DbPath, QueryError, Oracle -from data_diff.sqeleton.queries import table, commit, Expr +from data_diff.abcs.database_types import DbPath +from data_diff.databases.base import QueryError +from data_diff.databases.oracle import Oracle +from data_diff.queries.api import table, commit, Expr def _drop_table_oracle(name: DbPath): diff --git a/data_diff/sqeleton/schema.py b/data_diff/schema.py similarity index 91% rename from data_diff/sqeleton/schema.py rename to data_diff/schema.py index 35ebe8ef..847bbf23 100644 --- a/data_diff/sqeleton/schema.py +++ b/data_diff/schema.py @@ -1,7 +1,7 @@ import logging from data_diff.utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict -from data_diff.sqeleton.abcs import AbstractDatabase, DbPath +from data_diff.abcs.database_types import AbstractDatabase, DbPath logger = logging.getLogger("schema") diff --git a/data_diff/sqeleton/__init__.py b/data_diff/sqeleton/__init__.py deleted file mode 100644 index b6e32cc2..00000000 --- a/data_diff/sqeleton/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from data_diff.sqeleton.databases import connect -from data_diff.sqeleton.queries import table, this, SKIP, code diff --git a/data_diff/sqeleton/abcs/__init__.py b/data_diff/sqeleton/abcs/__init__.py deleted file mode 100644 index 5654ad16..00000000 --- a/data_diff/sqeleton/abcs/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from data_diff.sqeleton.abcs.database_types import ( - AbstractDatabase, - AbstractDialect, - DbKey, - DbPath, - DbTime, - IKey, - ColType_UUID, - NumericType, - PrecisionType, - StringType, - Boolean, - JSON, -) -from data_diff.sqeleton.abcs.compiler import AbstractCompiler, Compilable diff --git a/data_diff/sqeleton/databases/__init__.py b/data_diff/sqeleton/databases/__init__.py deleted file mode 100644 index 70af2412..00000000 --- a/data_diff/sqeleton/databases/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -from data_diff.sqeleton.databases.base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - QueryError, - ConnectError, - BaseDialect, - Database, -) -from data_diff.sqeleton.abcs import DbPath, DbKey, DbTime -from data_diff.sqeleton.databases._connect import Connect - -from data_diff.sqeleton.databases.postgresql import PostgreSQL -from data_diff.sqeleton.databases.mysql import MySQL -from data_diff.sqeleton.databases.oracle import Oracle -from data_diff.sqeleton.databases.snowflake import Snowflake -from data_diff.sqeleton.databases.bigquery import BigQuery -from data_diff.sqeleton.databases.redshift import Redshift -from data_diff.sqeleton.databases.presto import Presto -from data_diff.sqeleton.databases.databricks import Databricks -from data_diff.sqeleton.databases.trino import Trino -from data_diff.sqeleton.databases.clickhouse import Clickhouse -from data_diff.sqeleton.databases.vertica import Vertica -from data_diff.sqeleton.databases.duckdb import DuckDB -from data_diff.sqeleton.databases.mssql import MsSQL - -connect = Connect() diff --git a/data_diff/sqeleton/databases/_connect.py b/data_diff/sqeleton/databases/_connect.py deleted file mode 100644 index ad152dda..00000000 --- a/data_diff/sqeleton/databases/_connect.py +++ /dev/null @@ -1,283 +0,0 @@ -from typing import Hashable, MutableMapping, Type, Optional, Union, Dict -from itertools import zip_longest -from contextlib import suppress -import weakref -import dsnparse -import toml - -from runtype import dataclass -from typing_extensions import Self - -from data_diff.sqeleton.abcs.mixins import AbstractMixin -from data_diff.sqeleton.databases.base import Database, ThreadedDatabase -from data_diff.sqeleton.databases.postgresql import PostgreSQL -from data_diff.sqeleton.databases.mysql import MySQL -from data_diff.sqeleton.databases.oracle import Oracle -from data_diff.sqeleton.databases.snowflake import Snowflake -from data_diff.sqeleton.databases.bigquery import BigQuery -from data_diff.sqeleton.databases.redshift import Redshift -from data_diff.sqeleton.databases.presto import Presto -from data_diff.sqeleton.databases.databricks import Databricks -from data_diff.sqeleton.databases.trino import Trino -from data_diff.sqeleton.databases.clickhouse import Clickhouse -from data_diff.sqeleton.databases.vertica import Vertica -from data_diff.sqeleton.databases.duckdb import DuckDB -from data_diff.sqeleton.databases.mssql import MsSQL - - -@dataclass -class MatchUriPath: - database_cls: Type[Database] - - def match_path(self, dsn): - help_str = self.database_cls.CONNECT_URI_HELP - params = self.database_cls.CONNECT_URI_PARAMS - kwparams = self.database_cls.CONNECT_URI_KWPARAMS - - dsn_dict = dict(dsn.query) - matches = {} - for param, arg in zip_longest(params, dsn.paths): - if param is None: - raise ValueError(f"Too many parts to path. Expected format: {help_str}") - - optional = param.endswith("?") - param = param.rstrip("?") - - if arg is None: - try: - arg = dsn_dict.pop(param) - except KeyError: - if not optional: - raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") - - arg = None - - assert param and param not in matches - matches[param] = arg - - for param in kwparams: - try: - arg = dsn_dict.pop(param) - except KeyError: - raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") - - assert param and arg and param not in matches, (param, arg, matches.keys()) - matches[param] = arg - - for param, value in dsn_dict.items(): - if param in matches: - raise ValueError( - f"Parameter '{param}' already provided as positional argument. Expected format: {help_str}" - ) - - matches[param] = value - - return matches - - -DATABASE_BY_SCHEME = { - "postgresql": PostgreSQL, - "mysql": MySQL, - "oracle": Oracle, - "redshift": Redshift, - "snowflake": Snowflake, - "presto": Presto, - "bigquery": BigQuery, - "databricks": Databricks, - "duckdb": DuckDB, - "trino": Trino, - "clickhouse": Clickhouse, - "vertica": Vertica, - "mssql": MsSQL, -} - - -class Connect: - """Provides methods for connecting to a supported database using a URL or connection dict.""" - conn_cache: MutableMapping[Hashable, Database] - - def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): - self.database_by_scheme = database_by_scheme - self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()} - self.conn_cache = weakref.WeakValueDictionary() - - def for_databases(self, *dbs) -> Self: - database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs} - return type(self)(database_by_scheme) - - def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs) -> Database: - """Connect to the given database uri - - thread_count determines the max number of worker threads per database, - if relevant. None means no limit. - - Parameters: - db_uri (str): The URI for the database to connect - thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) - - Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. - - Supported schemes: - - postgresql - - mysql - - oracle - - snowflake - - bigquery - - redshift - - presto - - databricks - - trino - - clickhouse - - vertica - - duckdb - """ - - dsn = dsnparse.parse(db_uri) - if len(dsn.schemes) > 1: - raise NotImplementedError("No support for multiple schemes") - (scheme,) = dsn.schemes - - if scheme == "toml": - toml_path = dsn.path or dsn.host - database = dsn.fragment - if not database: - raise ValueError("Must specify a database name, e.g. 'toml://path#database'. ") - with open(toml_path) as f: - config = toml.load(f) - try: - conn_dict = config["database"][database] - except KeyError: - raise ValueError(f"Cannot find database config named '{database}'.") - return self.connect_with_dict(conn_dict, thread_count, **kwargs) - - try: - matcher = self.match_uri_path[scheme] - except KeyError: - raise NotImplementedError(f"Scheme '{scheme}' currently not supported") - - cls = matcher.database_cls - - if scheme == "databricks": - assert not dsn.user - kw = {} - kw["access_token"] = dsn.password - kw["http_path"] = dsn.path - kw["server_hostname"] = dsn.host - kw.update(dsn.query) - elif scheme == "duckdb": - kw = {} - kw["filepath"] = dsn.dbname - kw["dbname"] = dsn.user - else: - kw = matcher.match_path(dsn) - - if scheme == "bigquery": - kw["project"] = dsn.host - return cls(**kw, **kwargs) - - if scheme == "snowflake": - kw["account"] = dsn.host - assert not dsn.port - kw["user"] = dsn.user - kw["password"] = dsn.password - else: - if scheme == "oracle": - kw["host"] = dsn.hostloc - else: - kw["host"] = dsn.host - kw["port"] = dsn.port - kw["user"] = dsn.user - if dsn.password: - kw["password"] = dsn.password - - kw = {k: v for k, v in kw.items() if v is not None} - - if issubclass(cls, ThreadedDatabase): - db = cls(thread_count=thread_count, **kw, **kwargs) - else: - db = cls(**kw, **kwargs) - - return self._connection_created(db) - - def connect_with_dict(self, d, thread_count, **kwargs): - d = dict(d) - driver = d.pop("driver") - try: - matcher = self.match_uri_path[driver] - except KeyError: - raise NotImplementedError(f"Driver '{driver}' currently not supported") - - cls = matcher.database_cls - if issubclass(cls, ThreadedDatabase): - db = cls(thread_count=thread_count, **d, **kwargs) - else: - db = cls(**d, **kwargs) - - return self._connection_created(db) - - def _connection_created(self, db): - "Nop function to be overridden by subclasses." - return db - - def __call__( - self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True, **kwargs - ) -> Database: - """Connect to a database using the given database configuration. - - Configuration can be given either as a URI string, or as a dict of {option: value}. - - The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf. - - thread_count determines the max number of worker threads per database, - if relevant. None means no limit. - - Parameters: - db_conf (str | dict): The configuration for the database to connect. URI or dict. - thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) - shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True) - bigquery_credentials (google.oauth2.credentials.Credentials): Custom Google oAuth2 credential for BigQuery. - (default: None) - - Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. - - Supported drivers: - - postgresql - - mysql - - oracle - - snowflake - - bigquery - - redshift - - presto - - databricks - - trino - - clickhouse - - vertica - - Example: - >>> connect("mysql://localhost/db") - - >>> connect({"driver": "mysql", "host": "localhost", "database": "db"}) - - """ - cache_key = self.__make_cache_key(db_conf) - if shared: - with suppress(KeyError): - conn = self.conn_cache[cache_key] - if not conn.is_closed: - return conn - - if isinstance(db_conf, str): - conn = self.connect_to_uri(db_conf, thread_count, **kwargs) - elif isinstance(db_conf, dict): - conn = self.connect_with_dict(db_conf, thread_count, **kwargs) - else: - raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.") - - if shared: - self.conn_cache[cache_key] = conn - return conn - - def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable: - if isinstance(db_conf, dict): - return tuple(db_conf.items()) - return db_conf diff --git a/data_diff/sqeleton/databases/bigquery.py b/data_diff/sqeleton/databases/bigquery.py deleted file mode 100644 index bdf9e07d..00000000 --- a/data_diff/sqeleton/databases/bigquery.py +++ /dev/null @@ -1,297 +0,0 @@ -import re -from typing import Any, List, Union -from data_diff.sqeleton.abcs.database_types import ( - ColType, - Array, - JSON, - Struct, - Timestamp, - Datetime, - Integer, - Decimal, - Float, - Text, - DbPath, - FractionalType, - TemporalType, - Boolean, - UnknownColType, -) -from data_diff.sqeleton.abcs.mixins import ( - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, - AbstractMixin_Schema, - AbstractMixin_TimeTravel, -) -from data_diff.sqeleton.abcs import Compilable -from data_diff.sqeleton.queries import this, table, SKIP, code -from data_diff.sqeleton.databases.base import ( - BaseDialect, - Database, - import_helper, - parse_table_name, - ConnectError, - apply_query, - QueryResult, -) -from data_diff.sqeleton.databases.base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter, Mixin_RandomSample - - -@import_helper(text="Please install BigQuery and configure your google-cloud access.") -def import_bigquery(): - from google.cloud import bigquery - - return bigquery - - -def import_bigquery_service_account(): - from google.oauth2 import service_account - - return service_account - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" - - if coltype.precision == 0: - return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000', {value})" - elif coltype.precision == 6: - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - - timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return f"format('%.{coltype.precision}f', {value})" - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"cast({value} as int)") - - def normalize_json(self, value: str, _coltype: JSON) -> str: - # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: - # Got error: 400 Grouping is not defined for arguments of type ARRAY at … - # So we do the best effort and compare it as strings, hoping that the JSON forms - # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. - return f"to_json_string({value})" - - def normalize_array(self, value: str, _coltype: Array) -> str: - # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: - # Got error: 400 Grouping is not defined for arguments of type ARRAY at … - # So we do the best effort and compare it as strings, hoping that the JSON forms - # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. - return f"to_json_string({value})" - - def normalize_struct(self, value: str, _coltype: Struct) -> str: - # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: - # Got error: 400 Grouping is not defined for arguments of type ARRAY at … - # So we do the best effort and compare it as strings, hoping that the JSON forms - # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. - return f"to_json_string({value})" - - -class Mixin_Schema(AbstractMixin_Schema): - def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: - return ( - table(table_schema, "INFORMATION_SCHEMA", "TABLES") - .where( - this.table_schema == table_schema, - this.table_name.like(like) if like is not None else SKIP, - this.table_type == "BASE TABLE", - ) - .select(this.table_name) - ) - - -class Mixin_TimeTravel(AbstractMixin_TimeTravel): - def time_travel( - self, - table: Compilable, - before: bool = False, - timestamp: Compilable = None, - offset: Compilable = None, - statement: Compilable = None, - ) -> Compilable: - if before: - raise NotImplementedError("before=True not supported for BigQuery time-travel") - - if statement is not None: - raise NotImplementedError("BigQuery time-travel doesn't support querying by statement id") - - if timestamp is not None: - assert offset is None - return code("{table} FOR SYSTEM_TIME AS OF {timestamp}", table=table, timestamp=timestamp) - - assert offset is not None - return code( - "{table} FOR SYSTEM_TIME AS OF TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL {offset} HOUR);", - table=table, - offset=offset, - ) - - -class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "BigQuery" - ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation - TYPE_CLASSES = { - # Dates - "TIMESTAMP": Timestamp, - "DATETIME": Datetime, - # Numbers - "INT64": Integer, - "INT32": Integer, - "NUMERIC": Decimal, - "BIGNUMERIC": Decimal, - "FLOAT64": Float, - "FLOAT32": Float, - "STRING": Text, - "BOOL": Boolean, - "JSON": JSON, - } - TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>") - TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>") - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} - - def random(self) -> str: - return "RAND()" - - def quote(self, s: str): - return f"`{s}`" - - def to_string(self, s: str): - return f"cast({s} as string)" - - def type_repr(self, t) -> str: - try: - return {str: "STRING", float: "FLOAT64"}[t] - except KeyError: - return super().type_repr(t) - - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - *args: Any, # pass-through args - **kwargs: Any, # pass-through args - ) -> ColType: - col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs) - if isinstance(col_type, UnknownColType): - m = self.TYPE_ARRAY_RE.fullmatch(type_repr) - if m: - item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs) - col_type = Array(item_type=item_type) - - # We currently ignore structs' structure, but later can parse it too. Examples: - # - STRUCT (unnamed) - # - STRUCT (named) - # - STRUCT> (with complex fields) - # - STRUCT> (nested) - m = self.TYPE_STRUCT_RE.fullmatch(type_repr) - if m: - col_type = Struct() - - return col_type - - def to_comparable(self, value: str, coltype: ColType) -> str: - """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" - if isinstance(coltype, (JSON, Array, Struct)): - return self.normalize_value_by_type(value, coltype) - else: - return super().to_comparable(value, coltype) - - def set_timezone_to_utc(self) -> str: - raise NotImplementedError() - - -class BigQuery(Database): - CONNECT_URI_HELP = "bigquery:///" - CONNECT_URI_PARAMS = ["dataset"] - dialect = Dialect() - - def __init__(self, project, *, dataset, bigquery_credentials=None, **kw): - credentials = bigquery_credentials - bigquery = import_bigquery() - - keyfile = kw.pop("keyfile", None) - if keyfile: - bigquery_service_account = import_bigquery_service_account() - credentials = bigquery_service_account.Credentials.from_service_account_file( - keyfile, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - - self._client = bigquery.Client(project=project, credentials=credentials, **kw) - self.project = project - self.dataset = dataset - - self.default_schema = dataset - - def _normalize_returned_value(self, value): - if isinstance(value, bytes): - return value.decode() - return value - - def _query_atom(self, sql_code: str): - from google.cloud import bigquery - - try: - result = self._client.query(sql_code).result() - columns = [c.name for c in result.schema] - rows = list(result) - except Exception as e: - msg = "Exception when trying to execute SQL code:\n %s\n\nGot error: %s" - raise ConnectError(msg % (sql_code, e)) - - if rows and isinstance(rows[0], bigquery.table.Row): - rows = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in rows] - return QueryResult(rows, columns) - - def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult: - return apply_query(self._query_atom, sql_code) - - def close(self): - super().close() - self._client.close() - - def select_table_schema(self, path: DbPath) -> str: - project, schema, name = self._normalize_table_path(path) - return ( - "SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale " - f"FROM `{project}`.`{schema}`.INFORMATION_SCHEMA.COLUMNS " - f"WHERE table_name = '{name}' AND table_schema = '{schema}'" - ) - - def query_table_unique_columns(self, path: DbPath) -> List[str]: - return [] - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 0: - raise ValueError(f"{self.name}: Bad table path for {self}: ()") - elif len(path) == 1: - return (self.project, self.default_schema, path[0]) - elif len(path) == 2: - return (self.project,) + path - elif len(path) == 3: - return path - else: - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: [project.]schema.table" - ) - - def parse_table_name(self, name: str) -> DbPath: - path = parse_table_name(name) - return tuple(i for i in self._normalize_table_path(path) if i is not None) - - @property - def is_autocommit(self) -> bool: - return True diff --git a/data_diff/sqeleton/databases/clickhouse.py b/data_diff/sqeleton/databases/clickhouse.py deleted file mode 100644 index 578bb1e5..00000000 --- a/data_diff/sqeleton/databases/clickhouse.py +++ /dev/null @@ -1,196 +0,0 @@ -from typing import Optional, Type - -from data_diff.sqeleton.databases.base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - TIMESTAMP_PRECISION_POS, - BaseDialect, - ThreadedDatabase, - import_helper, - ConnectError, - Mixin_RandomSample, -) -from data_diff.sqeleton.abcs.database_types import ( - ColType, - Decimal, - Float, - Integer, - FractionalType, - Native_UUID, - TemporalType, - Text, - Timestamp, - Boolean, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue - -# https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings/#default-database -DEFAULT_DATABASE = "default" - - -@import_helper("clickhouse") -def import_clickhouse(): - import clickhouse_driver - - return clickhouse_driver - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS - return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_number(self, value: str, coltype: FractionalType) -> str: - # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. - # For example: - # select toString(toDecimal128(1.10, 2)); -- the result is 1.1 - # select toString(toDecimal128(1.00, 2)); -- the result is 1 - # So, we should use some custom approach to save these trailing zeros. - # To avoid it, we can add a small value like 0.000001 to prevent dropping of zeros from the end when casting. - # For examples above it looks like: - # select toString(toDecimal128(1.10, 2 + 1) + toDecimal128(0.001, 3)); -- the result is 1.101 - # After that, cut an extra symbol from the string, i.e. 1.101 -> 1.10 - # So, the algorithm is: - # 1. Cast to decimal with precision + 1 - # 2. Add a small value 10^(-precision-1) - # 3. Cast the result to string - # 4. Drop the extra digit from the string. To do that, we need to slice the string - # with length = digits in an integer part + 1 (symbol of ".") + precision - - if coltype.precision == 0: - return self.to_string(f"round({value})") - - precision = coltype.precision - # TODO: too complex, is there better performance way? - value = f""" - if({value} >= 0, '', '-') || left( - toString( - toDecimal128( - round(abs({value}), {precision}), - {precision} + 1 - ) - + - toDecimal128( - exp10(-{precision + 1}), - {precision} + 1 - ) - ), - toUInt8( - greatest( - floor(log10(abs({value}))) + 1, - 1 - ) - ) + 1 + {precision} - ) - """ - return value - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - prec = coltype.precision - if coltype.rounds: - timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)" - return self.to_string(timestamp) - - fractional = f"toUnixTimestamp64Micro(toDateTime64({value}, {prec})) % 1000000" - fractional = f"lpad({self.to_string(fractional)}, 6, '0')" - value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}" - return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" - - -class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "Clickhouse" - ROUNDS_ON_PREC_LOSS = False - TYPE_CLASSES = { - "Int8": Integer, - "Int16": Integer, - "Int32": Integer, - "Int64": Integer, - "Int128": Integer, - "Int256": Integer, - "UInt8": Integer, - "UInt16": Integer, - "UInt32": Integer, - "UInt64": Integer, - "UInt128": Integer, - "UInt256": Integer, - "Float32": Float, - "Float64": Float, - "Decimal": Decimal, - "UUID": Native_UUID, - "String": Text, - "FixedString": Text, - "DateTime": Timestamp, - "DateTime64": Timestamp, - "Bool": Boolean, - } - MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - def quote(self, s: str) -> str: - return f'"{s}"' - - def to_string(self, s: str) -> str: - return f"toString({s})" - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Done the same as for PostgreSQL but need to rewrite in another way - # because it does not help for float with a big integer part. - return super()._convert_db_precision_to_digits(p) - 2 - - def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: - nullable_prefix = "Nullable(" - if type_repr.startswith(nullable_prefix): - type_repr = type_repr[len(nullable_prefix) :].rstrip(")") - - if type_repr.startswith("Decimal"): - type_repr = "Decimal" - elif type_repr.startswith("FixedString"): - type_repr = "FixedString" - elif type_repr.startswith("DateTime64"): - type_repr = "DateTime64" - - return self.TYPE_CLASSES.get(type_repr) - - # def timestamp_value(self, t: DbTime) -> str: - # # return f"'{t}'" - # return f"'{str(t)[:19]}'" - - def set_timezone_to_utc(self) -> str: - raise NotImplementedError() - - def current_timestamp(self) -> str: - return "now()" - - -class Clickhouse(ThreadedDatabase): - dialect = Dialect() - CONNECT_URI_HELP = "clickhouse://:@/" - CONNECT_URI_PARAMS = ["database?"] - - def __init__(self, *, thread_count: int, **kw): - super().__init__(thread_count=thread_count) - - self._args = kw - # In Clickhouse database and schema are the same - self.default_schema = kw.get("database", DEFAULT_DATABASE) - - def create_connection(self): - clickhouse = import_clickhouse() - - class SingleConnection(clickhouse.dbapi.connection.Connection): - """Not thread-safe connection to Clickhouse""" - - def cursor(self, cursor_factory=None): - if not len(self.cursors): - _ = super().cursor() - return self.cursors[0] - - try: - return SingleConnection(**self._args) - except clickhouse.OperationError as e: - raise ConnectError(*e.args) from e - - @property - def is_autocommit(self) -> bool: - return True diff --git a/data_diff/sqeleton/databases/databricks.py b/data_diff/sqeleton/databases/databricks.py deleted file mode 100644 index e478039f..00000000 --- a/data_diff/sqeleton/databases/databricks.py +++ /dev/null @@ -1,199 +0,0 @@ -import math -from typing import Dict, Sequence -import logging - -from data_diff.sqeleton.abcs.database_types import ( - Integer, - Float, - Decimal, - Timestamp, - Text, - TemporalType, - NumericType, - DbPath, - ColType, - UnknownColType, - Boolean, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from data_diff.sqeleton.databases.base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - BaseDialect, - ThreadedDatabase, - import_helper, - parse_table_name, - Mixin_RandomSample, -) - - -@import_helper(text="You can install it using 'pip install databricks-sql-connector'") -def import_databricks(): - import databricks.sql - - return databricks - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - """Databricks timestamp contains no more than 6 digits in precision""" - - if coltype.rounds: - timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)" - return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')" - - precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision) - return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" - - def normalize_number(self, value: str, coltype: NumericType) -> str: - value = f"cast({value} as decimal(38, {coltype.precision}))" - if coltype.precision > 0: - value = f"format_number({value}, {coltype.precision})" - return f"replace({self.to_string(value)}, ',', '')" - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"cast ({value} as int)") - - -class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "Databricks" - ROUNDS_ON_PREC_LOSS = True - TYPE_CLASSES = { - # Numbers - "INT": Integer, - "SMALLINT": Integer, - "TINYINT": Integer, - "BIGINT": Integer, - "FLOAT": Float, - "DOUBLE": Float, - "DECIMAL": Decimal, - # Timestamps - "TIMESTAMP": Timestamp, - # Text - "STRING": Text, - # Boolean - "BOOLEAN": Boolean, - } - MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - def quote(self, s: str): - return f"`{s}`" - - def to_string(self, s: str) -> str: - return f"cast({s} as string)" - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 2 due to wierd precision issues - return max(super()._convert_db_precision_to_digits(p) - 2, 0) - - def set_timezone_to_utc(self) -> str: - return "SET TIME ZONE 'UTC'" - - -class Databricks(ThreadedDatabase): - dialect = Dialect() - CONNECT_URI_HELP = "databricks://:@/" - CONNECT_URI_PARAMS = ["catalog", "schema"] - - def __init__(self, *, thread_count, **kw): - logging.getLogger("databricks.sql").setLevel(logging.WARNING) - - self._args = kw - self.default_schema = kw.get("schema", "default") - self.catalog = self._args.get("catalog", "hive_metastore") - super().__init__(thread_count=thread_count) - - def create_connection(self): - databricks = import_databricks() - - try: - return databricks.sql.connect( - server_hostname=self._args["server_hostname"], - http_path=self._args["http_path"], - access_token=self._args["access_token"], - catalog=self.catalog, - ) - except databricks.sql.exc.Error as e: - raise ConnectionError(*e.args) from e - - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: - # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. - # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html - # So, to obtain information about schema, we should use another approach. - - conn = self.create_connection() - - catalog, schema, table = self._normalize_table_path(path) - with conn.cursor() as cursor: - cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table) - try: - rows = cursor.fetchall() - finally: - conn.close() - if not rows: - raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - - d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} - assert len(d) == len(rows) - return d - - def _process_table_schema( - self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None - ): - accept = {i.lower() for i in filter_columns} - rows = [row for name, row in raw_schema.items() if name.lower() in accept] - - resulted_rows = [] - for row in rows: - row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1] - type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType) - - if issubclass(type_cls, Integer): - row = (row[0], row_type, None, None, 0) - - elif issubclass(type_cls, Float): - numeric_precision = math.ceil(row[2] / math.log(2, 10)) - row = (row[0], row_type, None, numeric_precision, None) - - elif issubclass(type_cls, Decimal): - items = row[1][8:].rstrip(")").split(",") - numeric_precision, numeric_scale = int(items[0]), int(items[1]) - row = (row[0], row_type, None, numeric_precision, numeric_scale) - - elif issubclass(type_cls, Timestamp): - row = (row[0], row_type, row[2], None, None) - - else: - row = (row[0], row_type, None, None, None) - - resulted_rows.append(row) - - col_dict: Dict[str, ColType] = {row[0]: self.dialect.parse_type(path, *row) for row in resulted_rows} - - self._refine_coltypes(path, col_dict, where) - return col_dict - - def parse_table_name(self, name: str) -> DbPath: - path = parse_table_name(name) - return tuple(i for i in self._normalize_table_path(path) if i is not None) - - @property - def is_autocommit(self) -> bool: - return True - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return self.catalog, self.default_schema, path[0] - elif len(path) == 2: - return self.catalog, path[0], path[1] - elif len(path) == 3: - return path - - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or catalog.schema.table" - ) diff --git a/data_diff/sqeleton/databases/duckdb.py b/data_diff/sqeleton/databases/duckdb.py deleted file mode 100644 index 827a0483..00000000 --- a/data_diff/sqeleton/databases/duckdb.py +++ /dev/null @@ -1,192 +0,0 @@ -from typing import Union - -from data_diff.utils import match_regexps -from data_diff.sqeleton.abcs.database_types import ( - Timestamp, - TimestampTZ, - DbPath, - ColType, - Float, - Decimal, - Integer, - TemporalType, - Native_UUID, - Text, - FractionalType, - Boolean, - AbstractTable, -) -from data_diff.sqeleton.abcs.mixins import ( - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, - AbstractMixin_RandomSample, - AbstractMixin_Regex, -) -from data_diff.sqeleton.databases.base import ( - Database, - BaseDialect, - import_helper, - ConnectError, - ThreadLocalInterpreter, - TIMESTAMP_PRECISION_POS, -) -from data_diff.sqeleton.databases.base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Mixin_Schema -from data_diff.sqeleton.queries.ast_classes import Func, Compilable -from data_diff.sqeleton.queries.api import code - - -@import_helper("duckdb") -def import_duckdb(): - import duckdb - - return duckdb - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - # It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers. - if coltype.rounds and coltype.precision > 0: - return f"CONCAT(SUBSTRING(STRFTIME({value}::TIMESTAMP, '%Y-%m-%d %H:%M:%S.'),1,23), LPAD(((ROUND(strftime({value}::timestamp, '%f')::DECIMAL(15,7)/100000,{coltype.precision-1})*100000)::INT)::VARCHAR,6,'0'))" - - return f"rpad(substring(strftime({value}::timestamp, '%Y-%m-%d %H:%M:%S.%f'),1,{TIMESTAMP_PRECISION_POS+coltype.precision}),26,'0')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"{value}::DECIMAL(38, {coltype.precision})") - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"{value}::INTEGER") - - -class Mixin_RandomSample(AbstractMixin_RandomSample): - def random_sample_n(self, tbl: AbstractTable, size: int) -> AbstractTable: - return code("SELECT * FROM ({tbl}) USING SAMPLE {size};", tbl=tbl, size=size) - - def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> AbstractTable: - return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio)) - - -class Mixin_Regex(AbstractMixin_Regex): - def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: - return Func("regexp_matches", [string, pattern]) - - -class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "DuckDB" - ROUNDS_ON_PREC_LOSS = False - SUPPORTS_PRIMARY_KEY = True - SUPPORTS_INDEXES = True - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - TYPE_CLASSES = { - # Timestamps - "TIMESTAMP WITH TIME ZONE": TimestampTZ, - "TIMESTAMP": Timestamp, - # Numbers - "DOUBLE": Float, - "FLOAT": Float, - "DECIMAL": Decimal, - "INTEGER": Integer, - "BIGINT": Integer, - # Text - "VARCHAR": Text, - "TEXT": Text, - # UUID - "UUID": Native_UUID, - # Bool - "BOOLEAN": Boolean, - } - - def quote(self, s: str): - return f'"{s}"' - - def to_string(self, s: str): - return f"{s}::VARCHAR" - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 2 due to wierd precision issues in PostgreSQL - return super()._convert_db_precision_to_digits(p) - 2 - - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - regexps = { - r"DECIMAL\((\d+),(\d+)\)": Decimal, - } - - for m, t_cls in match_regexps(regexps, type_repr): - precision = int(m.group(2)) - return t_cls(precision=precision) - - return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) - - def set_timezone_to_utc(self) -> str: - return "SET GLOBAL TimeZone='UTC'" - - def current_timestamp(self) -> str: - return "current_timestamp" - - -class DuckDB(Database): - dialect = Dialect() - SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it - default_schema = "main" - CONNECT_URI_HELP = "duckdb://@" - CONNECT_URI_PARAMS = ["database", "dbpath"] - - def __init__(self, **kw): - self._args = kw - self._conn = self.create_connection() - - @property - def is_autocommit(self) -> bool: - return True - - def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): - "Uses the standard SQL cursor interface" - return self._query_conn(self._conn, sql_code) - - def close(self): - super().close() - self._conn.close() - - def create_connection(self): - ddb = import_duckdb() - try: - return ddb.connect(self._args["filepath"]) - except ddb.OperationalError as e: - raise ConnectError(*e.args) from e - - def select_table_schema(self, path: DbPath) -> str: - database, schema, table = self._normalize_table_path(path) - - info_schema_path = ["information_schema", "columns"] - if database: - info_schema_path.insert(0, database) - - return ( - f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return None, self.default_schema, path[0] - elif len(path) == 2: - return None, path[0], path[1] - elif len(path) == 3: - return path - - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" - ) diff --git a/data_diff/sqeleton/databases/mssql.py b/data_diff/sqeleton/databases/mssql.py deleted file mode 100644 index d18f3fda..00000000 --- a/data_diff/sqeleton/databases/mssql.py +++ /dev/null @@ -1,214 +0,0 @@ -from typing import Optional -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from data_diff.sqeleton.databases.base import ( - CHECKSUM_HEXDIGITS, - Mixin_OptimizerHints, - Mixin_RandomSample, - QueryError, - ThreadedDatabase, - import_helper, - ConnectError, - BaseDialect, -) -from data_diff.sqeleton.databases.base import Mixin_Schema -from data_diff.sqeleton.abcs.database_types import ( - JSON, - Timestamp, - TimestampTZ, - DbPath, - Float, - Decimal, - Integer, - TemporalType, - Native_UUID, - Text, - FractionalType, - Boolean, -) - - -@import_helper("mssql") -def import_mssql(): - import pyodbc - - return pyodbc - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.precision > 0: - formatted_value = ( - f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss') + '.' + " - f"SUBSTRING(FORMAT({value}, 'fffffff'), 1, {coltype.precision})" - ) - else: - formatted_value = f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss')" - - return formatted_value - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - if coltype.precision == 0: - return f"CAST(FLOOR({value}) AS VARCHAR)" - - return f"FORMAT({value}, 'N{coltype.precision}')" - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))" - - -class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "MsSQL" - ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True - SUPPORTS_INDEXES = True - TYPE_CLASSES = { - # Timestamps - "datetimeoffset": TimestampTZ, - "datetime": Timestamp, - "datetime2": Timestamp, - "smalldatetime": Timestamp, - "date": Timestamp, - # Numbers - "float": Float, - "real": Float, - "decimal": Decimal, - "money": Decimal, - "smallmoney": Decimal, - # int - "int": Integer, - "bigint": Integer, - "tinyint": Integer, - "smallint": Integer, - # Text - "varchar": Text, - "char": Text, - "text": Text, - "ntext": Text, - "nvarchar": Text, - "nchar": Text, - "binary": Text, - "varbinary": Text, - # UUID - "uniqueidentifier": Native_UUID, - # Bool - "bit": Boolean, - # JSON - "json": JSON, - } - - MIXINS = {Mixin_Schema, Mixin_NormalizeValue, Mixin_RandomSample} - - def quote(self, s: str): - return f"[{s}]" - - def set_timezone_to_utc(self) -> str: - raise NotImplementedError("MsSQL does not support a session timezone setting.") - - def current_timestamp(self) -> str: - return "GETDATE()" - - def current_database(self) -> str: - return "DB_NAME()" - - def current_schema(self) -> str: - return """default_schema_name - FROM sys.database_principals - WHERE name = CURRENT_USER""" - - def to_string(self, s: str): - return f"CONVERT(varchar, {s})" - - def type_repr(self, t) -> str: - try: - return {bool: "bit"}[t] - except KeyError: - return super().type_repr(t) - - def random(self) -> str: - return "rand()" - - def is_distinct_from(self, a: str, b: str) -> str: - # IS (NOT) DISTINCT FROM is available only since SQLServer 2022. - # See: https://stackoverflow.com/a/18684859/857383 - return f"(({a}<>{b} OR {a} IS NULL OR {b} IS NULL) AND NOT({a} IS NULL AND {b} IS NULL))" - - def offset_limit( - self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None - ) -> str: - if offset: - raise NotImplementedError("No support for OFFSET in query") - - result = "" - if not has_order_by: - result += "ORDER BY 1" - - result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY" - return result - - def constant_values(self, rows) -> str: - values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) - return f"VALUES {values}" - - -class MsSQL(ThreadedDatabase): - dialect = Dialect() - # - CONNECT_URI_HELP = "mssql://:@//" - CONNECT_URI_PARAMS = ["database", "schema"] - - def __init__(self, host, port, user, password, *, database, thread_count, **kw): - args = dict(server=host, port=port, database=database, user=user, password=password, **kw) - self._args = {k: v for k, v in args.items() if v is not None} - self._args["driver"] = "{ODBC Driver 18 for SQL Server}" - - # TODO temp dev debug - self._args["TrustServerCertificate"] = "yes" - - try: - self.default_database = self._args["database"] - self.default_schema = self._args["schema"] - except KeyError: - raise ValueError("Specify a default database and schema.") - - super().__init__(thread_count=thread_count) - - def create_connection(self): - self._mssql = import_mssql() - try: - connection = self._mssql.connect(**self._args) - return connection - except self._mssql.Error as error: - raise ConnectError(*error.args) from error - - def select_table_schema(self, path: DbPath) -> str: - """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" - database, schema, name = self._normalize_table_path(path) - info_schema_path = ["information_schema", "columns"] - if database: - info_schema_path.insert(0, self.dialect.quote(database)) - - return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " - f"FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{name}' AND table_schema = '{schema}'" - ) - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return self.default_database, self.default_schema, path[0] - elif len(path) == 2: - return self.default_database, path[0], path[1] - elif len(path) == 3: - return path - - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" - ) - - def _query_cursor(self, c, sql_code: str): - try: - return super()._query_cursor(c, sql_code) - except self._mssql.DatabaseError as e: - raise QueryError(e) diff --git a/data_diff/sqeleton/databases/mysql.py b/data_diff/sqeleton/databases/mysql.py deleted file mode 100644 index 7c659749..00000000 --- a/data_diff/sqeleton/databases/mysql.py +++ /dev/null @@ -1,160 +0,0 @@ -from data_diff.sqeleton.abcs.database_types import ( - Datetime, - Timestamp, - Float, - Decimal, - Integer, - Text, - TemporalType, - FractionalType, - ColType_UUID, - Boolean, - Date, -) -from data_diff.sqeleton.abcs.mixins import ( - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, - AbstractMixin_Regex, - AbstractMixin_RandomSample, -) -from data_diff.sqeleton.databases.base import ( - Mixin_OptimizerHints, - ThreadedDatabase, - import_helper, - ConnectError, - BaseDialect, - Compilable, -) -from data_diff.sqeleton.databases.base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - TIMESTAMP_PRECISION_POS, - Mixin_Schema, - Mixin_RandomSample, -) -from data_diff.sqeleton.queries.ast_classes import BinBoolOp - - -@import_helper("mysql") -def import_mysql(): - import mysql.connector - - return mysql.connector - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") - - s = self.to_string(f"cast({value} as datetime(6))") - return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") - - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - return f"TRIM(CAST({value} AS char))" - - -class Mixin_Regex(AbstractMixin_Regex): - def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: - return BinBoolOp("REGEXP", [string, pattern]) - - -class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "MySQL" - ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True - SUPPORTS_INDEXES = True - TYPE_CLASSES = { - # Dates - "datetime": Datetime, - "timestamp": Timestamp, - "date": Date, - # Numbers - "double": Float, - "float": Float, - "decimal": Decimal, - "int": Integer, - "bigint": Integer, - "mediumint": Integer, - "smallint": Integer, - "tinyint": Integer, - # Text - "varchar": Text, - "char": Text, - "varbinary": Text, - "binary": Text, - "text": Text, - "mediumtext": Text, - "longtext": Text, - "tinytext": Text, - # Boolean - "boolean": Boolean, - } - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - def quote(self, s: str): - return f"`{s}`" - - def to_string(self, s: str): - return f"cast({s} as char)" - - def is_distinct_from(self, a: str, b: str) -> str: - return f"not ({a} <=> {b})" - - def random(self) -> str: - return "RAND()" - - def type_repr(self, t) -> str: - try: - return { - str: "VARCHAR(1024)", - }[t] - except KeyError: - return super().type_repr(t) - - def explain_as_text(self, query: str) -> str: - return f"EXPLAIN FORMAT=TREE {query}" - - def optimizer_hints(self, s: str): - return f"/*+ {s} */ " - - def set_timezone_to_utc(self) -> str: - return "SET @@session.time_zone='+00:00'" - - -class MySQL(ThreadedDatabase): - dialect = Dialect() - SUPPORTS_ALPHANUMS = False - SUPPORTS_UNIQUE_CONSTAINT = True - CONNECT_URI_HELP = "mysql://:@/" - CONNECT_URI_PARAMS = ["database?"] - - def __init__(self, *, thread_count, **kw): - self._args = kw - - super().__init__(thread_count=thread_count) - - # In MySQL schema and database are synonymous - try: - self.default_schema = kw["database"] - except KeyError: - raise ValueError("MySQL URL must specify a database") - - def create_connection(self): - mysql = import_mysql() - try: - return mysql.connect(charset="utf8", use_unicode=True, **self._args) - except mysql.Error as e: - if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR: - raise ConnectError("Bad user name or password") from e - elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: - raise ConnectError("Database does not exist") from e - raise ConnectError(*e.args) from e diff --git a/data_diff/sqeleton/databases/oracle.py b/data_diff/sqeleton/databases/oracle.py deleted file mode 100644 index 825510a1..00000000 --- a/data_diff/sqeleton/databases/oracle.py +++ /dev/null @@ -1,206 +0,0 @@ -from typing import Dict, List, Optional - -from data_diff.utils import match_regexps -from data_diff.sqeleton.abcs.database_types import ( - Decimal, - Float, - Text, - DbPath, - TemporalType, - ColType, - DbTime, - ColType_UUID, - Timestamp, - TimestampTZ, - FractionalType, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema -from data_diff.sqeleton.abcs import Compilable -from data_diff.sqeleton.queries import this, table, SKIP -from data_diff.sqeleton.databases.base import ( - BaseDialect, - Mixin_OptimizerHints, - ThreadedDatabase, - import_helper, - ConnectError, - QueryError, - Mixin_RandomSample, -) -from data_diff.sqeleton.databases.base import TIMESTAMP_PRECISION_POS - -SESSION_TIME_ZONE = None # Changed by the tests - - -@import_helper("oracle") -def import_oracle(): - import oracledb - - return oracledb - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - # standard_hash is faster than DBMS_CRYPTO.Hash - # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? - return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - # Cast is necessary for correct MD5 (trimming not enough) - return f"CAST(TRIM({value}) AS VARCHAR(36))" - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" - - if coltype.precision > 0: - truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')" - else: - truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')" - return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - # FM999.9990 - format_str = "FM" + "9" * (38 - coltype.precision) - if coltype.precision: - format_str += "0." + "9" * (coltype.precision - 1) + "0" - return f"to_char({value}, '{format_str}')" - - -class Mixin_Schema(AbstractMixin_Schema): - def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: - return ( - table("ALL_TABLES") - .where( - this.OWNER == table_schema, - this.TABLE_NAME.like(like) if like is not None else SKIP, - ) - .select(table_name=this.TABLE_NAME) - ) - - -class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "Oracle" - SUPPORTS_PRIMARY_KEY = True - SUPPORTS_INDEXES = True - TYPE_CLASSES: Dict[str, type] = { - "NUMBER": Decimal, - "FLOAT": Float, - # Text - "CHAR": Text, - "NCHAR": Text, - "NVARCHAR2": Text, - "VARCHAR2": Text, - "DATE": Timestamp, - } - ROUNDS_ON_PREC_LOSS = True - PLACEHOLDER_TABLE = "DUAL" - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - def quote(self, s: str): - return f'"{s}"' - - def to_string(self, s: str): - return f"cast({s} as varchar(1024))" - - def offset_limit( - self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None - ) -> str: - if offset: - raise NotImplementedError("No support for OFFSET in query") - - return f"FETCH NEXT {limit} ROWS ONLY" - - def concat(self, items: List[str]) -> str: - joined_exprs = " || ".join(items) - return f"({joined_exprs})" - - def timestamp_value(self, t: DbTime) -> str: - return "timestamp '%s'" % t.isoformat(" ") - - def random(self) -> str: - return "dbms_random.value" - - def is_distinct_from(self, a: str, b: str) -> str: - return f"DECODE({a}, {b}, 1, 0) = 0" - - def type_repr(self, t) -> str: - try: - return { - str: "VARCHAR(1024)", - }[t] - except KeyError: - return super().type_repr(t) - - def constant_values(self, rows) -> str: - return " UNION ALL ".join( - "SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows - ) - - def explain_as_text(self, query: str) -> str: - raise NotImplementedError("Explain not yet implemented in Oracle") - - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - regexps = { - r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, - r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, - r"TIMESTAMP\((\d)\)": Timestamp, - } - - for m, t_cls in match_regexps(regexps, type_repr): - precision = int(m.group(1)) - return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - - return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) - - def set_timezone_to_utc(self) -> str: - return "ALTER SESSION SET TIME_ZONE = 'UTC'" - - def current_timestamp(self) -> str: - return "LOCALTIMESTAMP" - - -class Oracle(ThreadedDatabase): - dialect = Dialect() - CONNECT_URI_HELP = "oracle://:@/" - CONNECT_URI_PARAMS = ["database?"] - - def __init__(self, *, host, database, thread_count, **kw): - self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) - - self.default_schema = kw.get("user").upper() - - super().__init__(thread_count=thread_count) - - def create_connection(self): - self._oracle = import_oracle() - try: - c = self._oracle.connect(**self.kwargs) - if SESSION_TIME_ZONE: - c.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'") - return c - except Exception as e: - raise ConnectError(*e.args) from e - - def _query_cursor(self, c, sql_code: str): - try: - return super()._query_cursor(c, sql_code) - except self._oracle.DatabaseError as e: - raise QueryError(e) - - def select_table_schema(self, path: DbPath) -> str: - schema, name = self._normalize_table_path(path) - - return ( - f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" - f" FROM ALL_TAB_COLUMNS WHERE table_name = '{name}' AND owner = '{schema}'" - ) diff --git a/data_diff/sqeleton/databases/postgresql.py b/data_diff/sqeleton/databases/postgresql.py deleted file mode 100644 index 41228439..00000000 --- a/data_diff/sqeleton/databases/postgresql.py +++ /dev/null @@ -1,183 +0,0 @@ -from typing import List -from data_diff.sqeleton.abcs.database_types import ( - DbPath, - JSON, - Timestamp, - TimestampTZ, - Float, - Decimal, - Integer, - TemporalType, - Native_UUID, - Text, - FractionalType, - Boolean, - Date, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from data_diff.sqeleton.databases.base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema -from data_diff.sqeleton.databases.base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - _CHECKSUM_BITSIZE, - TIMESTAMP_PRECISION_POS, - Mixin_RandomSample, -) - -SESSION_TIME_ZONE = None # Changed by the tests - - -@import_helper("postgresql") -def import_postgresql(): - import psycopg2 - import psycopg2.extras - - psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select) - return psycopg2 - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" - - timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"{value}::decimal(38, {coltype.precision})") - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"{value}::int") - - def normalize_json(self, value: str, _coltype: JSON) -> str: - return f"{value}::text" - - -class PostgresqlDialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "PostgreSQL" - ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True - SUPPORTS_INDEXES = True - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - TYPE_CLASSES = { - # Timestamps - "timestamp with time zone": TimestampTZ, - "timestamp without time zone": Timestamp, - "timestamp": Timestamp, - "date": Date, - # Numbers - "double precision": Float, - "real": Float, - "decimal": Decimal, - "smallint": Integer, - "integer": Integer, - "numeric": Decimal, - "bigint": Integer, - # Text - "character": Text, - "character varying": Text, - "varchar": Text, - "text": Text, - "json": JSON, - "jsonb": JSON, - "uuid": Native_UUID, - "boolean": Boolean, - } - - def quote(self, s: str): - return f'"{s}"' - - def to_string(self, s: str): - return f"{s}::varchar" - - def concat(self, items: List[str]) -> str: - joined_exprs = " || ".join(items) - return f"({joined_exprs})" - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 2 due to wierd precision issues in PostgreSQL - return super()._convert_db_precision_to_digits(p) - 2 - - def set_timezone_to_utc(self) -> str: - return "SET TIME ZONE 'UTC'" - - def current_timestamp(self) -> str: - return "current_timestamp" - - def type_repr(self, t) -> str: - if isinstance(t, TimestampTZ): - return f"timestamp ({t.precision}) with time zone" - return super().type_repr(t) - - -class PostgreSQL(ThreadedDatabase): - dialect = PostgresqlDialect() - SUPPORTS_UNIQUE_CONSTAINT = True - CONNECT_URI_HELP = "postgresql://:@/" - CONNECT_URI_PARAMS = ["database?"] - - default_schema = "public" - - def __init__(self, *, thread_count, **kw): - self._args = kw - - super().__init__(thread_count=thread_count) - - def create_connection(self): - if not self._args: - self._args["host"] = None # psycopg2 requires 1+ arguments - - pg = import_postgresql() - try: - c = pg.connect(**self._args) - if SESSION_TIME_ZONE: - c.cursor().execute(f"SET TIME ZONE '{SESSION_TIME_ZONE}'") - return c - except pg.OperationalError as e: - raise ConnectError(*e.args) from e - - def select_table_schema(self, path: DbPath) -> str: - database, schema, table = self._normalize_table_path(path) - - info_schema_path = ["information_schema", "columns"] - if database: - info_schema_path.insert(0, database) - - return ( - f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - def select_table_unique_columns(self, path: DbPath) -> str: - database, schema, table = self._normalize_table_path(path) - - info_schema_path = ["information_schema", "key_column_usage"] - if database: - info_schema_path.insert(0, database) - - return ( - "SELECT column_name " - f"FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return None, self.default_schema, path[0] - elif len(path) == 2: - return None, path[0], path[1] - elif len(path) == 3: - return path - - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" - ) diff --git a/data_diff/sqeleton/databases/presto.py b/data_diff/sqeleton/databases/presto.py deleted file mode 100644 index 3a033ed9..00000000 --- a/data_diff/sqeleton/databases/presto.py +++ /dev/null @@ -1,202 +0,0 @@ -from functools import partial -import re - -from data_diff.utils import match_regexps - -from data_diff.sqeleton.abcs.database_types import ( - Timestamp, - TimestampTZ, - Integer, - Float, - Text, - FractionalType, - DbPath, - DbTime, - Decimal, - ColType, - ColType_UUID, - TemporalType, - Boolean, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from data_diff.sqeleton.databases.base import ( - BaseDialect, - Database, - import_helper, - ThreadLocalInterpreter, - Mixin_Schema, - Mixin_RandomSample, -) -from data_diff.sqeleton.databases.base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - TIMESTAMP_PRECISION_POS, -) - - -def query_cursor(c, sql_code): - c.execute(sql_code) - if sql_code.lower().startswith("select"): - return c.fetchall() - # Required for the query to actually run 🤯 - if re.match(r"(insert|create|truncate|drop|explain)", sql_code, re.IGNORECASE): - return c.fetchone() - - -@import_helper("presto") -def import_presto(): - import prestodb - - return prestodb - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - # Trim doesn't work on CHAR type - return f"TRIM(CAST({value} AS VARCHAR))" - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - # TODO rounds - if coltype.rounds: - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - else: - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - - return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"cast ({value} as int)") - - -class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "Presto" - ROUNDS_ON_PREC_LOSS = True - TYPE_CLASSES = { - # Timestamps - "timestamp with time zone": TimestampTZ, - "timestamp without time zone": Timestamp, - "timestamp": Timestamp, - # Numbers - "integer": Integer, - "bigint": Integer, - "real": Float, - "double": Float, - # Text - "varchar": Text, - # Boolean - "boolean": Boolean, - } - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - def explain_as_text(self, query: str) -> str: - return f"EXPLAIN (FORMAT TEXT) {query}" - - def type_repr(self, t) -> str: - if isinstance(t, TimestampTZ): - return f"timestamp with time zone" - - try: - return {float: "REAL"}[t] - except KeyError: - return super().type_repr(t) - - def timestamp_value(self, t: DbTime) -> str: - return f"timestamp '{t.isoformat(' ')}'" - - def quote(self, s: str): - return f'"{s}"' - - def to_string(self, s: str): - return f"cast({s} as varchar)" - - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - _numeric_scale: int = None, - ) -> ColType: - timestamp_regexps = { - r"timestamp\((\d)\)": Timestamp, - r"timestamp\((\d)\) with time zone": TimestampTZ, - } - for m, t_cls in match_regexps(timestamp_regexps, type_repr): - precision = int(m.group(1)) - return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - - number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} - for m, n_cls in match_regexps(number_regexps, type_repr): - _prec, scale = map(int, m.groups()) - return n_cls(scale) - - string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text} - for m, n_cls in match_regexps(string_regexps, type_repr): - return n_cls() - - return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) - - def set_timezone_to_utc(self) -> str: - return "SET TIME ZONE '+00:00'" - - def current_timestamp(self) -> str: - return "current_timestamp" - - -class Presto(Database): - dialect = Dialect() - CONNECT_URI_HELP = "presto://@//" - CONNECT_URI_PARAMS = ["catalog", "schema"] - - default_schema = "public" - - def __init__(self, **kw): - prestodb = import_presto() - - if kw.get("schema"): - self.default_schema = kw.get("schema") - - if kw.get("auth") == "basic": # if auth=basic, add basic authenticator for Presto - kw["auth"] = prestodb.auth.BasicAuthentication(kw.pop("user"), kw.pop("password")) - - if "cert" in kw: # if a certificate was specified in URI, verify session with cert - cert = kw.pop("cert") - self._conn = prestodb.dbapi.connect(**kw) - self._conn._http_session.verify = cert - else: - self._conn = prestodb.dbapi.connect(**kw) - - def _query(self, sql_code: str) -> list: - "Uses the standard SQL cursor interface" - c = self._conn.cursor() - - if isinstance(sql_code, ThreadLocalInterpreter): - return sql_code.apply_queries(partial(query_cursor, c)) - - return query_cursor(c, sql_code) - - def close(self): - super().close() - self._conn.close() - - def select_table_schema(self, path: DbPath) -> str: - schema, table = self._normalize_table_path(path) - - return ( - "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision, NULL as numeric_scale " - "FROM INFORMATION_SCHEMA.COLUMNS " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - @property - def is_autocommit(self) -> bool: - return False diff --git a/data_diff/sqeleton/databases/redshift.py b/data_diff/sqeleton/databases/redshift.py deleted file mode 100644 index e41d961e..00000000 --- a/data_diff/sqeleton/databases/redshift.py +++ /dev/null @@ -1,176 +0,0 @@ -from typing import List, Dict -from data_diff.sqeleton.abcs.database_types import ( - Float, - JSON, - TemporalType, - FractionalType, - DbPath, - TimestampTZ, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from data_diff.sqeleton.databases.postgresql import ( - PostgreSQL, - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - TIMESTAMP_PRECISION_POS, - PostgresqlDialect, - Mixin_NormalizeValue, -) - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" - - -class Mixin_NormalizeValue(Mixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"{value}::timestamp(6)" - # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. - secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" - # Get the milliseconds from timestamp. - ms = f"extract(ms from {timestamp})" - # Get the microseconds from timestamp, without the milliseconds! - us = f"extract(us from {timestamp})" - # epoch = Total time since epoch in microseconds. - epoch = f"{secs}*1000000 + {ms}*1000 + {us}" - timestamp6 = ( - f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" - ) - else: - timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"{value}::decimal(38,{coltype.precision})") - - def normalize_json(self, value: str, _coltype: JSON) -> str: - return f"nvl2({value}, json_serialize({value}), NULL)" - - -class Dialect(PostgresqlDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "Redshift" - TYPE_CLASSES = { - **PostgresqlDialect.TYPE_CLASSES, - "double": Float, - "real": Float, - "super": JSON, - } - SUPPORTS_INDEXES = False - - def concat(self, items: List[str]) -> str: - joined_exprs = " || ".join(items) - return f"({joined_exprs})" - - def is_distinct_from(self, a: str, b: str) -> str: - return f"({a} IS NULL != {b} IS NULL) OR ({a}!={b})" - - def type_repr(self, t) -> str: - if isinstance(t, TimestampTZ): - return f"timestamptz" - return super().type_repr(t) - - -class Redshift(PostgreSQL): - dialect = Dialect() - CONNECT_URI_HELP = "redshift://:@/" - CONNECT_URI_PARAMS = ["database?"] - - def select_table_schema(self, path: DbPath) -> str: - database, schema, table = self._normalize_table_path(path) - - info_schema_path = ["information_schema", "columns"] - if database: - info_schema_path.insert(0, database) - - return ( - f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" - ) - - def select_external_table_schema(self, path: DbPath) -> str: - database, schema, table = self._normalize_table_path(path) - - db_clause = "" - if database: - db_clause = f" AND redshift_database_name = '{database.lower()}'" - - return ( - f"""SELECT - columnname AS column_name - , CASE WHEN external_type = 'string' THEN 'varchar' ELSE external_type END AS data_type - , NULL AS datetime_precision - , NULL AS numeric_precision - , NULL AS numeric_scale - FROM svv_external_columns - WHERE tablename = '{table.lower()}' AND schemaname = '{schema.lower()}' - """ - + db_clause - ) - - def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]: - rows = self.query(self.select_external_table_schema(path), list) - if not rows: - raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - - d = {r[0]: r for r in rows} - assert len(d) == len(rows) - return d - - def select_view_columns(self, path: DbPath) -> str: - _, schema, table = self._normalize_table_path(path) - - return """select * from pg_get_cols('{}.{}') - cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int) - """.format( - schema, table - ) - - def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]: - rows = self.query(self.select_view_columns(path), list) - - if not rows: - raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns") - - output = {} - for r in rows: - col_name = r[2] - type_info = r[3].split("(") - base_type = type_info[0] - precision = None - scale = None - - if len(type_info) > 1: - if base_type == "numeric": - precision, scale = type_info[1][:-1].split(",") - precision = int(precision) - scale = int(scale) - - out = [col_name, base_type, None, precision, scale] - output[col_name] = tuple(out) - - return output - - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: - try: - return super().query_table_schema(path) - except RuntimeError: - try: - return self.query_external_table_schema(path) - except RuntimeError: - return self.query_pg_get_cols(path) - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return None, self.default_schema, path[0] - elif len(path) == 2: - return None, path[0], path[1] - elif len(path) == 3: - return path - - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" - ) diff --git a/data_diff/sqeleton/databases/snowflake.py b/data_diff/sqeleton/databases/snowflake.py deleted file mode 100644 index 6868f52f..00000000 --- a/data_diff/sqeleton/databases/snowflake.py +++ /dev/null @@ -1,228 +0,0 @@ -from typing import Union, List -import logging - -from data_diff.sqeleton.abcs.database_types import ( - Timestamp, - TimestampTZ, - Decimal, - Float, - Text, - FractionalType, - TemporalType, - DbPath, - Boolean, - Date, -) -from data_diff.sqeleton.abcs.mixins import ( - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, - AbstractMixin_Schema, - AbstractMixin_TimeTravel, -) -from data_diff.sqeleton.abcs import Compilable -from data_diff.sqeleton.queries import table, this, SKIP, code -from data_diff.sqeleton.databases.base import ( - BaseDialect, - ConnectError, - Database, - import_helper, - CHECKSUM_MASK, - ThreadLocalInterpreter, - Mixin_RandomSample, -) - - -@import_helper("snowflake") -def import_snowflake(): - import snowflake.connector - from cryptography.hazmat.primitives import serialization - from cryptography.hazmat.backends import default_backend - - return snowflake, serialization, default_backend - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))" - else: - timestamp = f"cast(convert_timezone('UTC', {value}) as timestamp({coltype.precision}))" - - return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"{value}::int") - - -class Mixin_Schema(AbstractMixin_Schema): - def table_information(self) -> Compilable: - return table("INFORMATION_SCHEMA", "TABLES") - - def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: - return ( - self.table_information() - .where( - this.TABLE_SCHEMA == table_schema, - this.TABLE_NAME.like(like) if like is not None else SKIP, - this.TABLE_TYPE == "BASE TABLE", - ) - .select(table_name=this.TABLE_NAME) - ) - - -class Mixin_TimeTravel(AbstractMixin_TimeTravel): - def time_travel( - self, - table: Compilable, - before: bool = False, - timestamp: Compilable = None, - offset: Compilable = None, - statement: Compilable = None, - ) -> Compilable: - at_or_before = "AT" if before else "BEFORE" - if timestamp is not None: - assert offset is None and statement is None - key = "timestamp" - value = timestamp - elif offset is not None: - assert statement is None - key = "offset" - value = offset - else: - assert statement is not None - key = "statement" - value = statement - - return code(f"{{table}} {at_or_before}({key} => {{value}})", table=table, value=value) - - -class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "Snowflake" - ROUNDS_ON_PREC_LOSS = False - TYPE_CLASSES = { - # Timestamps - "TIMESTAMP_NTZ": Timestamp, - "TIMESTAMP_LTZ": Timestamp, - "TIMESTAMP_TZ": TimestampTZ, - "DATE": Date, - # Numbers - "NUMBER": Decimal, - "FLOAT": Float, - # Text - "TEXT": Text, - # Boolean - "BOOLEAN": Boolean, - } - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} - - def explain_as_text(self, query: str) -> str: - return f"EXPLAIN USING TEXT {query}" - - def quote(self, s: str): - return f'"{s}"' - - def to_string(self, s: str): - return f"cast({s} as string)" - - def table_information(self) -> Compilable: - return table("INFORMATION_SCHEMA", "TABLES") - - def set_timezone_to_utc(self) -> str: - return "ALTER SESSION SET TIMEZONE = 'UTC'" - - def optimizer_hints(self, hints: str) -> str: - raise NotImplementedError("Optimizer hints not yet implemented in snowflake") - - def type_repr(self, t) -> str: - if isinstance(t, TimestampTZ): - return f"timestamp_tz({t.precision})" - return super().type_repr(t) - - -class Snowflake(Database): - dialect = Dialect() - CONNECT_URI_HELP = "snowflake://:@//?warehouse=" - CONNECT_URI_PARAMS = ["database", "schema"] - CONNECT_URI_KWPARAMS = ["warehouse"] - - def __init__(self, *, schema: str, **kw): - snowflake, serialization, default_backend = import_snowflake() - logging.getLogger("snowflake.connector").setLevel(logging.WARNING) - - # Ignore the error: snowflake.connector.network.RetryRequest: could not find io module state - # It's a known issue: https://github.com/snowflakedb/snowflake-connector-python/issues/145 - logging.getLogger("snowflake.connector.network").disabled = True - - assert '"' not in schema, "Schema name should not contain quotes!" - # If a private key is used, read it from the specified path and pass it as "private_key" to the connector. - if "key" in kw: - with open(kw.get("key"), "rb") as key: - if "password" in kw: - raise ConnectError("Cannot use password and key at the same time") - if kw.get("private_key_passphrase"): - encoded_passphrase = kw.get("private_key_passphrase").encode() - else: - encoded_passphrase = None - p_key = serialization.load_pem_private_key( - key.read(), - password=encoded_passphrase, - backend=default_backend(), - ) - - kw["private_key"] = p_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - - self._conn = snowflake.connector.connect(schema=f'"{schema}"', **kw) - - self.default_schema = schema - - def close(self): - super().close() - self._conn.close() - - def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): - "Uses the standard SQL cursor interface" - return self._query_conn(self._conn, sql_code) - - def select_table_schema(self, path: DbPath) -> str: - """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" - database, schema, name = self._normalize_table_path(path) - info_schema_path = ["information_schema", "columns"] - if database: - info_schema_path.insert(0, database) - - return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " - f"FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{name}' AND table_schema = '{schema}'" - ) - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return None, self.default_schema, path[0] - elif len(path) == 2: - return None, path[0], path[1] - elif len(path) == 3: - return path - - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" - ) - - @property - def is_autocommit(self) -> bool: - return True - - def query_table_unique_columns(self, path: DbPath) -> List[str]: - return [] diff --git a/data_diff/sqeleton/databases/trino.py b/data_diff/sqeleton/databases/trino.py deleted file mode 100644 index a255b9a7..00000000 --- a/data_diff/sqeleton/databases/trino.py +++ /dev/null @@ -1,48 +0,0 @@ -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from data_diff.sqeleton.abcs.database_types import TemporalType, ColType_UUID -from data_diff.sqeleton.databases import presto -from data_diff.sqeleton.databases.base import import_helper -from data_diff.sqeleton.databases.base import TIMESTAMP_PRECISION_POS - - -@import_helper("trino") -def import_trino(): - import trino - - return trino - - -Mixin_MD5 = presto.Mixin_MD5 - - -class Mixin_NormalizeValue(presto.Mixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - s = f"date_format(cast({value} as timestamp({coltype.precision})), '%Y-%m-%d %H:%i:%S.%f')" - else: - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - - return ( - f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')" - ) - - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - return f"TRIM({value})" - - -class Dialect(presto.Dialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "Trino" - - -class Trino(presto.Presto): - dialect = Dialect() - CONNECT_URI_HELP = "trino://@//" - CONNECT_URI_PARAMS = ["catalog", "schema"] - - def __init__(self, **kw): - trino = import_trino() - - if kw.get("schema"): - self.default_schema = kw.get("schema") - - self._conn = trino.dbapi.connect(**kw) diff --git a/data_diff/sqeleton/databases/vertica.py b/data_diff/sqeleton/databases/vertica.py deleted file mode 100644 index 9642ff7c..00000000 --- a/data_diff/sqeleton/databases/vertica.py +++ /dev/null @@ -1,181 +0,0 @@ -from typing import List - -from data_diff.utils import match_regexps -from data_diff.sqeleton.databases.base import ( - CHECKSUM_HEXDIGITS, - MD5_HEXDIGITS, - TIMESTAMP_PRECISION_POS, - BaseDialect, - ConnectError, - DbPath, - ColType, - ThreadedDatabase, - import_helper, - Mixin_RandomSample, -) -from data_diff.sqeleton.abcs.database_types import ( - Decimal, - Float, - FractionalType, - Integer, - TemporalType, - Text, - Timestamp, - TimestampTZ, - Boolean, - ColType_UUID, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema -from data_diff.sqeleton.abcs import Compilable -from data_diff.sqeleton.queries import table, this, SKIP - - -@import_helper("vertica") -def import_vertica(): - import vertica_python - - return vertica_python - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" - - timestamp6 = f"TO_CHAR({value}::TIMESTAMP(6), 'YYYY-MM-DD HH24:MI:SS.US')" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"CAST({value} AS NUMERIC(38, {coltype.precision}))") - - def normalize_uuid(self, value: str, _coltype: ColType_UUID) -> str: - # Trim doesn't work on CHAR type - return f"TRIM(CAST({value} AS VARCHAR))" - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"cast ({value} as int)") - - -class Mixin_Schema(AbstractMixin_Schema): - def table_information(self) -> Compilable: - return table("v_catalog", "tables") - - def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: - return ( - self.table_information() - .where( - this.table_schema == table_schema, - this.table_name.like(like) if like is not None else SKIP, - ) - .select(this.table_name) - ) - - -class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): - name = "Vertica" - ROUNDS_ON_PREC_LOSS = True - - TYPE_CLASSES = { - # Timestamps - "timestamp": Timestamp, - "timestamptz": TimestampTZ, - # Numbers - "numeric": Decimal, - "int": Integer, - "float": Float, - # Text - "char": Text, - "varchar": Text, - # Boolean - "boolean": Boolean, - } - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - def quote(self, s: str): - return f'"{s}"' - - def concat(self, items: List[str]) -> str: - return " || ".join(items) - - def to_string(self, s: str) -> str: - return f"CAST({s} AS VARCHAR)" - - def is_distinct_from(self, a: str, b: str) -> str: - return f"not ({a} <=> {b})" - - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - timestamp_regexps = { - r"timestamp\(?(\d?)\)?": Timestamp, - r"timestamptz\(?(\d?)\)?": TimestampTZ, - } - for m, t_cls in match_regexps(timestamp_regexps, type_repr): - precision = int(m.group(1)) if m.group(1) else 6 - return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - - number_regexps = { - r"numeric\((\d+),(\d+)\)": Decimal, - } - for m, n_cls in match_regexps(number_regexps, type_repr): - _prec, scale = map(int, m.groups()) - return n_cls(scale) - - string_regexps = { - r"varchar\((\d+)\)": Text, - r"char\((\d+)\)": Text, - } - for m, n_cls in match_regexps(string_regexps, type_repr): - return n_cls() - - return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) - - def set_timezone_to_utc(self) -> str: - return "SET TIME ZONE TO 'UTC'" - - def current_timestamp(self) -> str: - return "current_timestamp(6)" - - -class Vertica(ThreadedDatabase): - dialect = Dialect() - CONNECT_URI_HELP = "vertica://:@/" - CONNECT_URI_PARAMS = ["database?"] - - default_schema = "public" - - def __init__(self, *, thread_count, **kw): - self._args = kw - self._args["AUTOCOMMIT"] = False - - super().__init__(thread_count=thread_count) - - def create_connection(self): - vertica = import_vertica() - try: - c = vertica.connect(**self._args) - return c - except vertica.errors.ConnectionError as e: - raise ConnectError(*e.args) from e - - def select_table_schema(self, path: DbPath) -> str: - schema, name = self._normalize_table_path(path) - - return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " - "FROM V_CATALOG.COLUMNS " - f"WHERE table_name = '{name}' AND table_schema = '{schema}'" - ) diff --git a/data_diff/sqeleton/queries/__init__.py b/data_diff/sqeleton/queries/__init__.py deleted file mode 100644 index f1eea7b1..00000000 --- a/data_diff/sqeleton/queries/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from data_diff.sqeleton.queries.compiler import Compiler, CompileError -from data_diff.sqeleton.queries.api import ( - this, - join, - outerjoin, - table, - SKIP, - sum_, - avg, - min_, - max_, - cte, - commit, - when, - coalesce, - and_, - if_, - or_, - leftjoin, - rightjoin, - current_timestamp, - code, -) -from data_diff.sqeleton.queries.ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In, Code, Column -from data_diff.sqeleton.queries.extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 8568ffc4..9864824a 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -8,10 +8,12 @@ from data_diff.utils import safezip, Vector from data_diff.utils import ArithString, split_space -from data_diff.sqeleton.databases import Database, DbPath, DbKey, DbTime -from data_diff.sqeleton.schema import Schema, create_schema -from data_diff.sqeleton.queries import Count, Checksum, SKIP, table, this, Expr, min_, max_, Code -from data_diff.sqeleton.queries.extras import ApplyFuncAndNormalizeAsString, NormalizeAsString +from data_diff.databases.base import Database +from data_diff.abcs.database_types import DbPath, DbKey, DbTime +from data_diff.schema import Schema, create_schema +from data_diff.queries.extras import Checksum +from data_diff.queries.api import Count, SKIP, table, this, Expr, min_, max_, Code +from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, NormalizeAsString logger = logging.getLogger("table_segment") diff --git a/data_diff/utils.py b/data_diff/utils.py index 67d74dce..b725285e 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -1,8 +1,9 @@ import json import logging import re +import string from abc import abstractmethod -from typing import Any, Dict, Iterable, Iterator, MutableMapping, Sequence, TypeVar +from typing import Any, Dict, Iterable, Iterator, List, MutableMapping, Sequence, TypeVar, Union from urllib.parse import urlparse import operator import threading diff --git a/tests/common.py b/tests/common.py index e434eaa7..222ae94b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -9,8 +9,8 @@ from parameterized import parameterized_class -from data_diff.sqeleton.queries import table -from data_diff.sqeleton.databases import Database +from data_diff.queries.api import table +from data_diff.databases.base import Database from data_diff import databases as db from data_diff import tracking @@ -81,7 +81,7 @@ def get_git_revision_short_hash() -> str: db.Clickhouse: TEST_CLICKHOUSE_CONN_STRING, db.Vertica: TEST_VERTICA_CONN_STRING, db.DuckDB: TEST_DUCKDB_CONN_STRING, - db.MsSql: TEST_MSSQL_CONN_STRING, + db.MsSQL: TEST_MSSQL_CONN_STRING, } _database_instances = {} diff --git a/tests/test_api.py b/tests/test_api.py index d97f59a9..07a4d57d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,8 +1,8 @@ from datetime import datetime, timedelta from data_diff import diff_tables, connect_to_table, Algorithm -from data_diff.databases import MySQL -from data_diff.sqeleton.queries import table, commit +from data_diff.databases.mysql import MySQL +from data_diff.queries.api import table, commit from tests.common import TEST_MYSQL_CONN_STRING, get_conn, random_table_suffix, DiffTestCase diff --git a/tests/test_cli.py b/tests/test_cli.py index 2e8111da..1fc4833e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,7 @@ import sys from datetime import datetime, timedelta -from data_diff.sqeleton.queries import commit, current_timestamp +from data_diff.queries.api import commit, current_timestamp from tests.common import DiffTestCase, CONN_STRINGS from tests.test_diff_tables import test_each_database diff --git a/tests/test_database.py b/tests/test_database.py index 1b967cc8..4f4c8ce1 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -4,11 +4,12 @@ import pytz -from data_diff.sqeleton import connect -from data_diff.sqeleton import databases as dbs -from data_diff.sqeleton.queries import table, current_timestamp, NormalizeAsString +from data_diff import connect +from data_diff import databases as dbs +from data_diff.queries.api import table, current_timestamp +from data_diff.queries.extras import NormalizeAsString from tests.common import TEST_MYSQL_CONN_STRING, test_each_database_in_list, get_conn, str_to_checksum, random_table_suffix -from data_diff.sqeleton.abcs.database_types import TimestampTZ +from data_diff.abcs.database_types import TimestampTZ TEST_DATABASES = { dbs.MySQL, diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 203731c4..3d345296 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -1,9 +1,7 @@ -from contextlib import suppress import unittest import time import json import re -import rich.progress import math import uuid from datetime import datetime, timedelta, timezone @@ -14,8 +12,8 @@ from parameterized import parameterized from data_diff.utils import number_to_human -from data_diff.sqeleton.queries import table, commit, this, Code -from data_diff.sqeleton.queries.api import insert_rows_in_batches +from data_diff.queries.api import table, commit, this, Code +from data_diff.queries.api import insert_rows_in_batches from data_diff import databases as db from data_diff.query_utils import drop_table @@ -351,7 +349,7 @@ def init_conns(): "boolean", ], }, - db.MsSql: { + db.MsSQL: { "int": ["INT", "BIGINT"], "datetime": ["datetime2(6)"], "float": ["DECIMAL(6, 2)", "FLOAT", "REAL"], @@ -625,7 +623,7 @@ def _insert_to_table(conn, table_path, values, coltype): for i, sample in values ] # mssql represents with int - elif isinstance(conn, db.MsSql) and coltype in ("BIT"): + elif isinstance(conn, db.MsSQL) and coltype in ("BIT"): values = [(i, int(sample)) for i, sample in values] insert_rows_in_batches(conn, tbl, values, columns=["id", "col"]) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 052c48ce..b5885a26 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -3,7 +3,7 @@ import uuid import unittest -from data_diff.sqeleton.queries import table, this, commit, code +from data_diff.queries.api import table, this, commit, code from data_diff.utils import ArithAlphanumeric, numberToAlphanum from data_diff.hashdiff_tables import HashDiffer diff --git a/tests/test_format.py b/tests/test_format.py index 0aa8ee8e..4743acc4 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -1,8 +1,8 @@ import unittest from data_diff.diff_tables import DiffResultWrapper, InfoTree, SegmentInfo, TableSegment from data_diff.format import jsonify -from data_diff.sqeleton.abcs.database_types import Integer -from data_diff.sqeleton.databases import Database +from data_diff.abcs.database_types import Integer +from data_diff.databases.base import Database class TestFormat(unittest.TestCase): diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 6a1559d7..b2c5c419 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -1,8 +1,8 @@ from typing import List from datetime import datetime -from data_diff.sqeleton.queries.ast_classes import TablePath -from data_diff.sqeleton.queries import table, commit +from data_diff.queries.ast_classes import TablePath +from data_diff.queries.api import table, commit from data_diff.table_segment import TableSegment from data_diff import databases as db from data_diff.joindiff_tables import JoinDiffer diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 418f44fb..b5e9fa10 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,7 +1,6 @@ import unittest -from data_diff.sqeleton.queries import table, commit - +from data_diff.queries.api import table, commit from data_diff import TableSegment, HashDiffer from data_diff import databases as db from tests.common import get_conn, random_table_suffix diff --git a/tests/test_query.py b/tests/test_query.py index b1937028..cc11b533 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,12 +1,13 @@ from datetime import datetime from typing import List, Optional import unittest -from data_diff.sqeleton.abcs import AbstractDatabase, AbstractDialect +from data_diff.abcs.database_types import AbstractDatabase, AbstractDialect from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict -from data_diff.sqeleton.queries import this, table, Compiler, outerjoin, cte, when, coalesce, CompileError -from data_diff.sqeleton.queries.ast_classes import Random -from data_diff.sqeleton import code, this, table +from data_diff.queries.compiler import Compiler, CompileError +from data_diff.queries.api import outerjoin, cte, when, coalesce +from data_diff.queries.ast_classes import Random +from data_diff.queries.api import code, this, table def normalize_spaces(s: str): diff --git a/tests/test_sql.py b/tests/test_sql.py index d8e07046..2dcab403 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -2,8 +2,8 @@ from tests.common import TEST_MYSQL_CONN_STRING -from data_diff.sqeleton import connect -from data_diff.sqeleton.queries import Compiler, Count, Explain, Select, table, In, BinOp, Code +from data_diff.databases import connect +from data_diff.queries.api import Compiler, Count, Explain, Select, table, In, BinOp, Code class TestSQL(unittest.TestCase):