diff --git a/data_diff/__init__.py b/data_diff/__init__.py index bbdffb01..60c79b10 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -1,9 +1,8 @@ from typing import Sequence, Tuple, Iterator, Optional, Union -from data_diff.sqeleton.abcs import DbTime, DbPath - +from data_diff.abcs.database_types import DbTime, DbPath from data_diff.tracking import disable_tracking -from data_diff.databases import connect +from data_diff.databases._connect import connect from data_diff.diff_tables import Algorithm from data_diff.hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR from data_diff.joindiff_tables import JoinDiffer, TABLE_WRITE_LIMIT diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 481c829f..77dc7fb6 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -12,8 +12,8 @@ from rich.logging import RichHandler import click -from data_diff.sqeleton.schema import create_schema -from data_diff.sqeleton.queries.api import current_timestamp +from data_diff.schema import create_schema +from data_diff.queries.api import current_timestamp from data_diff.dbt import dbt_diff from data_diff.utils import eval_name_template, remove_password_from_url, safezip, match_like, LogStatusHandler @@ -21,7 +21,7 @@ from data_diff.hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR from data_diff.joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer from data_diff.table_segment import TableSegment -from data_diff.databases import connect +from data_diff.databases._connect import connect from data_diff.parse_time import parse_time_before, UNITS_STR, ParseError from data_diff.config import apply_config_from_file from data_diff.tracking import disable_tracking, set_entrypoint_name diff --git a/data_diff/abcs/__init__.py b/data_diff/abcs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/data_diff/sqeleton/abcs/compiler.py b/data_diff/abcs/compiler.py similarity index 100% rename from data_diff/sqeleton/abcs/compiler.py rename to data_diff/abcs/compiler.py diff --git a/data_diff/sqeleton/abcs/database_types.py b/data_diff/abcs/database_types.py similarity index 99% rename from data_diff/sqeleton/abcs/database_types.py rename to data_diff/abcs/database_types.py index 26909067..82ec8352 100644 --- a/data_diff/sqeleton/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -6,7 +6,7 @@ from runtype import dataclass from typing_extensions import Self -from data_diff.sqeleton.utils import ArithAlphanumeric, ArithUUID, Unknown +from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown DbPath = Tuple[str, ...] diff --git a/data_diff/sqeleton/abcs/mixins.py b/data_diff/abcs/mixins.py similarity index 98% rename from data_diff/sqeleton/abcs/mixins.py rename to data_diff/abcs/mixins.py index e33129a2..17f06064 100644 --- a/data_diff/sqeleton/abcs/mixins.py +++ b/data_diff/abcs/mixins.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from data_diff.sqeleton.abcs.database_types import ( +from data_diff.abcs.database_types import ( Array, TemporalType, FractionalType, @@ -10,7 +10,7 @@ JSON, Struct, ) -from data_diff.sqeleton.abcs.compiler import Compilable +from data_diff.abcs.compiler import Compilable class AbstractMixin(ABC): diff --git a/data_diff/sqeleton/bound_exprs.py b/data_diff/bound_exprs.py similarity index 85% rename from data_diff/sqeleton/bound_exprs.py rename to data_diff/bound_exprs.py index 8bbb3063..1742b74c 100644 --- a/data_diff/sqeleton/bound_exprs.py +++ b/data_diff/bound_exprs.py @@ -7,10 +7,11 @@ from runtype import dataclass from typing_extensions import Self -from data_diff.sqeleton.abcs import AbstractDatabase, AbstractCompiler -from data_diff.sqeleton.queries.ast_classes import ExprNode, ITable, TablePath, Compilable -from data_diff.sqeleton.queries.api import table -from data_diff.sqeleton.schema import create_schema +from data_diff.abcs.database_types import AbstractDatabase +from data_diff.abcs.compiler import AbstractCompiler +from data_diff.queries.ast_classes import ExprNode, TablePath, Compilable +from data_diff.queries.api import table +from data_diff.schema import create_schema @dataclass @@ -80,8 +81,8 @@ def bound_table(database: AbstractDatabase, table_path: Union[TablePath, str, tu # Database.table = bound_table # def test(): -# from data_diff.sqeleton. import connect -# from data_diff.sqeleton.queries.api import table +# from data_diff import connect +# from data_diff.queries.api import table # d = connect("mysql://erez:qweqwe123@localhost/erez") # t = table(('Rating',)) diff --git a/data_diff/databases/__init__.py b/data_diff/databases/__init__.py index cae67d1e..842cc731 100644 --- a/data_diff/databases/__init__.py +++ b/data_diff/databases/__init__.py @@ -1,17 +1,16 @@ -from data_diff.sqeleton.databases import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError - -from data_diff.databases.postgresql import PostgreSQL -from data_diff.databases.mysql import MySQL -from data_diff.databases.oracle import Oracle -from data_diff.databases.snowflake import Snowflake -from data_diff.databases.bigquery import BigQuery -from data_diff.databases.redshift import Redshift -from data_diff.databases.presto import Presto -from data_diff.databases.databricks import Databricks -from data_diff.databases.trino import Trino -from data_diff.databases.clickhouse import Clickhouse -from data_diff.databases.vertica import Vertica -from data_diff.databases.duckdb import DuckDB -from data_diff.databases.mssql import MsSql - -from data_diff.databases._connect import connect +from data_diff.databases.base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError, BaseDialect, Database +from data_diff.databases._connect import connect as connect +from data_diff.databases._connect import Connect as Connect +from data_diff.databases.postgresql import PostgreSQL as PostgreSQL +from data_diff.databases.mysql import MySQL as MySQL +from data_diff.databases.oracle import Oracle as Oracle +from data_diff.databases.snowflake import Snowflake as Snowflake +from data_diff.databases.bigquery import BigQuery as BigQuery +from data_diff.databases.redshift import Redshift as Redshift +from data_diff.databases.presto import Presto as Presto +from data_diff.databases.databricks import Databricks as Databricks +from data_diff.databases.trino import Trino as Trino +from data_diff.databases.clickhouse import Clickhouse as Clickhouse +from data_diff.databases.vertica import Vertica as Vertica +from data_diff.databases.duckdb import DuckDB as DuckDB +from data_diff.databases.mssql import MsSQL as MsSQL diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index fcef1069..8f842123 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -1,7 +1,15 @@ import logging +from typing import Hashable, MutableMapping, Type, Optional, Union, Dict +from itertools import zip_longest +from contextlib import suppress +import weakref +import dsnparse +import toml -from data_diff.sqeleton.databases import Connect +from runtype import dataclass +from typing_extensions import Self +from data_diff.databases.base import Database, ThreadedDatabase from data_diff.databases.postgresql import PostgreSQL from data_diff.databases.mysql import MySQL from data_diff.databases.oracle import Oracle @@ -14,7 +22,57 @@ from data_diff.databases.clickhouse import Clickhouse from data_diff.databases.vertica import Vertica from data_diff.databases.duckdb import DuckDB -from data_diff.databases.mssql import MsSql +from data_diff.databases.mssql import MsSQL + + +@dataclass +class MatchUriPath: + database_cls: Type[Database] + + def match_path(self, dsn): + help_str = self.database_cls.CONNECT_URI_HELP + params = self.database_cls.CONNECT_URI_PARAMS + kwparams = self.database_cls.CONNECT_URI_KWPARAMS + + dsn_dict = dict(dsn.query) + matches = {} + for param, arg in zip_longest(params, dsn.paths): + if param is None: + raise ValueError(f"Too many parts to path. Expected format: {help_str}") + + optional = param.endswith("?") + param = param.rstrip("?") + + if arg is None: + try: + arg = dsn_dict.pop(param) + except KeyError: + if not optional: + raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") + + arg = None + + assert param and param not in matches + matches[param] = arg + + for param in kwparams: + try: + arg = dsn_dict.pop(param) + except KeyError: + raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") + + assert param and arg and param not in matches, (param, arg, matches.keys()) + matches[param] = arg + + for param, value in dsn_dict.items(): + if param in matches: + raise ValueError( + f"Parameter '{param}' already provided as positional argument. Expected format: {help_str}" + ) + + matches[param] = value + + return matches DATABASE_BY_SCHEME = { @@ -30,10 +88,201 @@ "trino": Trino, "clickhouse": Clickhouse, "vertica": Vertica, - "mssql": MsSql, + "mssql": MsSQL, } +class Connect: + """Provides methods for connecting to a supported database using a URL or connection dict.""" + conn_cache: MutableMapping[Hashable, Database] + + def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): + self.database_by_scheme = database_by_scheme + self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()} + self.conn_cache = weakref.WeakValueDictionary() + + def for_databases(self, *dbs) -> Self: + database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs} + return type(self)(database_by_scheme) + + def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs) -> Database: + """Connect to the given database uri + + thread_count determines the max number of worker threads per database, + if relevant. None means no limit. + + Parameters: + db_uri (str): The URI for the database to connect + thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) + + Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. + + Supported schemes: + - postgresql + - mysql + - oracle + - snowflake + - bigquery + - redshift + - presto + - databricks + - trino + - clickhouse + - vertica + - duckdb + """ + + dsn = dsnparse.parse(db_uri) + if len(dsn.schemes) > 1: + raise NotImplementedError("No support for multiple schemes") + (scheme,) = dsn.schemes + + if scheme == "toml": + toml_path = dsn.path or dsn.host + database = dsn.fragment + if not database: + raise ValueError("Must specify a database name, e.g. 'toml://path#database'. ") + with open(toml_path) as f: + config = toml.load(f) + try: + conn_dict = config["database"][database] + except KeyError: + raise ValueError(f"Cannot find database config named '{database}'.") + return self.connect_with_dict(conn_dict, thread_count, **kwargs) + + try: + matcher = self.match_uri_path[scheme] + except KeyError: + raise NotImplementedError(f"Scheme '{scheme}' currently not supported") + + cls = matcher.database_cls + + if scheme == "databricks": + assert not dsn.user + kw = {} + kw["access_token"] = dsn.password + kw["http_path"] = dsn.path + kw["server_hostname"] = dsn.host + kw.update(dsn.query) + elif scheme == "duckdb": + kw = {} + kw["filepath"] = dsn.dbname + kw["dbname"] = dsn.user + else: + kw = matcher.match_path(dsn) + + if scheme == "bigquery": + kw["project"] = dsn.host + return cls(**kw, **kwargs) + + if scheme == "snowflake": + kw["account"] = dsn.host + assert not dsn.port + kw["user"] = dsn.user + kw["password"] = dsn.password + else: + if scheme == "oracle": + kw["host"] = dsn.hostloc + else: + kw["host"] = dsn.host + kw["port"] = dsn.port + kw["user"] = dsn.user + if dsn.password: + kw["password"] = dsn.password + + kw = {k: v for k, v in kw.items() if v is not None} + + if issubclass(cls, ThreadedDatabase): + db = cls(thread_count=thread_count, **kw, **kwargs) + else: + db = cls(**kw, **kwargs) + + return self._connection_created(db) + + def connect_with_dict(self, d, thread_count, **kwargs): + d = dict(d) + driver = d.pop("driver") + try: + matcher = self.match_uri_path[driver] + except KeyError: + raise NotImplementedError(f"Driver '{driver}' currently not supported") + + cls = matcher.database_cls + if issubclass(cls, ThreadedDatabase): + db = cls(thread_count=thread_count, **d, **kwargs) + else: + db = cls(**d, **kwargs) + + return self._connection_created(db) + + def _connection_created(self, db): + "Nop function to be overridden by subclasses." + return db + + def __call__( + self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True, **kwargs + ) -> Database: + """Connect to a database using the given database configuration. + + Configuration can be given either as a URI string, or as a dict of {option: value}. + + The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf. + + thread_count determines the max number of worker threads per database, + if relevant. None means no limit. + + Parameters: + db_conf (str | dict): The configuration for the database to connect. URI or dict. + thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) + shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True) + bigquery_credentials (google.oauth2.credentials.Credentials): Custom Google oAuth2 credential for BigQuery. + (default: None) + + Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. + + Supported drivers: + - postgresql + - mysql + - oracle + - snowflake + - bigquery + - redshift + - presto + - databricks + - trino + - clickhouse + - vertica + + Example: + >>> connect("mysql://localhost/db") + + >>> connect({"driver": "mysql", "host": "localhost", "database": "db"}) + + """ + cache_key = self.__make_cache_key(db_conf) + if shared: + with suppress(KeyError): + conn = self.conn_cache[cache_key] + if not conn.is_closed: + return conn + + if isinstance(db_conf, str): + conn = self.connect_to_uri(db_conf, thread_count, **kwargs) + elif isinstance(db_conf, dict): + conn = self.connect_with_dict(db_conf, thread_count, **kwargs) + else: + raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.") + + if shared: + self.conn_cache[cache_key] = conn + return conn + + def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable: + if isinstance(db_conf, dict): + return tuple(db_conf.items()) + return db_conf + + class Connect_SetUTC(Connect): """Provides methods for connecting to a supported database using a URL or connection dict. diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 5b7ff5ce..a89ab74e 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,5 +1,607 @@ -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from datetime import datetime +import math +import sys +import logging +from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar +from functools import partial, wraps +from concurrent.futures import ThreadPoolExecutor +import threading +from abc import abstractmethod +from uuid import UUID +import decimal +from runtype import dataclass +from typing_extensions import Self -class DatadiffDialect(AbstractMixin_MD5, AbstractMixin_NormalizeValue): +from data_diff.utils import is_uuid, safezip +from data_diff.queries.api import Expr, Compiler, table, Select, SKIP, Explain, Code, this +from data_diff.queries.ast_classes import Random +from data_diff.abcs.database_types import ( + AbstractDatabase, + Array, + Struct, + AbstractDialect, + AbstractTable, + ColType, + Integer, + Decimal, + Float, + Native_UUID, + String_UUID, + String_Alphanum, + String_VaryingAlphanum, + TemporalType, + UnknownColType, + TimestampTZ, + Text, + DbTime, + DbPath, + Boolean, + JSON, +) +from data_diff.abcs.mixins import Compilable +from data_diff.abcs.mixins import ( + AbstractMixin_Schema, + AbstractMixin_RandomSample, + AbstractMixin_NormalizeValue, + AbstractMixin_OptimizerHints, +) +from data_diff.bound_exprs import bound_table + +logger = logging.getLogger("database") + + +def parse_table_name(t): + return tuple(t.split(".")) + + +def import_helper(package: str = None, text=""): + def dec(f): + @wraps(f) + def _inner(): + try: + return f() + except ModuleNotFoundError as e: + s = text + if package: + s += f"Please complete setup by running: pip install 'data_diff[{package}]'." + raise ModuleNotFoundError(f"{e}\n\n{s}\n") + + return _inner + + return dec + + +class ConnectError(Exception): + pass + + +class QueryError(Exception): pass + + +def _one(seq): + (x,) = seq + return x + + +class ThreadLocalInterpreter: + """An interpeter used to execute a sequence of queries within the same thread and cursor. + + Useful for cursor-sensitive operations, such as creating a temporary table. + """ + + def __init__(self, compiler: Compiler, gen: Generator): + self.gen = gen + self.compiler = compiler + + def apply_queries(self, callback: Callable[[str], Any]): + q: Expr = next(self.gen) + while True: + sql = self.compiler.compile(q) + logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql) + try: + try: + res = callback(sql) if sql is not SKIP else SKIP + except Exception as e: + q = self.gen.throw(type(e), e) + else: + q = self.gen.send(res) + except StopIteration: + break + + +def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocalInterpreter]) -> list: + if isinstance(sql_code, ThreadLocalInterpreter): + return sql_code.apply_queries(callback) + else: + return callback(sql_code) + + +class Mixin_Schema(AbstractMixin_Schema): + def table_information(self) -> Compilable: + return table("information_schema", "tables") + + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + self.table_information() + .where( + this.table_schema == table_schema, + this.table_name.like(like) if like is not None else SKIP, + this.table_type == "BASE TABLE", + ) + .select(this.table_name) + ) + + +class Mixin_RandomSample(AbstractMixin_RandomSample): + def random_sample_n(self, tbl: AbstractTable, size: int) -> AbstractTable: + # TODO use a more efficient algorithm, when the table count is known + return tbl.order_by(Random()).limit(size) + + def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> AbstractTable: + return tbl.where(Random() < ratio) + + +class Mixin_OptimizerHints(AbstractMixin_OptimizerHints): + def optimizer_hints(self, hints: str) -> str: + return f"/*+ {hints} */ " + + +class BaseDialect(AbstractDialect): + SUPPORTS_PRIMARY_KEY = False + SUPPORTS_INDEXES = False + TYPE_CLASSES: Dict[str, type] = {} + MIXINS = frozenset() + + PLACEHOLDER_TABLE = None # Used for Oracle + + def offset_limit( + self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + ) -> str: + if offset: + raise NotImplementedError("No support for OFFSET in query") + + return f"LIMIT {limit}" + + def concat(self, items: List[str]) -> str: + assert len(items) > 1 + joined_exprs = ", ".join(items) + return f"concat({joined_exprs})" + + def to_comparable(self, value: str, coltype: ColType) -> str: + """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" + return value + + def is_distinct_from(self, a: str, b: str) -> str: + return f"{a} is distinct from {b}" + + def timestamp_value(self, t: DbTime) -> str: + return f"'{t.isoformat()}'" + + def random(self) -> str: + return "random()" + + def current_timestamp(self) -> str: + return "current_timestamp()" + + def current_database(self) -> str: + return "current_database()" + + def current_schema(self) -> str: + return "current_schema()" + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN {query}" + + def _constant_value(self, v): + if v is None: + return "NULL" + elif isinstance(v, str): + return f"'{v}'" + elif isinstance(v, datetime): + return self.timestamp_value(v) + elif isinstance(v, UUID): + return f"'{v}'" + elif isinstance(v, decimal.Decimal): + return str(v) + elif isinstance(v, bytearray): + return f"'{v.decode()}'" + elif isinstance(v, Code): + return v.code + return repr(v) + + def constant_values(self, rows) -> str: + values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) + return f"VALUES {values}" + + def type_repr(self, t) -> str: + if isinstance(t, str): + return t + elif isinstance(t, TimestampTZ): + return f"TIMESTAMP({min(t.precision, DEFAULT_DATETIME_PRECISION)})" + return { + int: "INT", + str: "VARCHAR", + bool: "BOOLEAN", + float: "FLOAT", + datetime: "TIMESTAMP", + }[t] + + def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: + return self.TYPE_CLASSES.get(type_repr) + + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + """ """ + + cls = self._parse_type_repr(type_repr) + if cls is None: + return UnknownColType(type_repr) + + if issubclass(cls, TemporalType): + return cls( + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, + rounds=self.ROUNDS_ON_PREC_LOSS, + ) + + elif issubclass(cls, Integer): + return cls() + + elif issubclass(cls, Boolean): + return cls() + + elif issubclass(cls, Decimal): + if numeric_scale is None: + numeric_scale = 0 # Needed for Oracle. + return cls(precision=numeric_scale) + + elif issubclass(cls, Float): + # assert numeric_scale is None + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) + + elif issubclass(cls, (JSON, Array, Struct, Text, Native_UUID)): + return cls() + + raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.") + + def _convert_db_precision_to_digits(self, p: int) -> int: + """Convert from binary precision, used by floats, to decimal precision.""" + # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format + return math.floor(math.log(2**p, 10)) + + @classmethod + def load_mixins(cls, *abstract_mixins) -> Self: + mixins = {m for m in cls.MIXINS if issubclass(m, abstract_mixins)} + + class _DialectWithMixins(cls, *mixins, *abstract_mixins): + pass + + _DialectWithMixins.__name__ = cls.__name__ + return _DialectWithMixins() + + +T = TypeVar("T", bound=BaseDialect) + + +@dataclass +class QueryResult: + rows: list + columns: list = None + + def __iter__(self): + return iter(self.rows) + + def __len__(self): + return len(self.rows) + + def __getitem__(self, i): + return self.rows[i] + + +class Database(AbstractDatabase[T]): + """Base abstract class for databases. + + Used for providing connection code and implementation specific SQL utilities. + + Instanciated using :meth:`~data_diff.connect` + """ + + default_schema: str = None + SUPPORTS_ALPHANUMS = True + SUPPORTS_UNIQUE_CONSTAINT = False + + CONNECT_URI_KWPARAMS = [] + + _interactive = False + is_closed = False + + @property + def name(self): + return type(self).__name__ + + def compile(self, sql_ast): + compiler = Compiler(self) + return compiler.compile(sql_ast) + + def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): + """Query the given SQL code/AST, and attempt to convert the result to type 'res_type' + + If given a generator, it will execute all the yielded sql queries with the same thread and cursor. + The results of the queries a returned by the `yield` stmt (using the .send() mechanism). + It's a cleaner approach than exposing cursors, but may not be enough in all cases. + """ + + compiler = Compiler(self) + if isinstance(sql_ast, Generator): + sql_code = ThreadLocalInterpreter(compiler, sql_ast) + elif isinstance(sql_ast, list): + for i in sql_ast[:-1]: + self.query(i) + return self.query(sql_ast[-1], res_type) + else: + if isinstance(sql_ast, str): + sql_code = sql_ast + else: + if res_type is None: + res_type = sql_ast.type + sql_code = compiler.compile(sql_ast) + if sql_code is SKIP: + return SKIP + + logger.debug("Running SQL (%s): %s", self.name, sql_code) + + if self._interactive and isinstance(sql_ast, Select): + explained_sql = compiler.compile(Explain(sql_ast)) + explain = self._query(explained_sql) + for row in explain: + # Most returned a 1-tuple. Presto returns a string + if isinstance(row, tuple): + (row,) = row + logger.debug("EXPLAIN: %s", row) + answer = input("Continue? [y/n] ") + if answer.lower() not in ["y", "yes"]: + sys.exit(1) + + res = self._query(sql_code) + if res_type is list: + return list(res) + elif res_type is int: + if not res: + raise ValueError("Query returned 0 rows, expected 1") + row = _one(res) + if not row: + raise ValueError("Row is empty, expected 1 column") + res = _one(row) + if res is None: # May happen due to sum() of 0 items + return None + return int(res) + elif res_type is datetime: + res = _one(_one(res)) + if isinstance(res, str): + res = datetime.fromisoformat(res[:23]) # TODO use a better parsing method + return res + elif res_type is tuple: + assert len(res) == 1, (sql_code, res) + return res[0] + elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1: + if res_type.__args__ in ((int,), (str,)): + return [_one(row) for row in res] + elif res_type.__args__ in [(Tuple,), (tuple,)]: + return [tuple(row) for row in res] + elif res_type.__args__ == (dict,): + return [dict(safezip(res.columns, row)) for row in res] + else: + raise ValueError(res_type) + return res + + def enable_interactive(self): + self._interactive = True + + def select_table_schema(self, path: DbPath) -> str: + """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" + schema, name = self._normalize_table_path(path) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + "FROM information_schema.columns " + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + ) + + def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + rows = self.query(self.select_table_schema(path), list) + if not rows: + raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") + + d = {r[0]: r for r in rows} + assert len(d) == len(rows) + return d + + def select_table_unique_columns(self, path: DbPath) -> str: + schema, name = self._normalize_table_path(path) + + return ( + "SELECT column_name " + "FROM information_schema.key_column_usage " + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + ) + + def query_table_unique_columns(self, path: DbPath) -> List[str]: + if not self.SUPPORTS_UNIQUE_CONSTAINT: + raise NotImplementedError("This database doesn't support 'unique' constraints") + res = self.query(self.select_table_unique_columns(path), List[str]) + return list(res) + + def _process_table_schema( + self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str] = None, where: str = None + ): + if filter_columns is None: + filtered_schema = raw_schema + else: + accept = {i.lower() for i in filter_columns} + filtered_schema = {name: row for name, row in raw_schema.items() if name.lower() in accept} + + col_dict = {row[0]: self.dialect.parse_type(path, *row) for _name, row in filtered_schema.items()} + + self._refine_coltypes(path, col_dict, where) + + # Return a dict of form {name: type} after normalization + return col_dict + + def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], where: str = None, sample_size=64): + """Refine the types in the column dict, by querying the database for a sample of their values + + 'where' restricts the rows to be sampled. + """ + + text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)] + if not text_columns: + return + + if isinstance(self.dialect, AbstractMixin_NormalizeValue): + fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns] + else: + fields = this[text_columns] + + samples_by_row = self.query( + table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list + ) + if not samples_by_row: + raise ValueError(f"Table {table_path} is empty.") + + samples_by_col = list(zip(*samples_by_row)) + + for col_name, samples in safezip(text_columns, samples_by_col): + uuid_samples = [s for s in samples if s and is_uuid(s)] + + if uuid_samples: + if len(uuid_samples) != len(samples): + logger.warning( + f"Mixed UUID/Non-UUID values detected in column {'.'.join(table_path)}.{col_name}, disabling UUID support." + ) + else: + assert col_name in col_dict + col_dict[col_name] = String_UUID() + continue + + if self.SUPPORTS_ALPHANUMS: # Anything but MySQL (so far) + alphanum_samples = [s for s in samples if String_Alphanum.test_value(s)] + if alphanum_samples: + if len(alphanum_samples) != len(samples): + logger.debug( + f"Mixed Alphanum/Non-Alphanum values detected in column {'.'.join(table_path)}.{col_name}. It cannot be used as a key." + ) + else: + assert col_name in col_dict + col_dict[col_name] = String_VaryingAlphanum() + + # @lru_cache() + # def get_table_schema(self, path: DbPath) -> Dict[str, ColType]: + # return self.query_table_schema(path) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return self.default_schema, path[0] + elif len(path) == 2: + return path + + raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") + + def parse_table_name(self, name: str) -> DbPath: + return parse_table_name(name) + + def _query_cursor(self, c, sql_code: str) -> QueryResult: + assert isinstance(sql_code, str), sql_code + try: + c.execute(sql_code) + if sql_code.lower().startswith(("select", "explain", "show")): + columns = [col[0] for col in c.description] + + fetched = c.fetchall() + result = QueryResult(fetched, columns) + return result + except Exception as _e: + # logger.exception(e) + # logger.error(f'Caused by SQL: {sql_code}') + raise + + def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult: + c = conn.cursor() + callback = partial(self._query_cursor, c) + return apply_query(callback, sql_code) + + def close(self): + self.is_closed = True + return super().close() + + def list_tables(self, tables_like, schema=None): + return self.query(self.dialect.list_tables(schema or self.default_schema, tables_like)) + + def table(self, *path, **kw): + return bound_table(self, path, **kw) + + +class ThreadedDatabase(Database): + """Access the database through singleton threads. + + Used for database connectors that do not support sharing their connection between different threads. + """ + + def __init__(self, thread_count=1): + self._init_error = None + self._queue = ThreadPoolExecutor(thread_count, initializer=self.set_conn) + self.thread_local = threading.local() + logger.info(f"[{self.name}] Starting a threadpool, size={thread_count}.") + + def set_conn(self): + assert not hasattr(self.thread_local, "conn") + try: + self.thread_local.conn = self.create_connection() + except Exception as e: + self._init_error = e + + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult: + r = self._queue.submit(self._query_in_worker, sql_code) + return r.result() + + def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]): + "This method runs in a worker thread" + if self._init_error: + raise self._init_error + return self._query_conn(self.thread_local.conn, sql_code) + + @abstractmethod + def create_connection(self): + "Return a connection instance, that supports the .cursor() method." + + def close(self): + super().close() + self._queue.shutdown() + + @property + def is_autocommit(self) -> bool: + return False + + +# TODO FYI mssql md5_as_int currently requires this to be reduced +CHECKSUM_HEXDIGITS = 14 # Must be 15 or lower, otherwise SUM() overflows +MD5_HEXDIGITS = 32 + +_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2 +CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1 + +DEFAULT_DATETIME_PRECISION = 6 +DEFAULT_NUMERIC_PRECISION = 24 + +TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20 diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index a6fdbc9c..5925234f 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,10 +1,297 @@ -from data_diff.sqeleton.databases import bigquery -from data_diff.databases.base import DatadiffDialect +import re +from typing import Any, List, Union +from data_diff.abcs.database_types import ( + ColType, + Array, + JSON, + Struct, + Timestamp, + Datetime, + Integer, + Decimal, + Float, + Text, + DbPath, + FractionalType, + TemporalType, + Boolean, + UnknownColType, +) +from data_diff.abcs.mixins import ( + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, + AbstractMixin_Schema, + AbstractMixin_TimeTravel, +) +from data_diff.abcs.compiler import Compilable +from data_diff.queries.api import this, table, SKIP, code +from data_diff.databases.base import ( + BaseDialect, + Database, + import_helper, + parse_table_name, + ConnectError, + apply_query, + QueryResult, +) +from data_diff.databases.base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter, Mixin_RandomSample -class Dialect(bigquery.Dialect, bigquery.Mixin_MD5, bigquery.Mixin_NormalizeValue, DatadiffDialect): - pass +@import_helper(text="Please install BigQuery and configure your google-cloud access.") +def import_bigquery(): + from google.cloud import bigquery + return bigquery -class BigQuery(bigquery.BigQuery): + +def import_bigquery_service_account(): + from google.oauth2 import service_account + + return service_account + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" + + if coltype.precision == 0: + return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000', {value})" + elif coltype.precision == 6: + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + + timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return f"format('%.{coltype.precision}f', {value})" + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"cast({value} as int)") + + def normalize_json(self, value: str, _coltype: JSON) -> str: + # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: + # Got error: 400 Grouping is not defined for arguments of type ARRAY at … + # So we do the best effort and compare it as strings, hoping that the JSON forms + # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. + return f"to_json_string({value})" + + def normalize_array(self, value: str, _coltype: Array) -> str: + # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: + # Got error: 400 Grouping is not defined for arguments of type ARRAY at … + # So we do the best effort and compare it as strings, hoping that the JSON forms + # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. + return f"to_json_string({value})" + + def normalize_struct(self, value: str, _coltype: Struct) -> str: + # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: + # Got error: 400 Grouping is not defined for arguments of type ARRAY at … + # So we do the best effort and compare it as strings, hoping that the JSON forms + # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. + return f"to_json_string({value})" + + +class Mixin_Schema(AbstractMixin_Schema): + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + table(table_schema, "INFORMATION_SCHEMA", "TABLES") + .where( + this.table_schema == table_schema, + this.table_name.like(like) if like is not None else SKIP, + this.table_type == "BASE TABLE", + ) + .select(this.table_name) + ) + + +class Mixin_TimeTravel(AbstractMixin_TimeTravel): + def time_travel( + self, + table: Compilable, + before: bool = False, + timestamp: Compilable = None, + offset: Compilable = None, + statement: Compilable = None, + ) -> Compilable: + if before: + raise NotImplementedError("before=True not supported for BigQuery time-travel") + + if statement is not None: + raise NotImplementedError("BigQuery time-travel doesn't support querying by statement id") + + if timestamp is not None: + assert offset is None + return code("{table} FOR SYSTEM_TIME AS OF {timestamp}", table=table, timestamp=timestamp) + + assert offset is not None + return code( + "{table} FOR SYSTEM_TIME AS OF TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL {offset} HOUR);", + table=table, + offset=offset, + ) + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "BigQuery" + ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation + TYPE_CLASSES = { + # Dates + "TIMESTAMP": Timestamp, + "DATETIME": Datetime, + # Numbers + "INT64": Integer, + "INT32": Integer, + "NUMERIC": Decimal, + "BIGNUMERIC": Decimal, + "FLOAT64": Float, + "FLOAT32": Float, + "STRING": Text, + "BOOL": Boolean, + "JSON": JSON, + } + TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>") + TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>") + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} + + def random(self) -> str: + return "RAND()" + + def quote(self, s: str): + return f"`{s}`" + + def to_string(self, s: str): + return f"cast({s} as string)" + + def type_repr(self, t) -> str: + try: + return {str: "STRING", float: "FLOAT64"}[t] + except KeyError: + return super().type_repr(t) + + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + *args: Any, # pass-through args + **kwargs: Any, # pass-through args + ) -> ColType: + col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs) + if isinstance(col_type, UnknownColType): + m = self.TYPE_ARRAY_RE.fullmatch(type_repr) + if m: + item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs) + col_type = Array(item_type=item_type) + + # We currently ignore structs' structure, but later can parse it too. Examples: + # - STRUCT (unnamed) + # - STRUCT (named) + # - STRUCT> (with complex fields) + # - STRUCT> (nested) + m = self.TYPE_STRUCT_RE.fullmatch(type_repr) + if m: + col_type = Struct() + + return col_type + + def to_comparable(self, value: str, coltype: ColType) -> str: + """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" + if isinstance(coltype, (JSON, Array, Struct)): + return self.normalize_value_by_type(value, coltype) + else: + return super().to_comparable(value, coltype) + + def set_timezone_to_utc(self) -> str: + raise NotImplementedError() + + +class BigQuery(Database): + CONNECT_URI_HELP = "bigquery:///" + CONNECT_URI_PARAMS = ["dataset"] dialect = Dialect() + + def __init__(self, project, *, dataset, bigquery_credentials=None, **kw): + credentials = bigquery_credentials + bigquery = import_bigquery() + + keyfile = kw.pop("keyfile", None) + if keyfile: + bigquery_service_account = import_bigquery_service_account() + credentials = bigquery_service_account.Credentials.from_service_account_file( + keyfile, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + self._client = bigquery.Client(project=project, credentials=credentials, **kw) + self.project = project + self.dataset = dataset + + self.default_schema = dataset + + def _normalize_returned_value(self, value): + if isinstance(value, bytes): + return value.decode() + return value + + def _query_atom(self, sql_code: str): + from google.cloud import bigquery + + try: + result = self._client.query(sql_code).result() + columns = [c.name for c in result.schema] + rows = list(result) + except Exception as e: + msg = "Exception when trying to execute SQL code:\n %s\n\nGot error: %s" + raise ConnectError(msg % (sql_code, e)) + + if rows and isinstance(rows[0], bigquery.table.Row): + rows = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in rows] + return QueryResult(rows, columns) + + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult: + return apply_query(self._query_atom, sql_code) + + def close(self): + super().close() + self._client.close() + + def select_table_schema(self, path: DbPath) -> str: + project, schema, name = self._normalize_table_path(path) + return ( + "SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale " + f"FROM `{project}`.`{schema}`.INFORMATION_SCHEMA.COLUMNS " + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + ) + + def query_table_unique_columns(self, path: DbPath) -> List[str]: + return [] + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 0: + raise ValueError(f"{self.name}: Bad table path for {self}: ()") + elif len(path) == 1: + return (self.project, self.default_schema, path[0]) + elif len(path) == 2: + return (self.project,) + path + elif len(path) == 3: + return path + else: + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: [project.]schema.table" + ) + + def parse_table_name(self, name: str) -> DbPath: + path = parse_table_name(name) + return tuple(i for i in self._normalize_table_path(path) if i is not None) + + @property + def is_autocommit(self) -> bool: + return True diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index ce22943b..9366b922 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -1,10 +1,196 @@ -from data_diff.sqeleton.databases import clickhouse -from data_diff.databases.base import DatadiffDialect +from typing import Optional, Type +from data_diff.databases.base import ( + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + TIMESTAMP_PRECISION_POS, + BaseDialect, + ThreadedDatabase, + import_helper, + ConnectError, + Mixin_RandomSample, +) +from data_diff.abcs.database_types import ( + ColType, + Decimal, + Float, + Integer, + FractionalType, + Native_UUID, + TemporalType, + Text, + Timestamp, + Boolean, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -class Dialect(clickhouse.Dialect, clickhouse.Mixin_MD5, clickhouse.Mixin_NormalizeValue, DatadiffDialect): - pass +# https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings/#default-database +DEFAULT_DATABASE = "default" -class Clickhouse(clickhouse.Clickhouse): +@import_helper("clickhouse") +def import_clickhouse(): + import clickhouse_driver + + return clickhouse_driver + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS + return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_number(self, value: str, coltype: FractionalType) -> str: + # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. + # For example: + # select toString(toDecimal128(1.10, 2)); -- the result is 1.1 + # select toString(toDecimal128(1.00, 2)); -- the result is 1 + # So, we should use some custom approach to save these trailing zeros. + # To avoid it, we can add a small value like 0.000001 to prevent dropping of zeros from the end when casting. + # For examples above it looks like: + # select toString(toDecimal128(1.10, 2 + 1) + toDecimal128(0.001, 3)); -- the result is 1.101 + # After that, cut an extra symbol from the string, i.e. 1.101 -> 1.10 + # So, the algorithm is: + # 1. Cast to decimal with precision + 1 + # 2. Add a small value 10^(-precision-1) + # 3. Cast the result to string + # 4. Drop the extra digit from the string. To do that, we need to slice the string + # with length = digits in an integer part + 1 (symbol of ".") + precision + + if coltype.precision == 0: + return self.to_string(f"round({value})") + + precision = coltype.precision + # TODO: too complex, is there better performance way? + value = f""" + if({value} >= 0, '', '-') || left( + toString( + toDecimal128( + round(abs({value}), {precision}), + {precision} + 1 + ) + + + toDecimal128( + exp10(-{precision + 1}), + {precision} + 1 + ) + ), + toUInt8( + greatest( + floor(log10(abs({value}))) + 1, + 1 + ) + ) + 1 + {precision} + ) + """ + return value + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + prec = coltype.precision + if coltype.rounds: + timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)" + return self.to_string(timestamp) + + fractional = f"toUnixTimestamp64Micro(toDateTime64({value}, {prec})) % 1000000" + fractional = f"lpad({self.to_string(fractional)}, 6, '0')" + value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}" + return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" + + +class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Clickhouse" + ROUNDS_ON_PREC_LOSS = False + TYPE_CLASSES = { + "Int8": Integer, + "Int16": Integer, + "Int32": Integer, + "Int64": Integer, + "Int128": Integer, + "Int256": Integer, + "UInt8": Integer, + "UInt16": Integer, + "UInt32": Integer, + "UInt64": Integer, + "UInt128": Integer, + "UInt256": Integer, + "Float32": Float, + "Float64": Float, + "Decimal": Decimal, + "UUID": Native_UUID, + "String": Text, + "FixedString": Text, + "DateTime": Timestamp, + "DateTime64": Timestamp, + "Bool": Boolean, + } + MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str) -> str: + return f'"{s}"' + + def to_string(self, s: str) -> str: + return f"toString({s})" + + def _convert_db_precision_to_digits(self, p: int) -> int: + # Done the same as for PostgreSQL but need to rewrite in another way + # because it does not help for float with a big integer part. + return super()._convert_db_precision_to_digits(p) - 2 + + def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: + nullable_prefix = "Nullable(" + if type_repr.startswith(nullable_prefix): + type_repr = type_repr[len(nullable_prefix) :].rstrip(")") + + if type_repr.startswith("Decimal"): + type_repr = "Decimal" + elif type_repr.startswith("FixedString"): + type_repr = "FixedString" + elif type_repr.startswith("DateTime64"): + type_repr = "DateTime64" + + return self.TYPE_CLASSES.get(type_repr) + + # def timestamp_value(self, t: DbTime) -> str: + # # return f"'{t}'" + # return f"'{str(t)[:19]}'" + + def set_timezone_to_utc(self) -> str: + raise NotImplementedError() + + def current_timestamp(self) -> str: + return "now()" + + +class Clickhouse(ThreadedDatabase): dialect = Dialect() + CONNECT_URI_HELP = "clickhouse://:@/" + CONNECT_URI_PARAMS = ["database?"] + + def __init__(self, *, thread_count: int, **kw): + super().__init__(thread_count=thread_count) + + self._args = kw + # In Clickhouse database and schema are the same + self.default_schema = kw.get("database", DEFAULT_DATABASE) + + def create_connection(self): + clickhouse = import_clickhouse() + + class SingleConnection(clickhouse.dbapi.connection.Connection): + """Not thread-safe connection to Clickhouse""" + + def cursor(self, cursor_factory=None): + if not len(self.cursors): + _ = super().cursor() + return self.cursors[0] + + try: + return SingleConnection(**self._args) + except clickhouse.OperationError as e: + raise ConnectError(*e.args) from e + + @property + def is_autocommit(self) -> bool: + return True diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 6794c264..1b8aa33a 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,10 +1,199 @@ -from data_diff.sqeleton.databases import databricks -from data_diff.databases.base import DatadiffDialect +import math +from typing import Dict, Sequence +import logging +from data_diff.abcs.database_types import ( + Integer, + Float, + Decimal, + Timestamp, + Text, + TemporalType, + NumericType, + DbPath, + ColType, + UnknownColType, + Boolean, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.databases.base import ( + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + BaseDialect, + ThreadedDatabase, + import_helper, + parse_table_name, + Mixin_RandomSample, +) -class Dialect(databricks.Dialect, databricks.Mixin_MD5, databricks.Mixin_NormalizeValue, DatadiffDialect): - pass +@import_helper(text="You can install it using 'pip install databricks-sql-connector'") +def import_databricks(): + import databricks.sql -class Databricks(databricks.Databricks): + return databricks + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + """Databricks timestamp contains no more than 6 digits in precision""" + + if coltype.rounds: + timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)" + return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')" + + precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision) + return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" + + def normalize_number(self, value: str, coltype: NumericType) -> str: + value = f"cast({value} as decimal(38, {coltype.precision}))" + if coltype.precision > 0: + value = f"format_number({value}, {coltype.precision})" + return f"replace({self.to_string(value)}, ',', '')" + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"cast ({value} as int)") + + +class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Databricks" + ROUNDS_ON_PREC_LOSS = True + TYPE_CLASSES = { + # Numbers + "INT": Integer, + "SMALLINT": Integer, + "TINYINT": Integer, + "BIGINT": Integer, + "FLOAT": Float, + "DOUBLE": Float, + "DECIMAL": Decimal, + # Timestamps + "TIMESTAMP": Timestamp, + # Text + "STRING": Text, + # Boolean + "BOOLEAN": Boolean, + } + MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str): + return f"`{s}`" + + def to_string(self, s: str) -> str: + return f"cast({s} as string)" + + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues + return max(super()._convert_db_precision_to_digits(p) - 2, 0) + + def set_timezone_to_utc(self) -> str: + return "SET TIME ZONE 'UTC'" + + +class Databricks(ThreadedDatabase): dialect = Dialect() + CONNECT_URI_HELP = "databricks://:@/" + CONNECT_URI_PARAMS = ["catalog", "schema"] + + def __init__(self, *, thread_count, **kw): + logging.getLogger("databricks.sql").setLevel(logging.WARNING) + + self._args = kw + self.default_schema = kw.get("schema", "default") + self.catalog = self._args.get("catalog", "hive_metastore") + super().__init__(thread_count=thread_count) + + def create_connection(self): + databricks = import_databricks() + + try: + return databricks.sql.connect( + server_hostname=self._args["server_hostname"], + http_path=self._args["http_path"], + access_token=self._args["access_token"], + catalog=self.catalog, + ) + except databricks.sql.exc.Error as e: + raise ConnectionError(*e.args) from e + + def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. + # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html + # So, to obtain information about schema, we should use another approach. + + conn = self.create_connection() + + catalog, schema, table = self._normalize_table_path(path) + with conn.cursor() as cursor: + cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table) + try: + rows = cursor.fetchall() + finally: + conn.close() + if not rows: + raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") + + d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} + assert len(d) == len(rows) + return d + + def _process_table_schema( + self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None + ): + accept = {i.lower() for i in filter_columns} + rows = [row for name, row in raw_schema.items() if name.lower() in accept] + + resulted_rows = [] + for row in rows: + row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1] + type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType) + + if issubclass(type_cls, Integer): + row = (row[0], row_type, None, None, 0) + + elif issubclass(type_cls, Float): + numeric_precision = math.ceil(row[2] / math.log(2, 10)) + row = (row[0], row_type, None, numeric_precision, None) + + elif issubclass(type_cls, Decimal): + items = row[1][8:].rstrip(")").split(",") + numeric_precision, numeric_scale = int(items[0]), int(items[1]) + row = (row[0], row_type, None, numeric_precision, numeric_scale) + + elif issubclass(type_cls, Timestamp): + row = (row[0], row_type, row[2], None, None) + + else: + row = (row[0], row_type, None, None, None) + + resulted_rows.append(row) + + col_dict: Dict[str, ColType] = {row[0]: self.dialect.parse_type(path, *row) for row in resulted_rows} + + self._refine_coltypes(path, col_dict, where) + return col_dict + + def parse_table_name(self, name: str) -> DbPath: + path = parse_table_name(name) + return tuple(i for i in self._normalize_table_path(path) if i is not None) + + @property + def is_autocommit(self) -> bool: + return True + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return self.catalog, self.default_schema, path[0] + elif len(path) == 2: + return self.catalog, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or catalog.schema.table" + ) diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index e822264e..f7fdaadd 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -1,10 +1,192 @@ -from data_diff.sqeleton.databases import duckdb -from data_diff.databases.base import DatadiffDialect +from typing import Union +from data_diff.utils import match_regexps +from data_diff.abcs.database_types import ( + Timestamp, + TimestampTZ, + DbPath, + ColType, + Float, + Decimal, + Integer, + TemporalType, + Native_UUID, + Text, + FractionalType, + Boolean, + AbstractTable, +) +from data_diff.abcs.mixins import ( + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, + AbstractMixin_RandomSample, + AbstractMixin_Regex, +) +from data_diff.databases.base import ( + Database, + BaseDialect, + import_helper, + ConnectError, + ThreadLocalInterpreter, + TIMESTAMP_PRECISION_POS, +) +from data_diff.databases.base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Mixin_Schema +from data_diff.queries.ast_classes import Func, Compilable +from data_diff.queries.api import code -class Dialect(duckdb.Dialect, duckdb.Mixin_MD5, duckdb.Mixin_NormalizeValue, DatadiffDialect): - pass +@import_helper("duckdb") +def import_duckdb(): + import duckdb -class DuckDB(duckdb.DuckDB): + return duckdb + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + # It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers. + if coltype.rounds and coltype.precision > 0: + return f"CONCAT(SUBSTRING(STRFTIME({value}::TIMESTAMP, '%Y-%m-%d %H:%M:%S.'),1,23), LPAD(((ROUND(strftime({value}::timestamp, '%f')::DECIMAL(15,7)/100000,{coltype.precision-1})*100000)::INT)::VARCHAR,6,'0'))" + + return f"rpad(substring(strftime({value}::timestamp, '%Y-%m-%d %H:%M:%S.%f'),1,{TIMESTAMP_PRECISION_POS+coltype.precision}),26,'0')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"{value}::DECIMAL(38, {coltype.precision})") + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"{value}::INTEGER") + + +class Mixin_RandomSample(AbstractMixin_RandomSample): + def random_sample_n(self, tbl: AbstractTable, size: int) -> AbstractTable: + return code("SELECT * FROM ({tbl}) USING SAMPLE {size};", tbl=tbl, size=size) + + def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> AbstractTable: + return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio)) + + +class Mixin_Regex(AbstractMixin_Regex): + def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: + return Func("regexp_matches", [string, pattern]) + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "DuckDB" + ROUNDS_ON_PREC_LOSS = False + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_INDEXES = True + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + TYPE_CLASSES = { + # Timestamps + "TIMESTAMP WITH TIME ZONE": TimestampTZ, + "TIMESTAMP": Timestamp, + # Numbers + "DOUBLE": Float, + "FLOAT": Float, + "DECIMAL": Decimal, + "INTEGER": Integer, + "BIGINT": Integer, + # Text + "VARCHAR": Text, + "TEXT": Text, + # UUID + "UUID": Native_UUID, + # Bool + "BOOLEAN": Boolean, + } + + def quote(self, s: str): + return f'"{s}"' + + def to_string(self, s: str): + return f"{s}::VARCHAR" + + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues in PostgreSQL + return super()._convert_db_precision_to_digits(p) - 2 + + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + regexps = { + r"DECIMAL\((\d+),(\d+)\)": Decimal, + } + + for m, t_cls in match_regexps(regexps, type_repr): + precision = int(m.group(2)) + return t_cls(precision=precision) + + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) + + def set_timezone_to_utc(self) -> str: + return "SET GLOBAL TimeZone='UTC'" + + def current_timestamp(self) -> str: + return "current_timestamp" + + +class DuckDB(Database): dialect = Dialect() + SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it + default_schema = "main" + CONNECT_URI_HELP = "duckdb://@" + CONNECT_URI_PARAMS = ["database", "dbpath"] + + def __init__(self, **kw): + self._args = kw + self._conn = self.create_connection() + + @property + def is_autocommit(self) -> bool: + return True + + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): + "Uses the standard SQL cursor interface" + return self._query_conn(self._conn, sql_code) + + def close(self): + super().close() + self._conn.close() + + def create_connection(self): + ddb = import_duckdb() + try: + return ddb.connect(self._args["filepath"]) + except ddb.OperationalError as e: + raise ConnectError(*e.args) from e + + def select_table_schema(self, path: DbPath) -> str: + database, schema, table = self._normalize_table_path(path) + + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, database) + + return ( + f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return None, self.default_schema, path[0] + elif len(path) == 2: + return None, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 15163e4f..28d67c99 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -1,10 +1,214 @@ -from data_diff.sqeleton.databases import mssql -from data_diff.databases.base import DatadiffDialect +from typing import Optional +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.databases.base import ( + CHECKSUM_HEXDIGITS, + Mixin_OptimizerHints, + Mixin_RandomSample, + QueryError, + ThreadedDatabase, + import_helper, + ConnectError, + BaseDialect, +) +from data_diff.databases.base import Mixin_Schema +from data_diff.abcs.database_types import ( + JSON, + Timestamp, + TimestampTZ, + DbPath, + Float, + Decimal, + Integer, + TemporalType, + Native_UUID, + Text, + FractionalType, + Boolean, +) -class Dialect(mssql.Dialect, mssql.Mixin_MD5, mssql.Mixin_NormalizeValue, DatadiffDialect): - pass +@import_helper("mssql") +def import_mssql(): + import pyodbc + return pyodbc -class MsSql(mssql.MsSQL): + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.precision > 0: + formatted_value = ( + f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss') + '.' + " + f"SUBSTRING(FORMAT({value}, 'fffffff'), 1, {coltype.precision})" + ) + else: + formatted_value = f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss')" + + return formatted_value + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + if coltype.precision == 0: + return f"CAST(FLOOR({value}) AS VARCHAR)" + + return f"FORMAT({value}, 'N{coltype.precision}')" + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))" + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "MsSQL" + ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_INDEXES = True + TYPE_CLASSES = { + # Timestamps + "datetimeoffset": TimestampTZ, + "datetime": Timestamp, + "datetime2": Timestamp, + "smalldatetime": Timestamp, + "date": Timestamp, + # Numbers + "float": Float, + "real": Float, + "decimal": Decimal, + "money": Decimal, + "smallmoney": Decimal, + # int + "int": Integer, + "bigint": Integer, + "tinyint": Integer, + "smallint": Integer, + # Text + "varchar": Text, + "char": Text, + "text": Text, + "ntext": Text, + "nvarchar": Text, + "nchar": Text, + "binary": Text, + "varbinary": Text, + # UUID + "uniqueidentifier": Native_UUID, + # Bool + "bit": Boolean, + # JSON + "json": JSON, + } + + MIXINS = {Mixin_Schema, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str): + return f"[{s}]" + + def set_timezone_to_utc(self) -> str: + raise NotImplementedError("MsSQL does not support a session timezone setting.") + + def current_timestamp(self) -> str: + return "GETDATE()" + + def current_database(self) -> str: + return "DB_NAME()" + + def current_schema(self) -> str: + return """default_schema_name + FROM sys.database_principals + WHERE name = CURRENT_USER""" + + def to_string(self, s: str): + return f"CONVERT(varchar, {s})" + + def type_repr(self, t) -> str: + try: + return {bool: "bit"}[t] + except KeyError: + return super().type_repr(t) + + def random(self) -> str: + return "rand()" + + def is_distinct_from(self, a: str, b: str) -> str: + # IS (NOT) DISTINCT FROM is available only since SQLServer 2022. + # See: https://stackoverflow.com/a/18684859/857383 + return f"(({a}<>{b} OR {a} IS NULL OR {b} IS NULL) AND NOT({a} IS NULL AND {b} IS NULL))" + + def offset_limit( + self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + ) -> str: + if offset: + raise NotImplementedError("No support for OFFSET in query") + + result = "" + if not has_order_by: + result += "ORDER BY 1" + + result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY" + return result + + def constant_values(self, rows) -> str: + values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) + return f"VALUES {values}" + + +class MsSQL(ThreadedDatabase): dialect = Dialect() + # + CONNECT_URI_HELP = "mssql://:@//" + CONNECT_URI_PARAMS = ["database", "schema"] + + def __init__(self, host, port, user, password, *, database, thread_count, **kw): + args = dict(server=host, port=port, database=database, user=user, password=password, **kw) + self._args = {k: v for k, v in args.items() if v is not None} + self._args["driver"] = "{ODBC Driver 18 for SQL Server}" + + # TODO temp dev debug + self._args["TrustServerCertificate"] = "yes" + + try: + self.default_database = self._args["database"] + self.default_schema = self._args["schema"] + except KeyError: + raise ValueError("Specify a default database and schema.") + + super().__init__(thread_count=thread_count) + + def create_connection(self): + self._mssql = import_mssql() + try: + connection = self._mssql.connect(**self._args) + return connection + except self._mssql.Error as error: + raise ConnectError(*error.args) from error + + def select_table_schema(self, path: DbPath) -> str: + """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" + database, schema, name = self._normalize_table_path(path) + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, self.dialect.quote(database)) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + f"FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + ) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return self.default_database, self.default_schema, path[0] + elif len(path) == 2: + return self.default_database, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) + + def _query_cursor(self, c, sql_code: str): + try: + return super()._query_cursor(c, sql_code) + except self._mssql.DatabaseError as e: + raise QueryError(e) diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 102620b8..910ff78d 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,10 +1,159 @@ -from data_diff.sqeleton.databases import mysql -from data_diff.databases.base import DatadiffDialect +from data_diff.abcs.database_types import ( + Datetime, + Timestamp, + Float, + Decimal, + Integer, + Text, + TemporalType, + FractionalType, + ColType_UUID, + Boolean, + Date, +) +from data_diff.abcs.mixins import ( + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, + AbstractMixin_Regex, +) +from data_diff.databases.base import ( + Mixin_OptimizerHints, + ThreadedDatabase, + import_helper, + ConnectError, + BaseDialect, + Compilable, +) +from data_diff.databases.base import ( + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + TIMESTAMP_PRECISION_POS, + Mixin_Schema, + Mixin_RandomSample, +) +from data_diff.queries.ast_classes import BinBoolOp -class Dialect(mysql.Dialect, mysql.Mixin_MD5, mysql.Mixin_NormalizeValue, DatadiffDialect): - pass +@import_helper("mysql") +def import_mysql(): + import mysql.connector + return mysql.connector -class MySQL(mysql.MySQL): + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") + + s = self.to_string(f"cast({value} as datetime(6))") + return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + return f"TRIM(CAST({value} AS char))" + + +class Mixin_Regex(AbstractMixin_Regex): + def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: + return BinBoolOp("REGEXP", [string, pattern]) + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "MySQL" + ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_INDEXES = True + TYPE_CLASSES = { + # Dates + "datetime": Datetime, + "timestamp": Timestamp, + "date": Date, + # Numbers + "double": Float, + "float": Float, + "decimal": Decimal, + "int": Integer, + "bigint": Integer, + "mediumint": Integer, + "smallint": Integer, + "tinyint": Integer, + # Text + "varchar": Text, + "char": Text, + "varbinary": Text, + "binary": Text, + "text": Text, + "mediumtext": Text, + "longtext": Text, + "tinytext": Text, + # Boolean + "boolean": Boolean, + } + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str): + return f"`{s}`" + + def to_string(self, s: str): + return f"cast({s} as char)" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"not ({a} <=> {b})" + + def random(self) -> str: + return "RAND()" + + def type_repr(self, t) -> str: + try: + return { + str: "VARCHAR(1024)", + }[t] + except KeyError: + return super().type_repr(t) + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN FORMAT=TREE {query}" + + def optimizer_hints(self, s: str): + return f"/*+ {s} */ " + + def set_timezone_to_utc(self) -> str: + return "SET @@session.time_zone='+00:00'" + + +class MySQL(ThreadedDatabase): dialect = Dialect() + SUPPORTS_ALPHANUMS = False + SUPPORTS_UNIQUE_CONSTAINT = True + CONNECT_URI_HELP = "mysql://:@/" + CONNECT_URI_PARAMS = ["database?"] + + def __init__(self, *, thread_count, **kw): + self._args = kw + + super().__init__(thread_count=thread_count) + + # In MySQL schema and database are synonymous + try: + self.default_schema = kw["database"] + except KeyError: + raise ValueError("MySQL URL must specify a database") + + def create_connection(self): + mysql = import_mysql() + try: + return mysql.connect(charset="utf8", use_unicode=True, **self._args) + except mysql.Error as e: + if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR: + raise ConnectError("Bad user name or password") from e + elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: + raise ConnectError("Database does not exist") from e + raise ConnectError(*e.args) from e diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 3ee4a872..f0309c11 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,10 +1,206 @@ -from data_diff.sqeleton.databases import oracle -from data_diff.databases.base import DatadiffDialect +from typing import Dict, List, Optional +from data_diff.utils import match_regexps +from data_diff.abcs.database_types import ( + Decimal, + Float, + Text, + DbPath, + TemporalType, + ColType, + DbTime, + ColType_UUID, + Timestamp, + TimestampTZ, + FractionalType, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema +from data_diff.abcs.compiler import Compilable +from data_diff.queries.api import this, table, SKIP +from data_diff.databases.base import ( + BaseDialect, + Mixin_OptimizerHints, + ThreadedDatabase, + import_helper, + ConnectError, + QueryError, + Mixin_RandomSample, +) +from data_diff.databases.base import TIMESTAMP_PRECISION_POS -class Dialect(oracle.Dialect, oracle.Mixin_MD5, oracle.Mixin_NormalizeValue, DatadiffDialect): - pass +SESSION_TIME_ZONE = None # Changed by the tests -class Oracle(oracle.Oracle): +@import_helper("oracle") +def import_oracle(): + import oracledb + + return oracledb + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + # standard_hash is faster than DBMS_CRYPTO.Hash + # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? + return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + # Cast is necessary for correct MD5 (trimming not enough) + return f"CAST(TRIM({value}) AS VARCHAR(36))" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" + + if coltype.precision > 0: + truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')" + else: + truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')" + return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + # FM999.9990 + format_str = "FM" + "9" * (38 - coltype.precision) + if coltype.precision: + format_str += "0." + "9" * (coltype.precision - 1) + "0" + return f"to_char({value}, '{format_str}')" + + +class Mixin_Schema(AbstractMixin_Schema): + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + table("ALL_TABLES") + .where( + this.OWNER == table_schema, + this.TABLE_NAME.like(like) if like is not None else SKIP, + ) + .select(table_name=this.TABLE_NAME) + ) + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Oracle" + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_INDEXES = True + TYPE_CLASSES: Dict[str, type] = { + "NUMBER": Decimal, + "FLOAT": Float, + # Text + "CHAR": Text, + "NCHAR": Text, + "NVARCHAR2": Text, + "VARCHAR2": Text, + "DATE": Timestamp, + } + ROUNDS_ON_PREC_LOSS = True + PLACEHOLDER_TABLE = "DUAL" + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str): + return f'"{s}"' + + def to_string(self, s: str): + return f"cast({s} as varchar(1024))" + + def offset_limit( + self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + ) -> str: + if offset: + raise NotImplementedError("No support for OFFSET in query") + + return f"FETCH NEXT {limit} ROWS ONLY" + + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) + return f"({joined_exprs})" + + def timestamp_value(self, t: DbTime) -> str: + return "timestamp '%s'" % t.isoformat(" ") + + def random(self) -> str: + return "dbms_random.value" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"DECODE({a}, {b}, 1, 0) = 0" + + def type_repr(self, t) -> str: + try: + return { + str: "VARCHAR(1024)", + }[t] + except KeyError: + return super().type_repr(t) + + def constant_values(self, rows) -> str: + return " UNION ALL ".join( + "SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows + ) + + def explain_as_text(self, query: str) -> str: + raise NotImplementedError("Explain not yet implemented in Oracle") + + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + regexps = { + r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, + r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, + r"TIMESTAMP\((\d)\)": Timestamp, + } + + for m, t_cls in match_regexps(regexps, type_repr): + precision = int(m.group(1)) + return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) + + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) + + def set_timezone_to_utc(self) -> str: + return "ALTER SESSION SET TIME_ZONE = 'UTC'" + + def current_timestamp(self) -> str: + return "LOCALTIMESTAMP" + + +class Oracle(ThreadedDatabase): dialect = Dialect() + CONNECT_URI_HELP = "oracle://:@/" + CONNECT_URI_PARAMS = ["database?"] + + def __init__(self, *, host, database, thread_count, **kw): + self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) + + self.default_schema = kw.get("user").upper() + + super().__init__(thread_count=thread_count) + + def create_connection(self): + self._oracle = import_oracle() + try: + c = self._oracle.connect(**self.kwargs) + if SESSION_TIME_ZONE: + c.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'") + return c + except Exception as e: + raise ConnectError(*e.args) from e + + def _query_cursor(self, c, sql_code: str): + try: + return super()._query_cursor(c, sql_code) + except self._oracle.DatabaseError as e: + raise QueryError(e) + + def select_table_schema(self, path: DbPath) -> str: + schema, name = self._normalize_table_path(path) + + return ( + f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" + f" FROM ALL_TAB_COLUMNS WHERE table_name = '{name}' AND owner = '{schema}'" + ) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index b63f050a..dec9b9d3 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -1,10 +1,182 @@ -from data_diff.sqeleton.databases import postgresql as pg -from data_diff.databases.base import DatadiffDialect +from typing import List +from data_diff.abcs.database_types import ( + DbPath, + JSON, + Timestamp, + TimestampTZ, + Float, + Decimal, + Integer, + TemporalType, + Native_UUID, + Text, + FractionalType, + Boolean, + Date, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.databases.base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema +from data_diff.databases.base import ( + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + _CHECKSUM_BITSIZE, + TIMESTAMP_PRECISION_POS, + Mixin_RandomSample, +) +SESSION_TIME_ZONE = None # Changed by the tests -class PostgresqlDialect(pg.PostgresqlDialect, pg.Mixin_MD5, pg.Mixin_NormalizeValue, DatadiffDialect): - pass +@import_helper("postgresql") +def import_postgresql(): + import psycopg2.extras -class PostgreSQL(pg.PostgreSQL): + psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select) + return psycopg2 + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" + + timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"{value}::decimal(38, {coltype.precision})") + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"{value}::int") + + def normalize_json(self, value: str, _coltype: JSON) -> str: + return f"{value}::text" + + +class PostgresqlDialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "PostgreSQL" + ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_INDEXES = True + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + TYPE_CLASSES = { + # Timestamps + "timestamp with time zone": TimestampTZ, + "timestamp without time zone": Timestamp, + "timestamp": Timestamp, + "date": Date, + # Numbers + "double precision": Float, + "real": Float, + "decimal": Decimal, + "smallint": Integer, + "integer": Integer, + "numeric": Decimal, + "bigint": Integer, + # Text + "character": Text, + "character varying": Text, + "varchar": Text, + "text": Text, + "json": JSON, + "jsonb": JSON, + "uuid": Native_UUID, + "boolean": Boolean, + } + + def quote(self, s: str): + return f'"{s}"' + + def to_string(self, s: str): + return f"{s}::varchar" + + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) + return f"({joined_exprs})" + + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues in PostgreSQL + return super()._convert_db_precision_to_digits(p) - 2 + + def set_timezone_to_utc(self) -> str: + return "SET TIME ZONE 'UTC'" + + def current_timestamp(self) -> str: + return "current_timestamp" + + def type_repr(self, t) -> str: + if isinstance(t, TimestampTZ): + return f"timestamp ({t.precision}) with time zone" + return super().type_repr(t) + + +class PostgreSQL(ThreadedDatabase): dialect = PostgresqlDialect() + SUPPORTS_UNIQUE_CONSTAINT = True + CONNECT_URI_HELP = "postgresql://:@/" + CONNECT_URI_PARAMS = ["database?"] + + default_schema = "public" + + def __init__(self, *, thread_count, **kw): + self._args = kw + + super().__init__(thread_count=thread_count) + + def create_connection(self): + if not self._args: + self._args["host"] = None # psycopg2 requires 1+ arguments + + pg = import_postgresql() + try: + c = pg.connect(**self._args) + if SESSION_TIME_ZONE: + c.cursor().execute(f"SET TIME ZONE '{SESSION_TIME_ZONE}'") + return c + except pg.OperationalError as e: + raise ConnectError(*e.args) from e + + def select_table_schema(self, path: DbPath) -> str: + database, schema, table = self._normalize_table_path(path) + + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, database) + + return ( + f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def select_table_unique_columns(self, path: DbPath) -> str: + database, schema, table = self._normalize_table_path(path) + + info_schema_path = ["information_schema", "key_column_usage"] + if database: + info_schema_path.insert(0, database) + + return ( + "SELECT column_name " + f"FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return None, self.default_schema, path[0] + elif len(path) == 2: + return None, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 4ac86b3f..b4c45751 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,10 +1,202 @@ -from data_diff.sqeleton.databases import presto -from data_diff.databases.base import DatadiffDialect +from functools import partial +import re +from data_diff.utils import match_regexps -class Dialect(presto.Dialect, presto.Mixin_MD5, presto.Mixin_NormalizeValue, DatadiffDialect): - pass +from data_diff.abcs.database_types import ( + Timestamp, + TimestampTZ, + Integer, + Float, + Text, + FractionalType, + DbPath, + DbTime, + Decimal, + ColType, + ColType_UUID, + TemporalType, + Boolean, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.databases.base import ( + BaseDialect, + Database, + import_helper, + ThreadLocalInterpreter, + Mixin_Schema, + Mixin_RandomSample, +) +from data_diff.databases.base import ( + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + TIMESTAMP_PRECISION_POS, +) -class Presto(presto.Presto): +def query_cursor(c, sql_code): + c.execute(sql_code) + if sql_code.lower().startswith("select"): + return c.fetchall() + # Required for the query to actually run 🤯 + if re.match(r"(insert|create|truncate|drop|explain)", sql_code, re.IGNORECASE): + return c.fetchone() + + +@import_helper("presto") +def import_presto(): + import prestodb + + return prestodb + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + # Trim doesn't work on CHAR type + return f"TRIM(CAST({value} AS VARCHAR))" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + # TODO rounds + if coltype.rounds: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" + else: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" + + return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"cast ({value} as int)") + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Presto" + ROUNDS_ON_PREC_LOSS = True + TYPE_CLASSES = { + # Timestamps + "timestamp with time zone": TimestampTZ, + "timestamp without time zone": Timestamp, + "timestamp": Timestamp, + # Numbers + "integer": Integer, + "bigint": Integer, + "real": Float, + "double": Float, + # Text + "varchar": Text, + # Boolean + "boolean": Boolean, + } + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN (FORMAT TEXT) {query}" + + def type_repr(self, t) -> str: + if isinstance(t, TimestampTZ): + return f"timestamp with time zone" + + try: + return {float: "REAL"}[t] + except KeyError: + return super().type_repr(t) + + def timestamp_value(self, t: DbTime) -> str: + return f"timestamp '{t.isoformat(' ')}'" + + def quote(self, s: str): + return f'"{s}"' + + def to_string(self, s: str): + return f"cast({s} as varchar)" + + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + _numeric_scale: int = None, + ) -> ColType: + timestamp_regexps = { + r"timestamp\((\d)\)": Timestamp, + r"timestamp\((\d)\) with time zone": TimestampTZ, + } + for m, t_cls in match_regexps(timestamp_regexps, type_repr): + precision = int(m.group(1)) + return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) + + number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} + for m, n_cls in match_regexps(number_regexps, type_repr): + _prec, scale = map(int, m.groups()) + return n_cls(scale) + + string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text} + for m, n_cls in match_regexps(string_regexps, type_repr): + return n_cls() + + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + + def set_timezone_to_utc(self) -> str: + return "SET TIME ZONE '+00:00'" + + def current_timestamp(self) -> str: + return "current_timestamp" + + +class Presto(Database): dialect = Dialect() + CONNECT_URI_HELP = "presto://@//" + CONNECT_URI_PARAMS = ["catalog", "schema"] + + default_schema = "public" + + def __init__(self, **kw): + prestodb = import_presto() + + if kw.get("schema"): + self.default_schema = kw.get("schema") + + if kw.get("auth") == "basic": # if auth=basic, add basic authenticator for Presto + kw["auth"] = prestodb.auth.BasicAuthentication(kw.pop("user"), kw.pop("password")) + + if "cert" in kw: # if a certificate was specified in URI, verify session with cert + cert = kw.pop("cert") + self._conn = prestodb.dbapi.connect(**kw) + self._conn._http_session.verify = cert + else: + self._conn = prestodb.dbapi.connect(**kw) + + def _query(self, sql_code: str) -> list: + "Uses the standard SQL cursor interface" + c = self._conn.cursor() + + if isinstance(sql_code, ThreadLocalInterpreter): + return sql_code.apply_queries(partial(query_cursor, c)) + + return query_cursor(c, sql_code) + + def close(self): + super().close() + self._conn.close() + + def select_table_schema(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision, NULL as numeric_scale " + "FROM INFORMATION_SCHEMA.COLUMNS " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + @property + def is_autocommit(self) -> bool: + return False diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index e6eb3b20..d11029c0 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,10 +1,176 @@ -from data_diff.sqeleton.databases import redshift -from data_diff.databases.base import DatadiffDialect +from typing import List, Dict +from data_diff.abcs.database_types import ( + Float, + JSON, + TemporalType, + FractionalType, + DbPath, + TimestampTZ, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.databases.postgresql import ( + PostgreSQL, + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + TIMESTAMP_PRECISION_POS, + PostgresqlDialect, + Mixin_NormalizeValue, +) -class Dialect(redshift.Dialect, redshift.Mixin_MD5, redshift.Mixin_NormalizeValue, DatadiffDialect): - pass +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" -class Redshift(redshift.Redshift): +class Mixin_NormalizeValue(Mixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"{value}::timestamp(6)" + # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. + secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" + # Get the milliseconds from timestamp. + ms = f"extract(ms from {timestamp})" + # Get the microseconds from timestamp, without the milliseconds! + us = f"extract(us from {timestamp})" + # epoch = Total time since epoch in microseconds. + epoch = f"{secs}*1000000 + {ms}*1000 + {us}" + timestamp6 = ( + f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" + ) + else: + timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"{value}::decimal(38,{coltype.precision})") + + def normalize_json(self, value: str, _coltype: JSON) -> str: + return f"nvl2({value}, json_serialize({value}), NULL)" + + +class Dialect(PostgresqlDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Redshift" + TYPE_CLASSES = { + **PostgresqlDialect.TYPE_CLASSES, + "double": Float, + "real": Float, + "super": JSON, + } + SUPPORTS_INDEXES = False + + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) + return f"({joined_exprs})" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"({a} IS NULL != {b} IS NULL) OR ({a}!={b})" + + def type_repr(self, t) -> str: + if isinstance(t, TimestampTZ): + return f"timestamptz" + return super().type_repr(t) + + +class Redshift(PostgreSQL): dialect = Dialect() + CONNECT_URI_HELP = "redshift://:@/" + CONNECT_URI_PARAMS = ["database?"] + + def select_table_schema(self, path: DbPath) -> str: + database, schema, table = self._normalize_table_path(path) + + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, database) + + return ( + f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" + ) + + def select_external_table_schema(self, path: DbPath) -> str: + database, schema, table = self._normalize_table_path(path) + + db_clause = "" + if database: + db_clause = f" AND redshift_database_name = '{database.lower()}'" + + return ( + f"""SELECT + columnname AS column_name + , CASE WHEN external_type = 'string' THEN 'varchar' ELSE external_type END AS data_type + , NULL AS datetime_precision + , NULL AS numeric_precision + , NULL AS numeric_scale + FROM svv_external_columns + WHERE tablename = '{table.lower()}' AND schemaname = '{schema.lower()}' + """ + + db_clause + ) + + def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]: + rows = self.query(self.select_external_table_schema(path), list) + if not rows: + raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") + + d = {r[0]: r for r in rows} + assert len(d) == len(rows) + return d + + def select_view_columns(self, path: DbPath) -> str: + _, schema, table = self._normalize_table_path(path) + + return """select * from pg_get_cols('{}.{}') + cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int) + """.format( + schema, table + ) + + def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]: + rows = self.query(self.select_view_columns(path), list) + + if not rows: + raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns") + + output = {} + for r in rows: + col_name = r[2] + type_info = r[3].split("(") + base_type = type_info[0] + precision = None + scale = None + + if len(type_info) > 1: + if base_type == "numeric": + precision, scale = type_info[1][:-1].split(",") + precision = int(precision) + scale = int(scale) + + out = [col_name, base_type, None, precision, scale] + output[col_name] = tuple(out) + + return output + + def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + try: + return super().query_table_schema(path) + except RuntimeError: + try: + return self.query_external_table_schema(path) + except RuntimeError: + return self.query_pg_get_cols(path) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return None, self.default_schema, path[0] + elif len(path) == 2: + return None, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 7dd8539f..3a558425 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,10 +1,228 @@ -from data_diff.sqeleton.databases import snowflake -from data_diff.databases.base import DatadiffDialect +from typing import Union, List +import logging +from data_diff.abcs.database_types import ( + Timestamp, + TimestampTZ, + Decimal, + Float, + Text, + FractionalType, + TemporalType, + DbPath, + Boolean, + Date, +) +from data_diff.abcs.mixins import ( + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, + AbstractMixin_Schema, + AbstractMixin_TimeTravel, +) +from data_diff.abcs.compiler import Compilable +from data_diff.queries.api import table, this, SKIP, code +from data_diff.databases.base import ( + BaseDialect, + ConnectError, + Database, + import_helper, + CHECKSUM_MASK, + ThreadLocalInterpreter, + Mixin_RandomSample, +) -class Dialect(snowflake.Dialect, snowflake.Mixin_MD5, snowflake.Mixin_NormalizeValue, DatadiffDialect): - pass +@import_helper("snowflake") +def import_snowflake(): + import snowflake.connector + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.backends import default_backend -class Snowflake(snowflake.Snowflake): + return snowflake, serialization, default_backend + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))" + else: + timestamp = f"cast(convert_timezone('UTC', {value}) as timestamp({coltype.precision}))" + + return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"{value}::int") + + +class Mixin_Schema(AbstractMixin_Schema): + def table_information(self) -> Compilable: + return table("INFORMATION_SCHEMA", "TABLES") + + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + self.table_information() + .where( + this.TABLE_SCHEMA == table_schema, + this.TABLE_NAME.like(like) if like is not None else SKIP, + this.TABLE_TYPE == "BASE TABLE", + ) + .select(table_name=this.TABLE_NAME) + ) + + +class Mixin_TimeTravel(AbstractMixin_TimeTravel): + def time_travel( + self, + table: Compilable, + before: bool = False, + timestamp: Compilable = None, + offset: Compilable = None, + statement: Compilable = None, + ) -> Compilable: + at_or_before = "AT" if before else "BEFORE" + if timestamp is not None: + assert offset is None and statement is None + key = "timestamp" + value = timestamp + elif offset is not None: + assert statement is None + key = "offset" + value = offset + else: + assert statement is not None + key = "statement" + value = statement + + return code(f"{{table}} {at_or_before}({key} => {{value}})", table=table, value=value) + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Snowflake" + ROUNDS_ON_PREC_LOSS = False + TYPE_CLASSES = { + # Timestamps + "TIMESTAMP_NTZ": Timestamp, + "TIMESTAMP_LTZ": Timestamp, + "TIMESTAMP_TZ": TimestampTZ, + "DATE": Date, + # Numbers + "NUMBER": Decimal, + "FLOAT": Float, + # Text + "TEXT": Text, + # Boolean + "BOOLEAN": Boolean, + } + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN USING TEXT {query}" + + def quote(self, s: str): + return f'"{s}"' + + def to_string(self, s: str): + return f"cast({s} as string)" + + def table_information(self) -> Compilable: + return table("INFORMATION_SCHEMA", "TABLES") + + def set_timezone_to_utc(self) -> str: + return "ALTER SESSION SET TIMEZONE = 'UTC'" + + def optimizer_hints(self, hints: str) -> str: + raise NotImplementedError("Optimizer hints not yet implemented in snowflake") + + def type_repr(self, t) -> str: + if isinstance(t, TimestampTZ): + return f"timestamp_tz({t.precision})" + return super().type_repr(t) + + +class Snowflake(Database): dialect = Dialect() + CONNECT_URI_HELP = "snowflake://:@//?warehouse=" + CONNECT_URI_PARAMS = ["database", "schema"] + CONNECT_URI_KWPARAMS = ["warehouse"] + + def __init__(self, *, schema: str, **kw): + snowflake, serialization, default_backend = import_snowflake() + logging.getLogger("snowflake.connector").setLevel(logging.WARNING) + + # Ignore the error: snowflake.connector.network.RetryRequest: could not find io module state + # It's a known issue: https://github.com/snowflakedb/snowflake-connector-python/issues/145 + logging.getLogger("snowflake.connector.network").disabled = True + + assert '"' not in schema, "Schema name should not contain quotes!" + # If a private key is used, read it from the specified path and pass it as "private_key" to the connector. + if "key" in kw: + with open(kw.get("key"), "rb") as key: + if "password" in kw: + raise ConnectError("Cannot use password and key at the same time") + if kw.get("private_key_passphrase"): + encoded_passphrase = kw.get("private_key_passphrase").encode() + else: + encoded_passphrase = None + p_key = serialization.load_pem_private_key( + key.read(), + password=encoded_passphrase, + backend=default_backend(), + ) + + kw["private_key"] = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + self._conn = snowflake.connector.connect(schema=f'"{schema}"', **kw) + + self.default_schema = schema + + def close(self): + super().close() + self._conn.close() + + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): + "Uses the standard SQL cursor interface" + return self._query_conn(self._conn, sql_code) + + def select_table_schema(self, path: DbPath) -> str: + """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" + database, schema, name = self._normalize_table_path(path) + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, database) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + f"FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + ) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return None, self.default_schema, path[0] + elif len(path) == 2: + return None, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) + + @property + def is_autocommit(self) -> bool: + return True + + def query_table_unique_columns(self, path: DbPath) -> List[str]: + return [] diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index a39be906..e2095758 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,10 +1,48 @@ -from data_diff.sqeleton.databases import trino -from data_diff.databases.base import DatadiffDialect +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.abcs.database_types import TemporalType, ColType_UUID +from data_diff.databases import presto +from data_diff.databases.base import import_helper +from data_diff.databases.base import TIMESTAMP_PRECISION_POS -class Dialect(trino.Dialect, trino.Mixin_MD5, trino.Mixin_NormalizeValue, DatadiffDialect): - pass +@import_helper("trino") +def import_trino(): + import trino + return trino -class Trino(trino.Trino): + +Mixin_MD5 = presto.Mixin_MD5 + + +class Mixin_NormalizeValue(presto.Mixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + s = f"date_format(cast({value} as timestamp({coltype.precision})), '%Y-%m-%d %H:%i:%S.%f')" + else: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" + + return ( + f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')" + ) + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + return f"TRIM({value})" + + +class Dialect(presto.Dialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Trino" + + +class Trino(presto.Presto): dialect = Dialect() + CONNECT_URI_HELP = "trino://@//" + CONNECT_URI_PARAMS = ["catalog", "schema"] + + def __init__(self, **kw): + trino = import_trino() + + if kw.get("schema"): + self.default_schema = kw.get("schema") + + self._conn = trino.dbapi.connect(**kw) diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 60812a49..e8fe9ec2 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -1,10 +1,181 @@ -from data_diff.sqeleton.databases import vertica -from data_diff.databases.base import DatadiffDialect +from typing import List +from data_diff.utils import match_regexps +from data_diff.databases.base import ( + CHECKSUM_HEXDIGITS, + MD5_HEXDIGITS, + TIMESTAMP_PRECISION_POS, + BaseDialect, + ConnectError, + DbPath, + ColType, + ThreadedDatabase, + import_helper, + Mixin_RandomSample, +) +from data_diff.abcs.database_types import ( + Decimal, + Float, + FractionalType, + Integer, + TemporalType, + Text, + Timestamp, + TimestampTZ, + Boolean, + ColType_UUID, +) +from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema +from data_diff.abcs.compiler import Compilable +from data_diff.queries.api import table, this, SKIP -class Dialect(vertica.Dialect, vertica.Mixin_MD5, vertica.Mixin_NormalizeValue, DatadiffDialect): - pass +@import_helper("vertica") +def import_vertica(): + import vertica_python -class Vertica(vertica.Vertica): + return vertica_python + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" + + +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" + + timestamp6 = f"TO_CHAR({value}::TIMESTAMP(6), 'YYYY-MM-DD HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"CAST({value} AS NUMERIC(38, {coltype.precision}))") + + def normalize_uuid(self, value: str, _coltype: ColType_UUID) -> str: + # Trim doesn't work on CHAR type + return f"TRIM(CAST({value} AS VARCHAR))" + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + return self.to_string(f"cast ({value} as int)") + + +class Mixin_Schema(AbstractMixin_Schema): + def table_information(self) -> Compilable: + return table("v_catalog", "tables") + + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + self.table_information() + .where( + this.table_schema == table_schema, + this.table_name.like(like) if like is not None else SKIP, + ) + .select(this.table_name) + ) + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + name = "Vertica" + ROUNDS_ON_PREC_LOSS = True + + TYPE_CLASSES = { + # Timestamps + "timestamp": Timestamp, + "timestamptz": TimestampTZ, + # Numbers + "numeric": Decimal, + "int": Integer, + "float": Float, + # Text + "char": Text, + "varchar": Text, + # Boolean + "boolean": Boolean, + } + MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str): + return f'"{s}"' + + def concat(self, items: List[str]) -> str: + return " || ".join(items) + + def to_string(self, s: str) -> str: + return f"CAST({s} AS VARCHAR)" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"not ({a} <=> {b})" + + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + timestamp_regexps = { + r"timestamp\(?(\d?)\)?": Timestamp, + r"timestamptz\(?(\d?)\)?": TimestampTZ, + } + for m, t_cls in match_regexps(timestamp_regexps, type_repr): + precision = int(m.group(1)) if m.group(1) else 6 + return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) + + number_regexps = { + r"numeric\((\d+),(\d+)\)": Decimal, + } + for m, n_cls in match_regexps(number_regexps, type_repr): + _prec, scale = map(int, m.groups()) + return n_cls(scale) + + string_regexps = { + r"varchar\((\d+)\)": Text, + r"char\((\d+)\)": Text, + } + for m, n_cls in match_regexps(string_regexps, type_repr): + return n_cls() + + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + + def set_timezone_to_utc(self) -> str: + return "SET TIME ZONE TO 'UTC'" + + def current_timestamp(self) -> str: + return "current_timestamp(6)" + + +class Vertica(ThreadedDatabase): dialect = Dialect() + CONNECT_URI_HELP = "vertica://:@/" + CONNECT_URI_PARAMS = ["database?"] + + default_schema = "public" + + def __init__(self, *, thread_count, **kw): + self._args = kw + self._args["AUTOCOMMIT"] = False + + super().__init__(thread_count=thread_count) + + def create_connection(self): + vertica = import_vertica() + try: + c = vertica.connect(**self._args) + return c + except vertica.errors.ConnectionError as e: + raise ConnectError(*e.args) from e + + def select_table_schema(self, path: DbPath) -> str: + schema, name = self._normalize_table_path(path) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + "FROM V_CATALOG.COLUMNS " + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + ) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 519018f6..08c18391 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,7 +1,6 @@ """Provides classes for performing a table diff """ -import re import time from abc import ABC, abstractmethod from dataclasses import field @@ -14,12 +13,11 @@ from runtype import dataclass from data_diff.info_tree import InfoTree, SegmentInfo - from data_diff.utils import dbt_diff_string_template, run_as_daemon, safezip, getLogger, truncate_error, Vector from data_diff.thread_utils import ThreadedYielder from data_diff.table_segment import TableSegment, create_mesh_from_points from data_diff.tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled -from data_diff.sqeleton.abcs import IKey +from data_diff.abcs.database_types import IKey logger = getLogger(__name__) diff --git a/data_diff/format.py b/data_diff/format.py index bfeb0b1e..8a515e1b 100644 --- a/data_diff/format.py +++ b/data_diff/format.py @@ -4,7 +4,7 @@ from runtype import dataclass from data_diff.diff_tables import DiffResultWrapper -from data_diff.sqeleton.abcs.database_types import ( +from data_diff.abcs.database_types import ( JSON, Boolean, ColType, diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 65072ed4..3fc030ec 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -8,13 +8,11 @@ from runtype import dataclass -from data_diff.sqeleton.abcs import ColType_UUID, NumericType, PrecisionType, StringType, Boolean, JSON - +from data_diff.abcs.database_types import ColType_UUID, NumericType, PrecisionType, StringType, Boolean, JSON from data_diff.info_tree import InfoTree from data_diff.utils import safezip, diffs_are_equiv_jsons from data_diff.thread_utils import ThreadedYielder from data_diff.table_segment import TableSegment - from data_diff.diff_tables import TableDiffer BENCHMARK = os.environ.get("BENCHMARK", False) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 26ba1e0e..667786a7 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -10,14 +10,11 @@ from runtype import dataclass -from data_diff.sqeleton.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake, DbPath -from data_diff.sqeleton.abcs import NumericType -from data_diff.sqeleton.queries import ( +from data_diff.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake +from data_diff.abcs.database_types import NumericType, DbPath +from data_diff.queries.api import ( table, sum_, - min_, - max_, - avg, and_, if_, or_, @@ -28,11 +25,9 @@ when, Compiler, ) -from data_diff.sqeleton.queries.ast_classes import Concat, Count, Expr, Func, Random, TablePath, Code, ITable -from data_diff.sqeleton.queries.extras import NormalizeAsString - +from data_diff.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code, ITable +from data_diff.queries.extras import NormalizeAsString from data_diff.info_tree import InfoTree - from data_diff.query_utils import append_to_table, drop_table from data_diff.utils import safezip from data_diff.table_segment import TableSegment diff --git a/data_diff/queries/__init__.py b/data_diff/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/data_diff/sqeleton/queries/api.py b/data_diff/queries/api.py similarity index 96% rename from data_diff/sqeleton/queries/api.py rename to data_diff/queries/api.py index 301cea32..82786871 100644 --- a/data_diff/sqeleton/queries/api.py +++ b/data_diff/queries/api.py @@ -1,8 +1,6 @@ -from typing import Optional - -from data_diff.sqeleton.utils import CaseAwareMapping, CaseSensitiveDict -from data_diff.sqeleton.queries.ast_classes import * -from data_diff.sqeleton.queries.base import args_as_tuple +from data_diff.utils import CaseAwareMapping, CaseSensitiveDict +from data_diff.queries.ast_classes import * +from data_diff.queries.base import args_as_tuple this = This() diff --git a/data_diff/sqeleton/queries/ast_classes.py b/data_diff/queries/ast_classes.py similarity index 98% rename from data_diff/sqeleton/queries/ast_classes.py rename to data_diff/queries/ast_classes.py index f3b04f73..70cb355f 100644 --- a/data_diff/sqeleton/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,18 +1,19 @@ from dataclasses import field from datetime import datetime -from typing import Any, Generator, List, Optional, Sequence, Type, Union, Dict +from typing import Any, Generator, List, Optional, Sequence, Union, Dict from runtype import dataclass from typing_extensions import Self -from data_diff.sqeleton.utils import join_iter, ArithString -from data_diff.sqeleton.abcs import Compilable -from data_diff.sqeleton.abcs.database_types import AbstractTable -from data_diff.sqeleton.abcs.mixins import AbstractMixin_Regex, AbstractMixin_TimeTravel -from data_diff.sqeleton.schema import Schema +from data_diff.utils import join_iter, ArithString +from data_diff.abcs.compiler import Compilable +from data_diff.abcs.database_types import AbstractTable +from data_diff.abcs.mixins import AbstractMixin_Regex, AbstractMixin_TimeTravel +from data_diff.schema import Schema -from data_diff.sqeleton.queries.compiler import Compiler, cv_params, Root, CompileError -from data_diff.sqeleton.queries.base import SKIP, DbPath, args_as_tuple, SqeletonError +from data_diff.queries.compiler import Compiler, cv_params, Root, CompileError +from data_diff.queries.base import SKIP, args_as_tuple, SqeletonError +from data_diff.abcs.database_types import DbPath class QueryBuilderError(SqeletonError): diff --git a/data_diff/sqeleton/queries/base.py b/data_diff/queries/base.py similarity index 76% rename from data_diff/sqeleton/queries/base.py rename to data_diff/queries/base.py index d229e175..205c2211 100644 --- a/data_diff/sqeleton/queries/base.py +++ b/data_diff/queries/base.py @@ -1,8 +1,5 @@ from typing import Generator -from data_diff.sqeleton.abcs import DbPath, DbKey -from data_diff.sqeleton.schema import Schema - class _SKIP: def __repr__(self): diff --git a/data_diff/sqeleton/queries/compiler.py b/data_diff/queries/compiler.py similarity index 91% rename from data_diff/sqeleton/queries/compiler.py rename to data_diff/queries/compiler.py index f9ab7484..e6246236 100644 --- a/data_diff/sqeleton/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -6,8 +6,9 @@ from runtype import dataclass from typing_extensions import Self -from data_diff.sqeleton.utils import ArithString -from data_diff.sqeleton.abcs import AbstractDatabase, AbstractDialect, DbPath, AbstractCompiler, Compilable +from data_diff.utils import ArithString +from data_diff.abcs.database_types import AbstractDatabase, AbstractDialect, DbPath +from data_diff.abcs.compiler import AbstractCompiler, Compilable import contextvars @@ -44,7 +45,7 @@ def compile(self, elem, params=None) -> str: cv_params.set(params) if self.root and isinstance(elem, Compilable) and not isinstance(elem, Root): - from data_diff.sqeleton.queries.ast_classes import Select + from data_diff.queries.ast_classes import Select elem = Select(columns=[elem]) diff --git a/data_diff/sqeleton/queries/extras.py b/data_diff/queries/extras.py similarity index 89% rename from data_diff/sqeleton/queries/extras.py rename to data_diff/queries/extras.py index 4a1d58c1..8e916601 100644 --- a/data_diff/sqeleton/queries/extras.py +++ b/data_diff/queries/extras.py @@ -3,10 +3,10 @@ from typing import Callable, Sequence from runtype import dataclass -from data_diff.sqeleton.abcs.database_types import ColType, Native_UUID +from data_diff.abcs.database_types import ColType, Native_UUID -from data_diff.sqeleton.queries.compiler import Compiler -from data_diff.sqeleton.queries.ast_classes import Expr, ExprNode, Concat, Code +from data_diff.queries.compiler import Compiler +from data_diff.queries.ast_classes import Expr, ExprNode, Concat, Code @dataclass diff --git a/data_diff/query_utils.py b/data_diff/query_utils.py index 4b963039..a4887728 100644 --- a/data_diff/query_utils.py +++ b/data_diff/query_utils.py @@ -2,8 +2,10 @@ from contextlib import suppress -from data_diff.sqeleton.databases import DbPath, QueryError, Oracle -from data_diff.sqeleton.queries import table, commit, Expr +from data_diff.abcs.database_types import DbPath +from data_diff.databases.base import QueryError +from data_diff.databases.oracle import Oracle +from data_diff.queries.api import table, commit, Expr def _drop_table_oracle(name: DbPath): diff --git a/data_diff/sqeleton/schema.py b/data_diff/schema.py similarity index 79% rename from data_diff/sqeleton/schema.py rename to data_diff/schema.py index 01dfeed7..847bbf23 100644 --- a/data_diff/sqeleton/schema.py +++ b/data_diff/schema.py @@ -1,7 +1,7 @@ import logging -from data_diff.sqeleton.utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict -from data_diff.sqeleton.abcs import AbstractDatabase, DbPath +from data_diff.utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict +from data_diff.abcs.database_types import AbstractDatabase, DbPath logger = logging.getLogger("schema") diff --git a/data_diff/sqeleton/__init__.py b/data_diff/sqeleton/__init__.py deleted file mode 100644 index b6e32cc2..00000000 --- a/data_diff/sqeleton/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from data_diff.sqeleton.databases import connect -from data_diff.sqeleton.queries import table, this, SKIP, code diff --git a/data_diff/sqeleton/abcs/__init__.py b/data_diff/sqeleton/abcs/__init__.py deleted file mode 100644 index 5654ad16..00000000 --- a/data_diff/sqeleton/abcs/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from data_diff.sqeleton.abcs.database_types import ( - AbstractDatabase, - AbstractDialect, - DbKey, - DbPath, - DbTime, - IKey, - ColType_UUID, - NumericType, - PrecisionType, - StringType, - Boolean, - JSON, -) -from data_diff.sqeleton.abcs.compiler import AbstractCompiler, Compilable diff --git a/data_diff/sqeleton/databases/__init__.py b/data_diff/sqeleton/databases/__init__.py deleted file mode 100644 index 70af2412..00000000 --- a/data_diff/sqeleton/databases/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -from data_diff.sqeleton.databases.base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - QueryError, - ConnectError, - BaseDialect, - Database, -) -from data_diff.sqeleton.abcs import DbPath, DbKey, DbTime -from data_diff.sqeleton.databases._connect import Connect - -from data_diff.sqeleton.databases.postgresql import PostgreSQL -from data_diff.sqeleton.databases.mysql import MySQL -from data_diff.sqeleton.databases.oracle import Oracle -from data_diff.sqeleton.databases.snowflake import Snowflake -from data_diff.sqeleton.databases.bigquery import BigQuery -from data_diff.sqeleton.databases.redshift import Redshift -from data_diff.sqeleton.databases.presto import Presto -from data_diff.sqeleton.databases.databricks import Databricks -from data_diff.sqeleton.databases.trino import Trino -from data_diff.sqeleton.databases.clickhouse import Clickhouse -from data_diff.sqeleton.databases.vertica import Vertica -from data_diff.sqeleton.databases.duckdb import DuckDB -from data_diff.sqeleton.databases.mssql import MsSQL - -connect = Connect() diff --git a/data_diff/sqeleton/databases/_connect.py b/data_diff/sqeleton/databases/_connect.py deleted file mode 100644 index ad152dda..00000000 --- a/data_diff/sqeleton/databases/_connect.py +++ /dev/null @@ -1,283 +0,0 @@ -from typing import Hashable, MutableMapping, Type, Optional, Union, Dict -from itertools import zip_longest -from contextlib import suppress -import weakref -import dsnparse -import toml - -from runtype import dataclass -from typing_extensions import Self - -from data_diff.sqeleton.abcs.mixins import AbstractMixin -from data_diff.sqeleton.databases.base import Database, ThreadedDatabase -from data_diff.sqeleton.databases.postgresql import PostgreSQL -from data_diff.sqeleton.databases.mysql import MySQL -from data_diff.sqeleton.databases.oracle import Oracle -from data_diff.sqeleton.databases.snowflake import Snowflake -from data_diff.sqeleton.databases.bigquery import BigQuery -from data_diff.sqeleton.databases.redshift import Redshift -from data_diff.sqeleton.databases.presto import Presto -from data_diff.sqeleton.databases.databricks import Databricks -from data_diff.sqeleton.databases.trino import Trino -from data_diff.sqeleton.databases.clickhouse import Clickhouse -from data_diff.sqeleton.databases.vertica import Vertica -from data_diff.sqeleton.databases.duckdb import DuckDB -from data_diff.sqeleton.databases.mssql import MsSQL - - -@dataclass -class MatchUriPath: - database_cls: Type[Database] - - def match_path(self, dsn): - help_str = self.database_cls.CONNECT_URI_HELP - params = self.database_cls.CONNECT_URI_PARAMS - kwparams = self.database_cls.CONNECT_URI_KWPARAMS - - dsn_dict = dict(dsn.query) - matches = {} - for param, arg in zip_longest(params, dsn.paths): - if param is None: - raise ValueError(f"Too many parts to path. Expected format: {help_str}") - - optional = param.endswith("?") - param = param.rstrip("?") - - if arg is None: - try: - arg = dsn_dict.pop(param) - except KeyError: - if not optional: - raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") - - arg = None - - assert param and param not in matches - matches[param] = arg - - for param in kwparams: - try: - arg = dsn_dict.pop(param) - except KeyError: - raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") - - assert param and arg and param not in matches, (param, arg, matches.keys()) - matches[param] = arg - - for param, value in dsn_dict.items(): - if param in matches: - raise ValueError( - f"Parameter '{param}' already provided as positional argument. Expected format: {help_str}" - ) - - matches[param] = value - - return matches - - -DATABASE_BY_SCHEME = { - "postgresql": PostgreSQL, - "mysql": MySQL, - "oracle": Oracle, - "redshift": Redshift, - "snowflake": Snowflake, - "presto": Presto, - "bigquery": BigQuery, - "databricks": Databricks, - "duckdb": DuckDB, - "trino": Trino, - "clickhouse": Clickhouse, - "vertica": Vertica, - "mssql": MsSQL, -} - - -class Connect: - """Provides methods for connecting to a supported database using a URL or connection dict.""" - conn_cache: MutableMapping[Hashable, Database] - - def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): - self.database_by_scheme = database_by_scheme - self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()} - self.conn_cache = weakref.WeakValueDictionary() - - def for_databases(self, *dbs) -> Self: - database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs} - return type(self)(database_by_scheme) - - def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs) -> Database: - """Connect to the given database uri - - thread_count determines the max number of worker threads per database, - if relevant. None means no limit. - - Parameters: - db_uri (str): The URI for the database to connect - thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) - - Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. - - Supported schemes: - - postgresql - - mysql - - oracle - - snowflake - - bigquery - - redshift - - presto - - databricks - - trino - - clickhouse - - vertica - - duckdb - """ - - dsn = dsnparse.parse(db_uri) - if len(dsn.schemes) > 1: - raise NotImplementedError("No support for multiple schemes") - (scheme,) = dsn.schemes - - if scheme == "toml": - toml_path = dsn.path or dsn.host - database = dsn.fragment - if not database: - raise ValueError("Must specify a database name, e.g. 'toml://path#database'. ") - with open(toml_path) as f: - config = toml.load(f) - try: - conn_dict = config["database"][database] - except KeyError: - raise ValueError(f"Cannot find database config named '{database}'.") - return self.connect_with_dict(conn_dict, thread_count, **kwargs) - - try: - matcher = self.match_uri_path[scheme] - except KeyError: - raise NotImplementedError(f"Scheme '{scheme}' currently not supported") - - cls = matcher.database_cls - - if scheme == "databricks": - assert not dsn.user - kw = {} - kw["access_token"] = dsn.password - kw["http_path"] = dsn.path - kw["server_hostname"] = dsn.host - kw.update(dsn.query) - elif scheme == "duckdb": - kw = {} - kw["filepath"] = dsn.dbname - kw["dbname"] = dsn.user - else: - kw = matcher.match_path(dsn) - - if scheme == "bigquery": - kw["project"] = dsn.host - return cls(**kw, **kwargs) - - if scheme == "snowflake": - kw["account"] = dsn.host - assert not dsn.port - kw["user"] = dsn.user - kw["password"] = dsn.password - else: - if scheme == "oracle": - kw["host"] = dsn.hostloc - else: - kw["host"] = dsn.host - kw["port"] = dsn.port - kw["user"] = dsn.user - if dsn.password: - kw["password"] = dsn.password - - kw = {k: v for k, v in kw.items() if v is not None} - - if issubclass(cls, ThreadedDatabase): - db = cls(thread_count=thread_count, **kw, **kwargs) - else: - db = cls(**kw, **kwargs) - - return self._connection_created(db) - - def connect_with_dict(self, d, thread_count, **kwargs): - d = dict(d) - driver = d.pop("driver") - try: - matcher = self.match_uri_path[driver] - except KeyError: - raise NotImplementedError(f"Driver '{driver}' currently not supported") - - cls = matcher.database_cls - if issubclass(cls, ThreadedDatabase): - db = cls(thread_count=thread_count, **d, **kwargs) - else: - db = cls(**d, **kwargs) - - return self._connection_created(db) - - def _connection_created(self, db): - "Nop function to be overridden by subclasses." - return db - - def __call__( - self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True, **kwargs - ) -> Database: - """Connect to a database using the given database configuration. - - Configuration can be given either as a URI string, or as a dict of {option: value}. - - The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf. - - thread_count determines the max number of worker threads per database, - if relevant. None means no limit. - - Parameters: - db_conf (str | dict): The configuration for the database to connect. URI or dict. - thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) - shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True) - bigquery_credentials (google.oauth2.credentials.Credentials): Custom Google oAuth2 credential for BigQuery. - (default: None) - - Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. - - Supported drivers: - - postgresql - - mysql - - oracle - - snowflake - - bigquery - - redshift - - presto - - databricks - - trino - - clickhouse - - vertica - - Example: - >>> connect("mysql://localhost/db") - - >>> connect({"driver": "mysql", "host": "localhost", "database": "db"}) - - """ - cache_key = self.__make_cache_key(db_conf) - if shared: - with suppress(KeyError): - conn = self.conn_cache[cache_key] - if not conn.is_closed: - return conn - - if isinstance(db_conf, str): - conn = self.connect_to_uri(db_conf, thread_count, **kwargs) - elif isinstance(db_conf, dict): - conn = self.connect_with_dict(db_conf, thread_count, **kwargs) - else: - raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.") - - if shared: - self.conn_cache[cache_key] = conn - return conn - - def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable: - if isinstance(db_conf, dict): - return tuple(db_conf.items()) - return db_conf diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py deleted file mode 100644 index ec41bac4..00000000 --- a/data_diff/sqeleton/databases/base.py +++ /dev/null @@ -1,607 +0,0 @@ -from datetime import datetime -import math -import sys -import logging -from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar, TYPE_CHECKING -from functools import partial, wraps -from concurrent.futures import ThreadPoolExecutor -import threading -from abc import abstractmethod -from uuid import UUID -import decimal - -from runtype import dataclass -from typing_extensions import Self - -from data_diff.sqeleton.utils import is_uuid, safezip -from data_diff.sqeleton.queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this -from data_diff.sqeleton.queries.ast_classes import Random -from data_diff.sqeleton.abcs.database_types import ( - AbstractDatabase, - Array, - Struct, - AbstractDialect, - AbstractTable, - ColType, - Integer, - Decimal, - Float, - Native_UUID, - String_UUID, - String_Alphanum, - String_VaryingAlphanum, - TemporalType, - UnknownColType, - TimestampTZ, - Text, - DbTime, - DbPath, - Boolean, - JSON, -) -from data_diff.sqeleton.abcs.mixins import Compilable -from data_diff.sqeleton.abcs.mixins import ( - AbstractMixin_Schema, - AbstractMixin_RandomSample, - AbstractMixin_NormalizeValue, - AbstractMixin_OptimizerHints, -) -from data_diff.sqeleton.bound_exprs import bound_table - -logger = logging.getLogger("database") - - -def parse_table_name(t): - return tuple(t.split(".")) - - -def import_helper(package: str = None, text=""): - def dec(f): - @wraps(f) - def _inner(): - try: - return f() - except ModuleNotFoundError as e: - s = text - if package: - s += f"Please complete setup by running: pip install 'data_diff[{package}]'." - raise ModuleNotFoundError(f"{e}\n\n{s}\n") - - return _inner - - return dec - - -class ConnectError(Exception): - pass - - -class QueryError(Exception): - pass - - -def _one(seq): - (x,) = seq - return x - - -class ThreadLocalInterpreter: - """An interpeter used to execute a sequence of queries within the same thread and cursor. - - Useful for cursor-sensitive operations, such as creating a temporary table. - """ - - def __init__(self, compiler: Compiler, gen: Generator): - self.gen = gen - self.compiler = compiler - - def apply_queries(self, callback: Callable[[str], Any]): - q: Expr = next(self.gen) - while True: - sql = self.compiler.compile(q) - logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql) - try: - try: - res = callback(sql) if sql is not SKIP else SKIP - except Exception as e: - q = self.gen.throw(type(e), e) - else: - q = self.gen.send(res) - except StopIteration: - break - - -def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocalInterpreter]) -> list: - if isinstance(sql_code, ThreadLocalInterpreter): - return sql_code.apply_queries(callback) - else: - return callback(sql_code) - - -class Mixin_Schema(AbstractMixin_Schema): - def table_information(self) -> Compilable: - return table("information_schema", "tables") - - def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: - return ( - self.table_information() - .where( - this.table_schema == table_schema, - this.table_name.like(like) if like is not None else SKIP, - this.table_type == "BASE TABLE", - ) - .select(this.table_name) - ) - - -class Mixin_RandomSample(AbstractMixin_RandomSample): - def random_sample_n(self, tbl: AbstractTable, size: int) -> AbstractTable: - # TODO use a more efficient algorithm, when the table count is known - return tbl.order_by(Random()).limit(size) - - def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> AbstractTable: - return tbl.where(Random() < ratio) - - -class Mixin_OptimizerHints(AbstractMixin_OptimizerHints): - def optimizer_hints(self, hints: str) -> str: - return f"/*+ {hints} */ " - - -class BaseDialect(AbstractDialect): - SUPPORTS_PRIMARY_KEY = False - SUPPORTS_INDEXES = False - TYPE_CLASSES: Dict[str, type] = {} - MIXINS = frozenset() - - PLACEHOLDER_TABLE = None # Used for Oracle - - def offset_limit( - self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None - ) -> str: - if offset: - raise NotImplementedError("No support for OFFSET in query") - - return f"LIMIT {limit}" - - def concat(self, items: List[str]) -> str: - assert len(items) > 1 - joined_exprs = ", ".join(items) - return f"concat({joined_exprs})" - - def to_comparable(self, value: str, coltype: ColType) -> str: - """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" - return value - - def is_distinct_from(self, a: str, b: str) -> str: - return f"{a} is distinct from {b}" - - def timestamp_value(self, t: DbTime) -> str: - return f"'{t.isoformat()}'" - - def random(self) -> str: - return "random()" - - def current_timestamp(self) -> str: - return "current_timestamp()" - - def current_database(self) -> str: - return "current_database()" - - def current_schema(self) -> str: - return "current_schema()" - - def explain_as_text(self, query: str) -> str: - return f"EXPLAIN {query}" - - def _constant_value(self, v): - if v is None: - return "NULL" - elif isinstance(v, str): - return f"'{v}'" - elif isinstance(v, datetime): - return self.timestamp_value(v) - elif isinstance(v, UUID): - return f"'{v}'" - elif isinstance(v, decimal.Decimal): - return str(v) - elif isinstance(v, bytearray): - return f"'{v.decode()}'" - elif isinstance(v, Code): - return v.code - return repr(v) - - def constant_values(self, rows) -> str: - values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) - return f"VALUES {values}" - - def type_repr(self, t) -> str: - if isinstance(t, str): - return t - elif isinstance(t, TimestampTZ): - return f"TIMESTAMP({min(t.precision, DEFAULT_DATETIME_PRECISION)})" - return { - int: "INT", - str: "VARCHAR", - bool: "BOOLEAN", - float: "FLOAT", - datetime: "TIMESTAMP", - }[t] - - def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: - return self.TYPE_CLASSES.get(type_repr) - - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - """ """ - - cls = self._parse_type_repr(type_repr) - if cls is None: - return UnknownColType(type_repr) - - if issubclass(cls, TemporalType): - return cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, - rounds=self.ROUNDS_ON_PREC_LOSS, - ) - - elif issubclass(cls, Integer): - return cls() - - elif issubclass(cls, Boolean): - return cls() - - elif issubclass(cls, Decimal): - if numeric_scale is None: - numeric_scale = 0 # Needed for Oracle. - return cls(precision=numeric_scale) - - elif issubclass(cls, Float): - # assert numeric_scale is None - return cls( - precision=self._convert_db_precision_to_digits( - numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION - ) - ) - - elif issubclass(cls, (JSON, Array, Struct, Text, Native_UUID)): - return cls() - - raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.") - - def _convert_db_precision_to_digits(self, p: int) -> int: - """Convert from binary precision, used by floats, to decimal precision.""" - # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format - return math.floor(math.log(2**p, 10)) - - @classmethod - def load_mixins(cls, *abstract_mixins) -> Self: - mixins = {m for m in cls.MIXINS if issubclass(m, abstract_mixins)} - - class _DialectWithMixins(cls, *mixins, *abstract_mixins): - pass - - _DialectWithMixins.__name__ = cls.__name__ - return _DialectWithMixins() - - -T = TypeVar("T", bound=BaseDialect) - - -@dataclass -class QueryResult: - rows: list - columns: list = None - - def __iter__(self): - return iter(self.rows) - - def __len__(self): - return len(self.rows) - - def __getitem__(self, i): - return self.rows[i] - - -class Database(AbstractDatabase[T]): - """Base abstract class for databases. - - Used for providing connection code and implementation specific SQL utilities. - - Instanciated using :meth:`~data_diff.sqeleton.connect` - """ - - default_schema: str = None - SUPPORTS_ALPHANUMS = True - SUPPORTS_UNIQUE_CONSTAINT = False - - CONNECT_URI_KWPARAMS = [] - - _interactive = False - is_closed = False - - @property - def name(self): - return type(self).__name__ - - def compile(self, sql_ast): - compiler = Compiler(self) - return compiler.compile(sql_ast) - - def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): - """Query the given SQL code/AST, and attempt to convert the result to type 'res_type' - - If given a generator, it will execute all the yielded sql queries with the same thread and cursor. - The results of the queries a returned by the `yield` stmt (using the .send() mechanism). - It's a cleaner approach than exposing cursors, but may not be enough in all cases. - """ - - compiler = Compiler(self) - if isinstance(sql_ast, Generator): - sql_code = ThreadLocalInterpreter(compiler, sql_ast) - elif isinstance(sql_ast, list): - for i in sql_ast[:-1]: - self.query(i) - return self.query(sql_ast[-1], res_type) - else: - if isinstance(sql_ast, str): - sql_code = sql_ast - else: - if res_type is None: - res_type = sql_ast.type - sql_code = compiler.compile(sql_ast) - if sql_code is SKIP: - return SKIP - - logger.debug("Running SQL (%s): %s", self.name, sql_code) - - if self._interactive and isinstance(sql_ast, Select): - explained_sql = compiler.compile(Explain(sql_ast)) - explain = self._query(explained_sql) - for row in explain: - # Most returned a 1-tuple. Presto returns a string - if isinstance(row, tuple): - (row,) = row - logger.debug("EXPLAIN: %s", row) - answer = input("Continue? [y/n] ") - if answer.lower() not in ["y", "yes"]: - sys.exit(1) - - res = self._query(sql_code) - if res_type is list: - return list(res) - elif res_type is int: - if not res: - raise ValueError("Query returned 0 rows, expected 1") - row = _one(res) - if not row: - raise ValueError("Row is empty, expected 1 column") - res = _one(row) - if res is None: # May happen due to sum() of 0 items - return None - return int(res) - elif res_type is datetime: - res = _one(_one(res)) - if isinstance(res, str): - res = datetime.fromisoformat(res[:23]) # TODO use a better parsing method - return res - elif res_type is tuple: - assert len(res) == 1, (sql_code, res) - return res[0] - elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1: - if res_type.__args__ in ((int,), (str,)): - return [_one(row) for row in res] - elif res_type.__args__ in [(Tuple,), (tuple,)]: - return [tuple(row) for row in res] - elif res_type.__args__ == (dict,): - return [dict(safezip(res.columns, row)) for row in res] - else: - raise ValueError(res_type) - return res - - def enable_interactive(self): - self._interactive = True - - def select_table_schema(self, path: DbPath) -> str: - """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" - schema, name = self._normalize_table_path(path) - - return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " - "FROM information_schema.columns " - f"WHERE table_name = '{name}' AND table_schema = '{schema}'" - ) - - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: - rows = self.query(self.select_table_schema(path), list) - if not rows: - raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - - d = {r[0]: r for r in rows} - assert len(d) == len(rows) - return d - - def select_table_unique_columns(self, path: DbPath) -> str: - schema, name = self._normalize_table_path(path) - - return ( - "SELECT column_name " - "FROM information_schema.key_column_usage " - f"WHERE table_name = '{name}' AND table_schema = '{schema}'" - ) - - def query_table_unique_columns(self, path: DbPath) -> List[str]: - if not self.SUPPORTS_UNIQUE_CONSTAINT: - raise NotImplementedError("This database doesn't support 'unique' constraints") - res = self.query(self.select_table_unique_columns(path), List[str]) - return list(res) - - def _process_table_schema( - self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str] = None, where: str = None - ): - if filter_columns is None: - filtered_schema = raw_schema - else: - accept = {i.lower() for i in filter_columns} - filtered_schema = {name: row for name, row in raw_schema.items() if name.lower() in accept} - - col_dict = {row[0]: self.dialect.parse_type(path, *row) for _name, row in filtered_schema.items()} - - self._refine_coltypes(path, col_dict, where) - - # Return a dict of form {name: type} after normalization - return col_dict - - def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], where: str = None, sample_size=64): - """Refine the types in the column dict, by querying the database for a sample of their values - - 'where' restricts the rows to be sampled. - """ - - text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)] - if not text_columns: - return - - if isinstance(self.dialect, AbstractMixin_NormalizeValue): - fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns] - else: - fields = this[text_columns] - - samples_by_row = self.query( - table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list - ) - if not samples_by_row: - raise ValueError(f"Table {table_path} is empty.") - - samples_by_col = list(zip(*samples_by_row)) - - for col_name, samples in safezip(text_columns, samples_by_col): - uuid_samples = [s for s in samples if s and is_uuid(s)] - - if uuid_samples: - if len(uuid_samples) != len(samples): - logger.warning( - f"Mixed UUID/Non-UUID values detected in column {'.'.join(table_path)}.{col_name}, disabling UUID support." - ) - else: - assert col_name in col_dict - col_dict[col_name] = String_UUID() - continue - - if self.SUPPORTS_ALPHANUMS: # Anything but MySQL (so far) - alphanum_samples = [s for s in samples if String_Alphanum.test_value(s)] - if alphanum_samples: - if len(alphanum_samples) != len(samples): - logger.debug( - f"Mixed Alphanum/Non-Alphanum values detected in column {'.'.join(table_path)}.{col_name}. It cannot be used as a key." - ) - else: - assert col_name in col_dict - col_dict[col_name] = String_VaryingAlphanum() - - # @lru_cache() - # def get_table_schema(self, path: DbPath) -> Dict[str, ColType]: - # return self.query_table_schema(path) - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return self.default_schema, path[0] - elif len(path) == 2: - return path - - raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") - - def parse_table_name(self, name: str) -> DbPath: - return parse_table_name(name) - - def _query_cursor(self, c, sql_code: str) -> QueryResult: - assert isinstance(sql_code, str), sql_code - try: - c.execute(sql_code) - if sql_code.lower().startswith(("select", "explain", "show")): - columns = [col[0] for col in c.description] - - fetched = c.fetchall() - result = QueryResult(fetched, columns) - return result - except Exception as _e: - # logger.exception(e) - # logger.error(f'Caused by SQL: {sql_code}') - raise - - def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult: - c = conn.cursor() - callback = partial(self._query_cursor, c) - return apply_query(callback, sql_code) - - def close(self): - self.is_closed = True - return super().close() - - def list_tables(self, tables_like, schema=None): - return self.query(self.dialect.list_tables(schema or self.default_schema, tables_like)) - - def table(self, *path, **kw): - return bound_table(self, path, **kw) - - -class ThreadedDatabase(Database): - """Access the database through singleton threads. - - Used for database connectors that do not support sharing their connection between different threads. - """ - - def __init__(self, thread_count=1): - self._init_error = None - self._queue = ThreadPoolExecutor(thread_count, initializer=self.set_conn) - self.thread_local = threading.local() - logger.info(f"[{self.name}] Starting a threadpool, size={thread_count}.") - - def set_conn(self): - assert not hasattr(self.thread_local, "conn") - try: - self.thread_local.conn = self.create_connection() - except Exception as e: - self._init_error = e - - def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult: - r = self._queue.submit(self._query_in_worker, sql_code) - return r.result() - - def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]): - "This method runs in a worker thread" - if self._init_error: - raise self._init_error - return self._query_conn(self.thread_local.conn, sql_code) - - @abstractmethod - def create_connection(self): - "Return a connection instance, that supports the .cursor() method." - - def close(self): - super().close() - self._queue.shutdown() - - @property - def is_autocommit(self) -> bool: - return False - - -# TODO FYI mssql md5_as_int currently requires this to be reduced -CHECKSUM_HEXDIGITS = 14 # Must be 15 or lower, otherwise SUM() overflows -MD5_HEXDIGITS = 32 - -_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2 -CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1 - -DEFAULT_DATETIME_PRECISION = 6 -DEFAULT_NUMERIC_PRECISION = 24 - -TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20 diff --git a/data_diff/sqeleton/databases/bigquery.py b/data_diff/sqeleton/databases/bigquery.py deleted file mode 100644 index 0bac1ff6..00000000 --- a/data_diff/sqeleton/databases/bigquery.py +++ /dev/null @@ -1,297 +0,0 @@ -import re -from typing import Any, List, Union -from data_diff.sqeleton.abcs.database_types import ( - ColType, - Array, - JSON, - Struct, - Timestamp, - Datetime, - Integer, - Decimal, - Float, - Text, - DbPath, - FractionalType, - TemporalType, - Boolean, - UnknownColType, -) -from data_diff.sqeleton.abcs.mixins import ( - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, - AbstractMixin_Schema, - AbstractMixin_TimeTravel, -) -from data_diff.sqeleton.abcs import Compilable -from data_diff.sqeleton.queries import this, table, SKIP, code -from data_diff.sqeleton.databases.base import ( - BaseDialect, - Database, - import_helper, - parse_table_name, - ConnectError, - apply_query, - QueryResult, -) -from data_diff.sqeleton.databases.base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter, Mixin_RandomSample - - -@import_helper(text="Please install BigQuery and configure your google-cloud access.") -def import_bigquery(): - from google.cloud import bigquery - - return bigquery - - -def import_bigquery_service_account(): - from google.oauth2 import service_account - - return service_account - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" - - if coltype.precision == 0: - return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000', {value})" - elif coltype.precision == 6: - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - - timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return f"format('%.{coltype.precision}f', {value})" - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"cast({value} as int)") - - def normalize_json(self, value: str, _coltype: JSON) -> str: - # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: - # Got error: 400 Grouping is not defined for arguments of type ARRAY at … - # So we do the best effort and compare it as strings, hoping that the JSON forms - # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. - return f"to_json_string({value})" - - def normalize_array(self, value: str, _coltype: Array) -> str: - # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: - # Got error: 400 Grouping is not defined for arguments of type ARRAY at … - # So we do the best effort and compare it as strings, hoping that the JSON forms - # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. - return f"to_json_string({value})" - - def normalize_struct(self, value: str, _coltype: Struct) -> str: - # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: - # Got error: 400 Grouping is not defined for arguments of type ARRAY at … - # So we do the best effort and compare it as strings, hoping that the JSON forms - # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. - return f"to_json_string({value})" - - -class Mixin_Schema(AbstractMixin_Schema): - def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: - return ( - table(table_schema, "INFORMATION_SCHEMA", "TABLES") - .where( - this.table_schema == table_schema, - this.table_name.like(like) if like is not None else SKIP, - this.table_type == "BASE TABLE", - ) - .select(this.table_name) - ) - - -class Mixin_TimeTravel(AbstractMixin_TimeTravel): - def time_travel( - self, - table: Compilable, - before: bool = False, - timestamp: Compilable = None, - offset: Compilable = None, - statement: Compilable = None, - ) -> Compilable: - if before: - raise NotImplementedError("before=True not supported for BigQuery time-travel") - - if statement is not None: - raise NotImplementedError("BigQuery time-travel doesn't support querying by statement id") - - if timestamp is not None: - assert offset is None - return code("{table} FOR SYSTEM_TIME AS OF {timestamp}", table=table, timestamp=timestamp) - - assert offset is not None - return code( - "{table} FOR SYSTEM_TIME AS OF TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL {offset} HOUR);", - table=table, - offset=offset, - ) - - -class Dialect(BaseDialect, Mixin_Schema): - name = "BigQuery" - ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation - TYPE_CLASSES = { - # Dates - "TIMESTAMP": Timestamp, - "DATETIME": Datetime, - # Numbers - "INT64": Integer, - "INT32": Integer, - "NUMERIC": Decimal, - "BIGNUMERIC": Decimal, - "FLOAT64": Float, - "FLOAT32": Float, - "STRING": Text, - "BOOL": Boolean, - "JSON": JSON, - } - TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>") - TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>") - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} - - def random(self) -> str: - return "RAND()" - - def quote(self, s: str): - return f"`{s}`" - - def to_string(self, s: str): - return f"cast({s} as string)" - - def type_repr(self, t) -> str: - try: - return {str: "STRING", float: "FLOAT64"}[t] - except KeyError: - return super().type_repr(t) - - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - *args: Any, # pass-through args - **kwargs: Any, # pass-through args - ) -> ColType: - col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs) - if isinstance(col_type, UnknownColType): - m = self.TYPE_ARRAY_RE.fullmatch(type_repr) - if m: - item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs) - col_type = Array(item_type=item_type) - - # We currently ignore structs' structure, but later can parse it too. Examples: - # - STRUCT (unnamed) - # - STRUCT (named) - # - STRUCT> (with complex fields) - # - STRUCT> (nested) - m = self.TYPE_STRUCT_RE.fullmatch(type_repr) - if m: - col_type = Struct() - - return col_type - - def to_comparable(self, value: str, coltype: ColType) -> str: - """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" - if isinstance(coltype, (JSON, Array, Struct)): - return self.normalize_value_by_type(value, coltype) - else: - return super().to_comparable(value, coltype) - - def set_timezone_to_utc(self) -> str: - raise NotImplementedError() - - -class BigQuery(Database): - CONNECT_URI_HELP = "bigquery:///" - CONNECT_URI_PARAMS = ["dataset"] - dialect = Dialect() - - def __init__(self, project, *, dataset, bigquery_credentials=None, **kw): - credentials = bigquery_credentials - bigquery = import_bigquery() - - keyfile = kw.pop("keyfile", None) - if keyfile: - bigquery_service_account = import_bigquery_service_account() - credentials = bigquery_service_account.Credentials.from_service_account_file( - keyfile, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - - self._client = bigquery.Client(project=project, credentials=credentials, **kw) - self.project = project - self.dataset = dataset - - self.default_schema = dataset - - def _normalize_returned_value(self, value): - if isinstance(value, bytes): - return value.decode() - return value - - def _query_atom(self, sql_code: str): - from google.cloud import bigquery - - try: - result = self._client.query(sql_code).result() - columns = [c.name for c in result.schema] - rows = list(result) - except Exception as e: - msg = "Exception when trying to execute SQL code:\n %s\n\nGot error: %s" - raise ConnectError(msg % (sql_code, e)) - - if rows and isinstance(rows[0], bigquery.table.Row): - rows = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in rows] - return QueryResult(rows, columns) - - def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult: - return apply_query(self._query_atom, sql_code) - - def close(self): - super().close() - self._client.close() - - def select_table_schema(self, path: DbPath) -> str: - project, schema, name = self._normalize_table_path(path) - return ( - "SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale " - f"FROM `{project}`.`{schema}`.INFORMATION_SCHEMA.COLUMNS " - f"WHERE table_name = '{name}' AND table_schema = '{schema}'" - ) - - def query_table_unique_columns(self, path: DbPath) -> List[str]: - return [] - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 0: - raise ValueError(f"{self.name}: Bad table path for {self}: ()") - elif len(path) == 1: - return (self.project, self.default_schema, path[0]) - elif len(path) == 2: - return (self.project,) + path - elif len(path) == 3: - return path - else: - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: [project.]schema.table" - ) - - def parse_table_name(self, name: str) -> DbPath: - path = parse_table_name(name) - return tuple(i for i in self._normalize_table_path(path) if i is not None) - - @property - def is_autocommit(self) -> bool: - return True diff --git a/data_diff/sqeleton/databases/clickhouse.py b/data_diff/sqeleton/databases/clickhouse.py deleted file mode 100644 index e14cd226..00000000 --- a/data_diff/sqeleton/databases/clickhouse.py +++ /dev/null @@ -1,196 +0,0 @@ -from typing import Optional, Type - -from data_diff.sqeleton.databases.base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - TIMESTAMP_PRECISION_POS, - BaseDialect, - ThreadedDatabase, - import_helper, - ConnectError, - Mixin_RandomSample, -) -from data_diff.sqeleton.abcs.database_types import ( - ColType, - Decimal, - Float, - Integer, - FractionalType, - Native_UUID, - TemporalType, - Text, - Timestamp, - Boolean, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue - -# https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings/#default-database -DEFAULT_DATABASE = "default" - - -@import_helper("clickhouse") -def import_clickhouse(): - import clickhouse_driver - - return clickhouse_driver - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS - return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_number(self, value: str, coltype: FractionalType) -> str: - # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. - # For example: - # select toString(toDecimal128(1.10, 2)); -- the result is 1.1 - # select toString(toDecimal128(1.00, 2)); -- the result is 1 - # So, we should use some custom approach to save these trailing zeros. - # To avoid it, we can add a small value like 0.000001 to prevent dropping of zeros from the end when casting. - # For examples above it looks like: - # select toString(toDecimal128(1.10, 2 + 1) + toDecimal128(0.001, 3)); -- the result is 1.101 - # After that, cut an extra symbol from the string, i.e. 1.101 -> 1.10 - # So, the algorithm is: - # 1. Cast to decimal with precision + 1 - # 2. Add a small value 10^(-precision-1) - # 3. Cast the result to string - # 4. Drop the extra digit from the string. To do that, we need to slice the string - # with length = digits in an integer part + 1 (symbol of ".") + precision - - if coltype.precision == 0: - return self.to_string(f"round({value})") - - precision = coltype.precision - # TODO: too complex, is there better performance way? - value = f""" - if({value} >= 0, '', '-') || left( - toString( - toDecimal128( - round(abs({value}), {precision}), - {precision} + 1 - ) - + - toDecimal128( - exp10(-{precision + 1}), - {precision} + 1 - ) - ), - toUInt8( - greatest( - floor(log10(abs({value}))) + 1, - 1 - ) - ) + 1 + {precision} - ) - """ - return value - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - prec = coltype.precision - if coltype.rounds: - timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)" - return self.to_string(timestamp) - - fractional = f"toUnixTimestamp64Micro(toDateTime64({value}, {prec})) % 1000000" - fractional = f"lpad({self.to_string(fractional)}, 6, '0')" - value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}" - return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" - - -class Dialect(BaseDialect): - name = "Clickhouse" - ROUNDS_ON_PREC_LOSS = False - TYPE_CLASSES = { - "Int8": Integer, - "Int16": Integer, - "Int32": Integer, - "Int64": Integer, - "Int128": Integer, - "Int256": Integer, - "UInt8": Integer, - "UInt16": Integer, - "UInt32": Integer, - "UInt64": Integer, - "UInt128": Integer, - "UInt256": Integer, - "Float32": Float, - "Float64": Float, - "Decimal": Decimal, - "UUID": Native_UUID, - "String": Text, - "FixedString": Text, - "DateTime": Timestamp, - "DateTime64": Timestamp, - "Bool": Boolean, - } - MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - def quote(self, s: str) -> str: - return f'"{s}"' - - def to_string(self, s: str) -> str: - return f"toString({s})" - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Done the same as for PostgreSQL but need to rewrite in another way - # because it does not help for float with a big integer part. - return super()._convert_db_precision_to_digits(p) - 2 - - def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: - nullable_prefix = "Nullable(" - if type_repr.startswith(nullable_prefix): - type_repr = type_repr[len(nullable_prefix) :].rstrip(")") - - if type_repr.startswith("Decimal"): - type_repr = "Decimal" - elif type_repr.startswith("FixedString"): - type_repr = "FixedString" - elif type_repr.startswith("DateTime64"): - type_repr = "DateTime64" - - return self.TYPE_CLASSES.get(type_repr) - - # def timestamp_value(self, t: DbTime) -> str: - # # return f"'{t}'" - # return f"'{str(t)[:19]}'" - - def set_timezone_to_utc(self) -> str: - raise NotImplementedError() - - def current_timestamp(self) -> str: - return "now()" - - -class Clickhouse(ThreadedDatabase): - dialect = Dialect() - CONNECT_URI_HELP = "clickhouse://:@/" - CONNECT_URI_PARAMS = ["database?"] - - def __init__(self, *, thread_count: int, **kw): - super().__init__(thread_count=thread_count) - - self._args = kw - # In Clickhouse database and schema are the same - self.default_schema = kw.get("database", DEFAULT_DATABASE) - - def create_connection(self): - clickhouse = import_clickhouse() - - class SingleConnection(clickhouse.dbapi.connection.Connection): - """Not thread-safe connection to Clickhouse""" - - def cursor(self, cursor_factory=None): - if not len(self.cursors): - _ = super().cursor() - return self.cursors[0] - - try: - return SingleConnection(**self._args) - except clickhouse.OperationError as e: - raise ConnectError(*e.args) from e - - @property - def is_autocommit(self) -> bool: - return True diff --git a/data_diff/sqeleton/databases/databricks.py b/data_diff/sqeleton/databases/databricks.py deleted file mode 100644 index a5474ee2..00000000 --- a/data_diff/sqeleton/databases/databricks.py +++ /dev/null @@ -1,199 +0,0 @@ -import math -from typing import Dict, Sequence -import logging - -from data_diff.sqeleton.abcs.database_types import ( - Integer, - Float, - Decimal, - Timestamp, - Text, - TemporalType, - NumericType, - DbPath, - ColType, - UnknownColType, - Boolean, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from data_diff.sqeleton.databases.base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - BaseDialect, - ThreadedDatabase, - import_helper, - parse_table_name, - Mixin_RandomSample, -) - - -@import_helper(text="You can install it using 'pip install databricks-sql-connector'") -def import_databricks(): - import databricks.sql - - return databricks - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - """Databricks timestamp contains no more than 6 digits in precision""" - - if coltype.rounds: - timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)" - return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')" - - precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision) - return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" - - def normalize_number(self, value: str, coltype: NumericType) -> str: - value = f"cast({value} as decimal(38, {coltype.precision}))" - if coltype.precision > 0: - value = f"format_number({value}, {coltype.precision})" - return f"replace({self.to_string(value)}, ',', '')" - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"cast ({value} as int)") - - -class Dialect(BaseDialect): - name = "Databricks" - ROUNDS_ON_PREC_LOSS = True - TYPE_CLASSES = { - # Numbers - "INT": Integer, - "SMALLINT": Integer, - "TINYINT": Integer, - "BIGINT": Integer, - "FLOAT": Float, - "DOUBLE": Float, - "DECIMAL": Decimal, - # Timestamps - "TIMESTAMP": Timestamp, - # Text - "STRING": Text, - # Boolean - "BOOLEAN": Boolean, - } - MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - def quote(self, s: str): - return f"`{s}`" - - def to_string(self, s: str) -> str: - return f"cast({s} as string)" - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 2 due to wierd precision issues - return max(super()._convert_db_precision_to_digits(p) - 2, 0) - - def set_timezone_to_utc(self) -> str: - return "SET TIME ZONE 'UTC'" - - -class Databricks(ThreadedDatabase): - dialect = Dialect() - CONNECT_URI_HELP = "databricks://:@/" - CONNECT_URI_PARAMS = ["catalog", "schema"] - - def __init__(self, *, thread_count, **kw): - logging.getLogger("databricks.sql").setLevel(logging.WARNING) - - self._args = kw - self.default_schema = kw.get("schema", "default") - self.catalog = self._args.get("catalog", "hive_metastore") - super().__init__(thread_count=thread_count) - - def create_connection(self): - databricks = import_databricks() - - try: - return databricks.sql.connect( - server_hostname=self._args["server_hostname"], - http_path=self._args["http_path"], - access_token=self._args["access_token"], - catalog=self.catalog, - ) - except databricks.sql.exc.Error as e: - raise ConnectionError(*e.args) from e - - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: - # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. - # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html - # So, to obtain information about schema, we should use another approach. - - conn = self.create_connection() - - catalog, schema, table = self._normalize_table_path(path) - with conn.cursor() as cursor: - cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table) - try: - rows = cursor.fetchall() - finally: - conn.close() - if not rows: - raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - - d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} - assert len(d) == len(rows) - return d - - def _process_table_schema( - self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None - ): - accept = {i.lower() for i in filter_columns} - rows = [row for name, row in raw_schema.items() if name.lower() in accept] - - resulted_rows = [] - for row in rows: - row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1] - type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType) - - if issubclass(type_cls, Integer): - row = (row[0], row_type, None, None, 0) - - elif issubclass(type_cls, Float): - numeric_precision = math.ceil(row[2] / math.log(2, 10)) - row = (row[0], row_type, None, numeric_precision, None) - - elif issubclass(type_cls, Decimal): - items = row[1][8:].rstrip(")").split(",") - numeric_precision, numeric_scale = int(items[0]), int(items[1]) - row = (row[0], row_type, None, numeric_precision, numeric_scale) - - elif issubclass(type_cls, Timestamp): - row = (row[0], row_type, row[2], None, None) - - else: - row = (row[0], row_type, None, None, None) - - resulted_rows.append(row) - - col_dict: Dict[str, ColType] = {row[0]: self.dialect.parse_type(path, *row) for row in resulted_rows} - - self._refine_coltypes(path, col_dict, where) - return col_dict - - def parse_table_name(self, name: str) -> DbPath: - path = parse_table_name(name) - return tuple(i for i in self._normalize_table_path(path) if i is not None) - - @property - def is_autocommit(self) -> bool: - return True - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return self.catalog, self.default_schema, path[0] - elif len(path) == 2: - return self.catalog, path[0], path[1] - elif len(path) == 3: - return path - - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or catalog.schema.table" - ) diff --git a/data_diff/sqeleton/databases/duckdb.py b/data_diff/sqeleton/databases/duckdb.py deleted file mode 100644 index 3714b00d..00000000 --- a/data_diff/sqeleton/databases/duckdb.py +++ /dev/null @@ -1,192 +0,0 @@ -from typing import Union - -from data_diff.sqeleton.utils import match_regexps -from data_diff.sqeleton.abcs.database_types import ( - Timestamp, - TimestampTZ, - DbPath, - ColType, - Float, - Decimal, - Integer, - TemporalType, - Native_UUID, - Text, - FractionalType, - Boolean, - AbstractTable, -) -from data_diff.sqeleton.abcs.mixins import ( - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, - AbstractMixin_RandomSample, - AbstractMixin_Regex, -) -from data_diff.sqeleton.databases.base import ( - Database, - BaseDialect, - import_helper, - ConnectError, - ThreadLocalInterpreter, - TIMESTAMP_PRECISION_POS, -) -from data_diff.sqeleton.databases.base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Mixin_Schema -from data_diff.sqeleton.queries.ast_classes import Func, Compilable -from data_diff.sqeleton.queries.api import code - - -@import_helper("duckdb") -def import_duckdb(): - import duckdb - - return duckdb - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - # It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers. - if coltype.rounds and coltype.precision > 0: - return f"CONCAT(SUBSTRING(STRFTIME({value}::TIMESTAMP, '%Y-%m-%d %H:%M:%S.'),1,23), LPAD(((ROUND(strftime({value}::timestamp, '%f')::DECIMAL(15,7)/100000,{coltype.precision-1})*100000)::INT)::VARCHAR,6,'0'))" - - return f"rpad(substring(strftime({value}::timestamp, '%Y-%m-%d %H:%M:%S.%f'),1,{TIMESTAMP_PRECISION_POS+coltype.precision}),26,'0')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"{value}::DECIMAL(38, {coltype.precision})") - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"{value}::INTEGER") - - -class Mixin_RandomSample(AbstractMixin_RandomSample): - def random_sample_n(self, tbl: AbstractTable, size: int) -> AbstractTable: - return code("SELECT * FROM ({tbl}) USING SAMPLE {size};", tbl=tbl, size=size) - - def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> AbstractTable: - return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio)) - - -class Mixin_Regex(AbstractMixin_Regex): - def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: - return Func("regexp_matches", [string, pattern]) - - -class Dialect(BaseDialect, Mixin_Schema): - name = "DuckDB" - ROUNDS_ON_PREC_LOSS = False - SUPPORTS_PRIMARY_KEY = True - SUPPORTS_INDEXES = True - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - TYPE_CLASSES = { - # Timestamps - "TIMESTAMP WITH TIME ZONE": TimestampTZ, - "TIMESTAMP": Timestamp, - # Numbers - "DOUBLE": Float, - "FLOAT": Float, - "DECIMAL": Decimal, - "INTEGER": Integer, - "BIGINT": Integer, - # Text - "VARCHAR": Text, - "TEXT": Text, - # UUID - "UUID": Native_UUID, - # Bool - "BOOLEAN": Boolean, - } - - def quote(self, s: str): - return f'"{s}"' - - def to_string(self, s: str): - return f"{s}::VARCHAR" - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 2 due to wierd precision issues in PostgreSQL - return super()._convert_db_precision_to_digits(p) - 2 - - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - regexps = { - r"DECIMAL\((\d+),(\d+)\)": Decimal, - } - - for m, t_cls in match_regexps(regexps, type_repr): - precision = int(m.group(2)) - return t_cls(precision=precision) - - return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) - - def set_timezone_to_utc(self) -> str: - return "SET GLOBAL TimeZone='UTC'" - - def current_timestamp(self) -> str: - return "current_timestamp" - - -class DuckDB(Database): - dialect = Dialect() - SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it - default_schema = "main" - CONNECT_URI_HELP = "duckdb://@" - CONNECT_URI_PARAMS = ["database", "dbpath"] - - def __init__(self, **kw): - self._args = kw - self._conn = self.create_connection() - - @property - def is_autocommit(self) -> bool: - return True - - def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): - "Uses the standard SQL cursor interface" - return self._query_conn(self._conn, sql_code) - - def close(self): - super().close() - self._conn.close() - - def create_connection(self): - ddb = import_duckdb() - try: - return ddb.connect(self._args["filepath"]) - except ddb.OperationalError as e: - raise ConnectError(*e.args) from e - - def select_table_schema(self, path: DbPath) -> str: - database, schema, table = self._normalize_table_path(path) - - info_schema_path = ["information_schema", "columns"] - if database: - info_schema_path.insert(0, database) - - return ( - f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return None, self.default_schema, path[0] - elif len(path) == 2: - return None, path[0], path[1] - elif len(path) == 3: - return path - - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" - ) diff --git a/data_diff/sqeleton/databases/mssql.py b/data_diff/sqeleton/databases/mssql.py deleted file mode 100644 index cc0754a7..00000000 --- a/data_diff/sqeleton/databases/mssql.py +++ /dev/null @@ -1,214 +0,0 @@ -from typing import Optional -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from data_diff.sqeleton.databases.base import ( - CHECKSUM_HEXDIGITS, - Mixin_OptimizerHints, - Mixin_RandomSample, - QueryError, - ThreadedDatabase, - import_helper, - ConnectError, - BaseDialect, -) -from data_diff.sqeleton.databases.base import Mixin_Schema -from data_diff.sqeleton.abcs.database_types import ( - JSON, - Timestamp, - TimestampTZ, - DbPath, - Float, - Decimal, - Integer, - TemporalType, - Native_UUID, - Text, - FractionalType, - Boolean, -) - - -@import_helper("mssql") -def import_mssql(): - import pyodbc - - return pyodbc - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.precision > 0: - formatted_value = ( - f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss') + '.' + " - f"SUBSTRING(FORMAT({value}, 'fffffff'), 1, {coltype.precision})" - ) - else: - formatted_value = f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss')" - - return formatted_value - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - if coltype.precision == 0: - return f"CAST(FLOOR({value}) AS VARCHAR)" - - return f"FORMAT({value}, 'N{coltype.precision}')" - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))" - - -class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): - name = "MsSQL" - ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True - SUPPORTS_INDEXES = True - TYPE_CLASSES = { - # Timestamps - "datetimeoffset": TimestampTZ, - "datetime": Timestamp, - "datetime2": Timestamp, - "smalldatetime": Timestamp, - "date": Timestamp, - # Numbers - "float": Float, - "real": Float, - "decimal": Decimal, - "money": Decimal, - "smallmoney": Decimal, - # int - "int": Integer, - "bigint": Integer, - "tinyint": Integer, - "smallint": Integer, - # Text - "varchar": Text, - "char": Text, - "text": Text, - "ntext": Text, - "nvarchar": Text, - "nchar": Text, - "binary": Text, - "varbinary": Text, - # UUID - "uniqueidentifier": Native_UUID, - # Bool - "bit": Boolean, - # JSON - "json": JSON, - } - - MIXINS = {Mixin_Schema, Mixin_NormalizeValue, Mixin_RandomSample} - - def quote(self, s: str): - return f"[{s}]" - - def set_timezone_to_utc(self) -> str: - raise NotImplementedError("MsSQL does not support a session timezone setting.") - - def current_timestamp(self) -> str: - return "GETDATE()" - - def current_database(self) -> str: - return "DB_NAME()" - - def current_schema(self) -> str: - return """default_schema_name - FROM sys.database_principals - WHERE name = CURRENT_USER""" - - def to_string(self, s: str): - return f"CONVERT(varchar, {s})" - - def type_repr(self, t) -> str: - try: - return {bool: "bit"}[t] - except KeyError: - return super().type_repr(t) - - def random(self) -> str: - return "rand()" - - def is_distinct_from(self, a: str, b: str) -> str: - # IS (NOT) DISTINCT FROM is available only since SQLServer 2022. - # See: https://stackoverflow.com/a/18684859/857383 - return f"(({a}<>{b} OR {a} IS NULL OR {b} IS NULL) AND NOT({a} IS NULL AND {b} IS NULL))" - - def offset_limit( - self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None - ) -> str: - if offset: - raise NotImplementedError("No support for OFFSET in query") - - result = "" - if not has_order_by: - result += "ORDER BY 1" - - result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY" - return result - - def constant_values(self, rows) -> str: - values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) - return f"VALUES {values}" - - -class MsSQL(ThreadedDatabase): - dialect = Dialect() - # - CONNECT_URI_HELP = "mssql://:@//" - CONNECT_URI_PARAMS = ["database", "schema"] - - def __init__(self, host, port, user, password, *, database, thread_count, **kw): - args = dict(server=host, port=port, database=database, user=user, password=password, **kw) - self._args = {k: v for k, v in args.items() if v is not None} - self._args["driver"] = "{ODBC Driver 18 for SQL Server}" - - # TODO temp dev debug - self._args["TrustServerCertificate"] = "yes" - - try: - self.default_database = self._args["database"] - self.default_schema = self._args["schema"] - except KeyError: - raise ValueError("Specify a default database and schema.") - - super().__init__(thread_count=thread_count) - - def create_connection(self): - self._mssql = import_mssql() - try: - connection = self._mssql.connect(**self._args) - return connection - except self._mssql.Error as error: - raise ConnectError(*error.args) from error - - def select_table_schema(self, path: DbPath) -> str: - """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" - database, schema, name = self._normalize_table_path(path) - info_schema_path = ["information_schema", "columns"] - if database: - info_schema_path.insert(0, self.dialect.quote(database)) - - return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " - f"FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{name}' AND table_schema = '{schema}'" - ) - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return self.default_database, self.default_schema, path[0] - elif len(path) == 2: - return self.default_database, path[0], path[1] - elif len(path) == 3: - return path - - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" - ) - - def _query_cursor(self, c, sql_code: str): - try: - return super()._query_cursor(c, sql_code) - except self._mssql.DatabaseError as e: - raise QueryError(e) diff --git a/data_diff/sqeleton/databases/mysql.py b/data_diff/sqeleton/databases/mysql.py deleted file mode 100644 index a10652b5..00000000 --- a/data_diff/sqeleton/databases/mysql.py +++ /dev/null @@ -1,160 +0,0 @@ -from data_diff.sqeleton.abcs.database_types import ( - Datetime, - Timestamp, - Float, - Decimal, - Integer, - Text, - TemporalType, - FractionalType, - ColType_UUID, - Boolean, - Date, -) -from data_diff.sqeleton.abcs.mixins import ( - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, - AbstractMixin_Regex, - AbstractMixin_RandomSample, -) -from data_diff.sqeleton.databases.base import ( - Mixin_OptimizerHints, - ThreadedDatabase, - import_helper, - ConnectError, - BaseDialect, - Compilable, -) -from data_diff.sqeleton.databases.base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - TIMESTAMP_PRECISION_POS, - Mixin_Schema, - Mixin_RandomSample, -) -from data_diff.sqeleton.queries.ast_classes import BinBoolOp - - -@import_helper("mysql") -def import_mysql(): - import mysql.connector - - return mysql.connector - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") - - s = self.to_string(f"cast({value} as datetime(6))") - return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") - - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - return f"TRIM(CAST({value} AS char))" - - -class Mixin_Regex(AbstractMixin_Regex): - def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: - return BinBoolOp("REGEXP", [string, pattern]) - - -class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): - name = "MySQL" - ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True - SUPPORTS_INDEXES = True - TYPE_CLASSES = { - # Dates - "datetime": Datetime, - "timestamp": Timestamp, - "date": Date, - # Numbers - "double": Float, - "float": Float, - "decimal": Decimal, - "int": Integer, - "bigint": Integer, - "mediumint": Integer, - "smallint": Integer, - "tinyint": Integer, - # Text - "varchar": Text, - "char": Text, - "varbinary": Text, - "binary": Text, - "text": Text, - "mediumtext": Text, - "longtext": Text, - "tinytext": Text, - # Boolean - "boolean": Boolean, - } - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - def quote(self, s: str): - return f"`{s}`" - - def to_string(self, s: str): - return f"cast({s} as char)" - - def is_distinct_from(self, a: str, b: str) -> str: - return f"not ({a} <=> {b})" - - def random(self) -> str: - return "RAND()" - - def type_repr(self, t) -> str: - try: - return { - str: "VARCHAR(1024)", - }[t] - except KeyError: - return super().type_repr(t) - - def explain_as_text(self, query: str) -> str: - return f"EXPLAIN FORMAT=TREE {query}" - - def optimizer_hints(self, s: str): - return f"/*+ {s} */ " - - def set_timezone_to_utc(self) -> str: - return "SET @@session.time_zone='+00:00'" - - -class MySQL(ThreadedDatabase): - dialect = Dialect() - SUPPORTS_ALPHANUMS = False - SUPPORTS_UNIQUE_CONSTAINT = True - CONNECT_URI_HELP = "mysql://:@/" - CONNECT_URI_PARAMS = ["database?"] - - def __init__(self, *, thread_count, **kw): - self._args = kw - - super().__init__(thread_count=thread_count) - - # In MySQL schema and database are synonymous - try: - self.default_schema = kw["database"] - except KeyError: - raise ValueError("MySQL URL must specify a database") - - def create_connection(self): - mysql = import_mysql() - try: - return mysql.connect(charset="utf8", use_unicode=True, **self._args) - except mysql.Error as e: - if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR: - raise ConnectError("Bad user name or password") from e - elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: - raise ConnectError("Database does not exist") from e - raise ConnectError(*e.args) from e diff --git a/data_diff/sqeleton/databases/oracle.py b/data_diff/sqeleton/databases/oracle.py deleted file mode 100644 index 23fc8e09..00000000 --- a/data_diff/sqeleton/databases/oracle.py +++ /dev/null @@ -1,206 +0,0 @@ -from typing import Dict, List, Optional - -from data_diff.sqeleton.utils import match_regexps -from data_diff.sqeleton.abcs.database_types import ( - Decimal, - Float, - Text, - DbPath, - TemporalType, - ColType, - DbTime, - ColType_UUID, - Timestamp, - TimestampTZ, - FractionalType, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema -from data_diff.sqeleton.abcs import Compilable -from data_diff.sqeleton.queries import this, table, SKIP -from data_diff.sqeleton.databases.base import ( - BaseDialect, - Mixin_OptimizerHints, - ThreadedDatabase, - import_helper, - ConnectError, - QueryError, - Mixin_RandomSample, -) -from data_diff.sqeleton.databases.base import TIMESTAMP_PRECISION_POS - -SESSION_TIME_ZONE = None # Changed by the tests - - -@import_helper("oracle") -def import_oracle(): - import oracledb - - return oracledb - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - # standard_hash is faster than DBMS_CRYPTO.Hash - # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? - return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - # Cast is necessary for correct MD5 (trimming not enough) - return f"CAST(TRIM({value}) AS VARCHAR(36))" - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" - - if coltype.precision > 0: - truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')" - else: - truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')" - return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - # FM999.9990 - format_str = "FM" + "9" * (38 - coltype.precision) - if coltype.precision: - format_str += "0." + "9" * (coltype.precision - 1) + "0" - return f"to_char({value}, '{format_str}')" - - -class Mixin_Schema(AbstractMixin_Schema): - def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: - return ( - table("ALL_TABLES") - .where( - this.OWNER == table_schema, - this.TABLE_NAME.like(like) if like is not None else SKIP, - ) - .select(table_name=this.TABLE_NAME) - ) - - -class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): - name = "Oracle" - SUPPORTS_PRIMARY_KEY = True - SUPPORTS_INDEXES = True - TYPE_CLASSES: Dict[str, type] = { - "NUMBER": Decimal, - "FLOAT": Float, - # Text - "CHAR": Text, - "NCHAR": Text, - "NVARCHAR2": Text, - "VARCHAR2": Text, - "DATE": Timestamp, - } - ROUNDS_ON_PREC_LOSS = True - PLACEHOLDER_TABLE = "DUAL" - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - def quote(self, s: str): - return f'"{s}"' - - def to_string(self, s: str): - return f"cast({s} as varchar(1024))" - - def offset_limit( - self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None - ) -> str: - if offset: - raise NotImplementedError("No support for OFFSET in query") - - return f"FETCH NEXT {limit} ROWS ONLY" - - def concat(self, items: List[str]) -> str: - joined_exprs = " || ".join(items) - return f"({joined_exprs})" - - def timestamp_value(self, t: DbTime) -> str: - return "timestamp '%s'" % t.isoformat(" ") - - def random(self) -> str: - return "dbms_random.value" - - def is_distinct_from(self, a: str, b: str) -> str: - return f"DECODE({a}, {b}, 1, 0) = 0" - - def type_repr(self, t) -> str: - try: - return { - str: "VARCHAR(1024)", - }[t] - except KeyError: - return super().type_repr(t) - - def constant_values(self, rows) -> str: - return " UNION ALL ".join( - "SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows - ) - - def explain_as_text(self, query: str) -> str: - raise NotImplementedError("Explain not yet implemented in Oracle") - - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - regexps = { - r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, - r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, - r"TIMESTAMP\((\d)\)": Timestamp, - } - - for m, t_cls in match_regexps(regexps, type_repr): - precision = int(m.group(1)) - return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - - return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) - - def set_timezone_to_utc(self) -> str: - return "ALTER SESSION SET TIME_ZONE = 'UTC'" - - def current_timestamp(self) -> str: - return "LOCALTIMESTAMP" - - -class Oracle(ThreadedDatabase): - dialect = Dialect() - CONNECT_URI_HELP = "oracle://:@/" - CONNECT_URI_PARAMS = ["database?"] - - def __init__(self, *, host, database, thread_count, **kw): - self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) - - self.default_schema = kw.get("user").upper() - - super().__init__(thread_count=thread_count) - - def create_connection(self): - self._oracle = import_oracle() - try: - c = self._oracle.connect(**self.kwargs) - if SESSION_TIME_ZONE: - c.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'") - return c - except Exception as e: - raise ConnectError(*e.args) from e - - def _query_cursor(self, c, sql_code: str): - try: - return super()._query_cursor(c, sql_code) - except self._oracle.DatabaseError as e: - raise QueryError(e) - - def select_table_schema(self, path: DbPath) -> str: - schema, name = self._normalize_table_path(path) - - return ( - f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" - f" FROM ALL_TAB_COLUMNS WHERE table_name = '{name}' AND owner = '{schema}'" - ) diff --git a/data_diff/sqeleton/databases/postgresql.py b/data_diff/sqeleton/databases/postgresql.py deleted file mode 100644 index db34ec54..00000000 --- a/data_diff/sqeleton/databases/postgresql.py +++ /dev/null @@ -1,183 +0,0 @@ -from typing import List -from data_diff.sqeleton.abcs.database_types import ( - DbPath, - JSON, - Timestamp, - TimestampTZ, - Float, - Decimal, - Integer, - TemporalType, - Native_UUID, - Text, - FractionalType, - Boolean, - Date, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from data_diff.sqeleton.databases.base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema -from data_diff.sqeleton.databases.base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - _CHECKSUM_BITSIZE, - TIMESTAMP_PRECISION_POS, - Mixin_RandomSample, -) - -SESSION_TIME_ZONE = None # Changed by the tests - - -@import_helper("postgresql") -def import_postgresql(): - import psycopg2 - import psycopg2.extras - - psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select) - return psycopg2 - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" - - timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"{value}::decimal(38, {coltype.precision})") - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"{value}::int") - - def normalize_json(self, value: str, _coltype: JSON) -> str: - return f"{value}::text" - - -class PostgresqlDialect(BaseDialect, Mixin_Schema): - name = "PostgreSQL" - ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True - SUPPORTS_INDEXES = True - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - TYPE_CLASSES = { - # Timestamps - "timestamp with time zone": TimestampTZ, - "timestamp without time zone": Timestamp, - "timestamp": Timestamp, - "date": Date, - # Numbers - "double precision": Float, - "real": Float, - "decimal": Decimal, - "smallint": Integer, - "integer": Integer, - "numeric": Decimal, - "bigint": Integer, - # Text - "character": Text, - "character varying": Text, - "varchar": Text, - "text": Text, - "json": JSON, - "jsonb": JSON, - "uuid": Native_UUID, - "boolean": Boolean, - } - - def quote(self, s: str): - return f'"{s}"' - - def to_string(self, s: str): - return f"{s}::varchar" - - def concat(self, items: List[str]) -> str: - joined_exprs = " || ".join(items) - return f"({joined_exprs})" - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 2 due to wierd precision issues in PostgreSQL - return super()._convert_db_precision_to_digits(p) - 2 - - def set_timezone_to_utc(self) -> str: - return "SET TIME ZONE 'UTC'" - - def current_timestamp(self) -> str: - return "current_timestamp" - - def type_repr(self, t) -> str: - if isinstance(t, TimestampTZ): - return f"timestamp ({t.precision}) with time zone" - return super().type_repr(t) - - -class PostgreSQL(ThreadedDatabase): - dialect = PostgresqlDialect() - SUPPORTS_UNIQUE_CONSTAINT = True - CONNECT_URI_HELP = "postgresql://:@/" - CONNECT_URI_PARAMS = ["database?"] - - default_schema = "public" - - def __init__(self, *, thread_count, **kw): - self._args = kw - - super().__init__(thread_count=thread_count) - - def create_connection(self): - if not self._args: - self._args["host"] = None # psycopg2 requires 1+ arguments - - pg = import_postgresql() - try: - c = pg.connect(**self._args) - if SESSION_TIME_ZONE: - c.cursor().execute(f"SET TIME ZONE '{SESSION_TIME_ZONE}'") - return c - except pg.OperationalError as e: - raise ConnectError(*e.args) from e - - def select_table_schema(self, path: DbPath) -> str: - database, schema, table = self._normalize_table_path(path) - - info_schema_path = ["information_schema", "columns"] - if database: - info_schema_path.insert(0, database) - - return ( - f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - def select_table_unique_columns(self, path: DbPath) -> str: - database, schema, table = self._normalize_table_path(path) - - info_schema_path = ["information_schema", "key_column_usage"] - if database: - info_schema_path.insert(0, database) - - return ( - "SELECT column_name " - f"FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return None, self.default_schema, path[0] - elif len(path) == 2: - return None, path[0], path[1] - elif len(path) == 3: - return path - - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" - ) diff --git a/data_diff/sqeleton/databases/presto.py b/data_diff/sqeleton/databases/presto.py deleted file mode 100644 index 6c09a879..00000000 --- a/data_diff/sqeleton/databases/presto.py +++ /dev/null @@ -1,202 +0,0 @@ -from functools import partial -import re - -from data_diff.sqeleton.utils import match_regexps - -from data_diff.sqeleton.abcs.database_types import ( - Timestamp, - TimestampTZ, - Integer, - Float, - Text, - FractionalType, - DbPath, - DbTime, - Decimal, - ColType, - ColType_UUID, - TemporalType, - Boolean, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from data_diff.sqeleton.databases.base import ( - BaseDialect, - Database, - import_helper, - ThreadLocalInterpreter, - Mixin_Schema, - Mixin_RandomSample, -) -from data_diff.sqeleton.databases.base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - TIMESTAMP_PRECISION_POS, -) - - -def query_cursor(c, sql_code): - c.execute(sql_code) - if sql_code.lower().startswith("select"): - return c.fetchall() - # Required for the query to actually run 🤯 - if re.match(r"(insert|create|truncate|drop|explain)", sql_code, re.IGNORECASE): - return c.fetchone() - - -@import_helper("presto") -def import_presto(): - import prestodb - - return prestodb - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - # Trim doesn't work on CHAR type - return f"TRIM(CAST({value} AS VARCHAR))" - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - # TODO rounds - if coltype.rounds: - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - else: - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - - return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"cast ({value} as int)") - - -class Dialect(BaseDialect, Mixin_Schema): - name = "Presto" - ROUNDS_ON_PREC_LOSS = True - TYPE_CLASSES = { - # Timestamps - "timestamp with time zone": TimestampTZ, - "timestamp without time zone": Timestamp, - "timestamp": Timestamp, - # Numbers - "integer": Integer, - "bigint": Integer, - "real": Float, - "double": Float, - # Text - "varchar": Text, - # Boolean - "boolean": Boolean, - } - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - def explain_as_text(self, query: str) -> str: - return f"EXPLAIN (FORMAT TEXT) {query}" - - def type_repr(self, t) -> str: - if isinstance(t, TimestampTZ): - return f"timestamp with time zone" - - try: - return {float: "REAL"}[t] - except KeyError: - return super().type_repr(t) - - def timestamp_value(self, t: DbTime) -> str: - return f"timestamp '{t.isoformat(' ')}'" - - def quote(self, s: str): - return f'"{s}"' - - def to_string(self, s: str): - return f"cast({s} as varchar)" - - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - _numeric_scale: int = None, - ) -> ColType: - timestamp_regexps = { - r"timestamp\((\d)\)": Timestamp, - r"timestamp\((\d)\) with time zone": TimestampTZ, - } - for m, t_cls in match_regexps(timestamp_regexps, type_repr): - precision = int(m.group(1)) - return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - - number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} - for m, n_cls in match_regexps(number_regexps, type_repr): - _prec, scale = map(int, m.groups()) - return n_cls(scale) - - string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text} - for m, n_cls in match_regexps(string_regexps, type_repr): - return n_cls() - - return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) - - def set_timezone_to_utc(self) -> str: - return "SET TIME ZONE '+00:00'" - - def current_timestamp(self) -> str: - return "current_timestamp" - - -class Presto(Database): - dialect = Dialect() - CONNECT_URI_HELP = "presto://@//" - CONNECT_URI_PARAMS = ["catalog", "schema"] - - default_schema = "public" - - def __init__(self, **kw): - prestodb = import_presto() - - if kw.get("schema"): - self.default_schema = kw.get("schema") - - if kw.get("auth") == "basic": # if auth=basic, add basic authenticator for Presto - kw["auth"] = prestodb.auth.BasicAuthentication(kw.pop("user"), kw.pop("password")) - - if "cert" in kw: # if a certificate was specified in URI, verify session with cert - cert = kw.pop("cert") - self._conn = prestodb.dbapi.connect(**kw) - self._conn._http_session.verify = cert - else: - self._conn = prestodb.dbapi.connect(**kw) - - def _query(self, sql_code: str) -> list: - "Uses the standard SQL cursor interface" - c = self._conn.cursor() - - if isinstance(sql_code, ThreadLocalInterpreter): - return sql_code.apply_queries(partial(query_cursor, c)) - - return query_cursor(c, sql_code) - - def close(self): - super().close() - self._conn.close() - - def select_table_schema(self, path: DbPath) -> str: - schema, table = self._normalize_table_path(path) - - return ( - "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision, NULL as numeric_scale " - "FROM INFORMATION_SCHEMA.COLUMNS " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - @property - def is_autocommit(self) -> bool: - return False diff --git a/data_diff/sqeleton/databases/redshift.py b/data_diff/sqeleton/databases/redshift.py deleted file mode 100644 index 97cbc0e1..00000000 --- a/data_diff/sqeleton/databases/redshift.py +++ /dev/null @@ -1,176 +0,0 @@ -from typing import List, Dict -from data_diff.sqeleton.abcs.database_types import ( - Float, - JSON, - TemporalType, - FractionalType, - DbPath, - TimestampTZ, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5 -from data_diff.sqeleton.databases.postgresql import ( - PostgreSQL, - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - TIMESTAMP_PRECISION_POS, - PostgresqlDialect, - Mixin_NormalizeValue, -) - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" - - -class Mixin_NormalizeValue(Mixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"{value}::timestamp(6)" - # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. - secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" - # Get the milliseconds from timestamp. - ms = f"extract(ms from {timestamp})" - # Get the microseconds from timestamp, without the milliseconds! - us = f"extract(us from {timestamp})" - # epoch = Total time since epoch in microseconds. - epoch = f"{secs}*1000000 + {ms}*1000 + {us}" - timestamp6 = ( - f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" - ) - else: - timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"{value}::decimal(38,{coltype.precision})") - - def normalize_json(self, value: str, _coltype: JSON) -> str: - return f"nvl2({value}, json_serialize({value}), NULL)" - - -class Dialect(PostgresqlDialect): - name = "Redshift" - TYPE_CLASSES = { - **PostgresqlDialect.TYPE_CLASSES, - "double": Float, - "real": Float, - "super": JSON, - } - SUPPORTS_INDEXES = False - - def concat(self, items: List[str]) -> str: - joined_exprs = " || ".join(items) - return f"({joined_exprs})" - - def is_distinct_from(self, a: str, b: str) -> str: - return f"({a} IS NULL != {b} IS NULL) OR ({a}!={b})" - - def type_repr(self, t) -> str: - if isinstance(t, TimestampTZ): - return f"timestamptz" - return super().type_repr(t) - - -class Redshift(PostgreSQL): - dialect = Dialect() - CONNECT_URI_HELP = "redshift://:@/" - CONNECT_URI_PARAMS = ["database?"] - - def select_table_schema(self, path: DbPath) -> str: - database, schema, table = self._normalize_table_path(path) - - info_schema_path = ["information_schema", "columns"] - if database: - info_schema_path.insert(0, database) - - return ( - f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" - ) - - def select_external_table_schema(self, path: DbPath) -> str: - database, schema, table = self._normalize_table_path(path) - - db_clause = "" - if database: - db_clause = f" AND redshift_database_name = '{database.lower()}'" - - return ( - f"""SELECT - columnname AS column_name - , CASE WHEN external_type = 'string' THEN 'varchar' ELSE external_type END AS data_type - , NULL AS datetime_precision - , NULL AS numeric_precision - , NULL AS numeric_scale - FROM svv_external_columns - WHERE tablename = '{table.lower()}' AND schemaname = '{schema.lower()}' - """ - + db_clause - ) - - def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]: - rows = self.query(self.select_external_table_schema(path), list) - if not rows: - raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - - d = {r[0]: r for r in rows} - assert len(d) == len(rows) - return d - - def select_view_columns(self, path: DbPath) -> str: - _, schema, table = self._normalize_table_path(path) - - return """select * from pg_get_cols('{}.{}') - cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int) - """.format( - schema, table - ) - - def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]: - rows = self.query(self.select_view_columns(path), list) - - if not rows: - raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns") - - output = {} - for r in rows: - col_name = r[2] - type_info = r[3].split("(") - base_type = type_info[0] - precision = None - scale = None - - if len(type_info) > 1: - if base_type == "numeric": - precision, scale = type_info[1][:-1].split(",") - precision = int(precision) - scale = int(scale) - - out = [col_name, base_type, None, precision, scale] - output[col_name] = tuple(out) - - return output - - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: - try: - return super().query_table_schema(path) - except RuntimeError: - try: - return self.query_external_table_schema(path) - except RuntimeError: - return self.query_pg_get_cols(path) - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return None, self.default_schema, path[0] - elif len(path) == 2: - return None, path[0], path[1] - elif len(path) == 3: - return path - - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" - ) diff --git a/data_diff/sqeleton/databases/snowflake.py b/data_diff/sqeleton/databases/snowflake.py deleted file mode 100644 index e8bf51f5..00000000 --- a/data_diff/sqeleton/databases/snowflake.py +++ /dev/null @@ -1,228 +0,0 @@ -from typing import Union, List -import logging - -from data_diff.sqeleton.abcs.database_types import ( - Timestamp, - TimestampTZ, - Decimal, - Float, - Text, - FractionalType, - TemporalType, - DbPath, - Boolean, - Date, -) -from data_diff.sqeleton.abcs.mixins import ( - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, - AbstractMixin_Schema, - AbstractMixin_TimeTravel, -) -from data_diff.sqeleton.abcs import Compilable -from data_diff.sqeleton.queries import table, this, SKIP, code -from data_diff.sqeleton.databases.base import ( - BaseDialect, - ConnectError, - Database, - import_helper, - CHECKSUM_MASK, - ThreadLocalInterpreter, - Mixin_RandomSample, -) - - -@import_helper("snowflake") -def import_snowflake(): - import snowflake.connector - from cryptography.hazmat.primitives import serialization - from cryptography.hazmat.backends import default_backend - - return snowflake, serialization, default_backend - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))" - else: - timestamp = f"cast(convert_timezone('UTC', {value}) as timestamp({coltype.precision}))" - - return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"{value}::int") - - -class Mixin_Schema(AbstractMixin_Schema): - def table_information(self) -> Compilable: - return table("INFORMATION_SCHEMA", "TABLES") - - def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: - return ( - self.table_information() - .where( - this.TABLE_SCHEMA == table_schema, - this.TABLE_NAME.like(like) if like is not None else SKIP, - this.TABLE_TYPE == "BASE TABLE", - ) - .select(table_name=this.TABLE_NAME) - ) - - -class Mixin_TimeTravel(AbstractMixin_TimeTravel): - def time_travel( - self, - table: Compilable, - before: bool = False, - timestamp: Compilable = None, - offset: Compilable = None, - statement: Compilable = None, - ) -> Compilable: - at_or_before = "AT" if before else "BEFORE" - if timestamp is not None: - assert offset is None and statement is None - key = "timestamp" - value = timestamp - elif offset is not None: - assert statement is None - key = "offset" - value = offset - else: - assert statement is not None - key = "statement" - value = statement - - return code(f"{{table}} {at_or_before}({key} => {{value}})", table=table, value=value) - - -class Dialect(BaseDialect, Mixin_Schema): - name = "Snowflake" - ROUNDS_ON_PREC_LOSS = False - TYPE_CLASSES = { - # Timestamps - "TIMESTAMP_NTZ": Timestamp, - "TIMESTAMP_LTZ": Timestamp, - "TIMESTAMP_TZ": TimestampTZ, - "DATE": Date, - # Numbers - "NUMBER": Decimal, - "FLOAT": Float, - # Text - "TEXT": Text, - # Boolean - "BOOLEAN": Boolean, - } - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} - - def explain_as_text(self, query: str) -> str: - return f"EXPLAIN USING TEXT {query}" - - def quote(self, s: str): - return f'"{s}"' - - def to_string(self, s: str): - return f"cast({s} as string)" - - def table_information(self) -> Compilable: - return table("INFORMATION_SCHEMA", "TABLES") - - def set_timezone_to_utc(self) -> str: - return "ALTER SESSION SET TIMEZONE = 'UTC'" - - def optimizer_hints(self, hints: str) -> str: - raise NotImplementedError("Optimizer hints not yet implemented in snowflake") - - def type_repr(self, t) -> str: - if isinstance(t, TimestampTZ): - return f"timestamp_tz({t.precision})" - return super().type_repr(t) - - -class Snowflake(Database): - dialect = Dialect() - CONNECT_URI_HELP = "snowflake://:@//?warehouse=" - CONNECT_URI_PARAMS = ["database", "schema"] - CONNECT_URI_KWPARAMS = ["warehouse"] - - def __init__(self, *, schema: str, **kw): - snowflake, serialization, default_backend = import_snowflake() - logging.getLogger("snowflake.connector").setLevel(logging.WARNING) - - # Ignore the error: snowflake.connector.network.RetryRequest: could not find io module state - # It's a known issue: https://github.com/snowflakedb/snowflake-connector-python/issues/145 - logging.getLogger("snowflake.connector.network").disabled = True - - assert '"' not in schema, "Schema name should not contain quotes!" - # If a private key is used, read it from the specified path and pass it as "private_key" to the connector. - if "key" in kw: - with open(kw.get("key"), "rb") as key: - if "password" in kw: - raise ConnectError("Cannot use password and key at the same time") - if kw.get("private_key_passphrase"): - encoded_passphrase = kw.get("private_key_passphrase").encode() - else: - encoded_passphrase = None - p_key = serialization.load_pem_private_key( - key.read(), - password=encoded_passphrase, - backend=default_backend(), - ) - - kw["private_key"] = p_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - - self._conn = snowflake.connector.connect(schema=f'"{schema}"', **kw) - - self.default_schema = schema - - def close(self): - super().close() - self._conn.close() - - def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): - "Uses the standard SQL cursor interface" - return self._query_conn(self._conn, sql_code) - - def select_table_schema(self, path: DbPath) -> str: - """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" - database, schema, name = self._normalize_table_path(path) - info_schema_path = ["information_schema", "columns"] - if database: - info_schema_path.insert(0, database) - - return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " - f"FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{name}' AND table_schema = '{schema}'" - ) - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - return None, self.default_schema, path[0] - elif len(path) == 2: - return None, path[0], path[1] - elif len(path) == 3: - return path - - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" - ) - - @property - def is_autocommit(self) -> bool: - return True - - def query_table_unique_columns(self, path: DbPath) -> List[str]: - return [] diff --git a/data_diff/sqeleton/databases/trino.py b/data_diff/sqeleton/databases/trino.py deleted file mode 100644 index 20411749..00000000 --- a/data_diff/sqeleton/databases/trino.py +++ /dev/null @@ -1,47 +0,0 @@ -from data_diff.sqeleton.abcs.database_types import TemporalType, ColType_UUID -from data_diff.sqeleton.databases import presto -from data_diff.sqeleton.databases.base import import_helper -from data_diff.sqeleton.databases.base import TIMESTAMP_PRECISION_POS - - -@import_helper("trino") -def import_trino(): - import trino - - return trino - - -Mixin_MD5 = presto.Mixin_MD5 - - -class Mixin_NormalizeValue(presto.Mixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - s = f"date_format(cast({value} as timestamp({coltype.precision})), '%Y-%m-%d %H:%i:%S.%f')" - else: - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - - return ( - f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')" - ) - - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - return f"TRIM({value})" - - -class Dialect(presto.Dialect): - name = "Trino" - - -class Trino(presto.Presto): - dialect = Dialect() - CONNECT_URI_HELP = "trino://@//" - CONNECT_URI_PARAMS = ["catalog", "schema"] - - def __init__(self, **kw): - trino = import_trino() - - if kw.get("schema"): - self.default_schema = kw.get("schema") - - self._conn = trino.dbapi.connect(**kw) diff --git a/data_diff/sqeleton/databases/vertica.py b/data_diff/sqeleton/databases/vertica.py deleted file mode 100644 index 0c03ab79..00000000 --- a/data_diff/sqeleton/databases/vertica.py +++ /dev/null @@ -1,181 +0,0 @@ -from typing import List - -from data_diff.sqeleton.utils import match_regexps -from data_diff.sqeleton.databases.base import ( - CHECKSUM_HEXDIGITS, - MD5_HEXDIGITS, - TIMESTAMP_PRECISION_POS, - BaseDialect, - ConnectError, - DbPath, - ColType, - ThreadedDatabase, - import_helper, - Mixin_RandomSample, -) -from data_diff.sqeleton.abcs.database_types import ( - Decimal, - Float, - FractionalType, - Integer, - TemporalType, - Text, - Timestamp, - TimestampTZ, - Boolean, - ColType_UUID, -) -from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema -from data_diff.sqeleton.abcs import Compilable -from data_diff.sqeleton.queries import table, this, SKIP - - -@import_helper("vertica") -def import_vertica(): - import vertica_python - - return vertica_python - - -class Mixin_MD5(AbstractMixin_MD5): - def md5_as_int(self, s: str) -> str: - return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" - - -class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" - - timestamp6 = f"TO_CHAR({value}::TIMESTAMP(6), 'YYYY-MM-DD HH24:MI:SS.US')" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"CAST({value} AS NUMERIC(38, {coltype.precision}))") - - def normalize_uuid(self, value: str, _coltype: ColType_UUID) -> str: - # Trim doesn't work on CHAR type - return f"TRIM(CAST({value} AS VARCHAR))" - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - return self.to_string(f"cast ({value} as int)") - - -class Mixin_Schema(AbstractMixin_Schema): - def table_information(self) -> Compilable: - return table("v_catalog", "tables") - - def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: - return ( - self.table_information() - .where( - this.table_schema == table_schema, - this.table_name.like(like) if like is not None else SKIP, - ) - .select(this.table_name) - ) - - -class Dialect(BaseDialect, Mixin_Schema): - name = "Vertica" - ROUNDS_ON_PREC_LOSS = True - - TYPE_CLASSES = { - # Timestamps - "timestamp": Timestamp, - "timestamptz": TimestampTZ, - # Numbers - "numeric": Decimal, - "int": Integer, - "float": Float, - # Text - "char": Text, - "varchar": Text, - # Boolean - "boolean": Boolean, - } - MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - - def quote(self, s: str): - return f'"{s}"' - - def concat(self, items: List[str]) -> str: - return " || ".join(items) - - def to_string(self, s: str) -> str: - return f"CAST({s} AS VARCHAR)" - - def is_distinct_from(self, a: str, b: str) -> str: - return f"not ({a} <=> {b})" - - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - timestamp_regexps = { - r"timestamp\(?(\d?)\)?": Timestamp, - r"timestamptz\(?(\d?)\)?": TimestampTZ, - } - for m, t_cls in match_regexps(timestamp_regexps, type_repr): - precision = int(m.group(1)) if m.group(1) else 6 - return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - - number_regexps = { - r"numeric\((\d+),(\d+)\)": Decimal, - } - for m, n_cls in match_regexps(number_regexps, type_repr): - _prec, scale = map(int, m.groups()) - return n_cls(scale) - - string_regexps = { - r"varchar\((\d+)\)": Text, - r"char\((\d+)\)": Text, - } - for m, n_cls in match_regexps(string_regexps, type_repr): - return n_cls() - - return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) - - def set_timezone_to_utc(self) -> str: - return "SET TIME ZONE TO 'UTC'" - - def current_timestamp(self) -> str: - return "current_timestamp(6)" - - -class Vertica(ThreadedDatabase): - dialect = Dialect() - CONNECT_URI_HELP = "vertica://:@/" - CONNECT_URI_PARAMS = ["database?"] - - default_schema = "public" - - def __init__(self, *, thread_count, **kw): - self._args = kw - self._args["AUTOCOMMIT"] = False - - super().__init__(thread_count=thread_count) - - def create_connection(self): - vertica = import_vertica() - try: - c = vertica.connect(**self._args) - return c - except vertica.errors.ConnectionError as e: - raise ConnectError(*e.args) from e - - def select_table_schema(self, path: DbPath) -> str: - schema, name = self._normalize_table_path(path) - - return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " - "FROM V_CATALOG.COLUMNS " - f"WHERE table_name = '{name}' AND table_schema = '{schema}'" - ) diff --git a/data_diff/sqeleton/queries/__init__.py b/data_diff/sqeleton/queries/__init__.py deleted file mode 100644 index f1eea7b1..00000000 --- a/data_diff/sqeleton/queries/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from data_diff.sqeleton.queries.compiler import Compiler, CompileError -from data_diff.sqeleton.queries.api import ( - this, - join, - outerjoin, - table, - SKIP, - sum_, - avg, - min_, - max_, - cte, - commit, - when, - coalesce, - and_, - if_, - or_, - leftjoin, - rightjoin, - current_timestamp, - code, -) -from data_diff.sqeleton.queries.ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In, Code, Column -from data_diff.sqeleton.queries.extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString diff --git a/data_diff/sqeleton/utils.py b/data_diff/sqeleton/utils.py deleted file mode 100644 index d1f596db..00000000 --- a/data_diff/sqeleton/utils.py +++ /dev/null @@ -1,311 +0,0 @@ -from typing import ( - Iterable, - Iterator, - MutableMapping, - Type, - Union, - Any, - Sequence, - Dict, - TypeVar, - List, -) -from abc import abstractmethod -import math -import string -import re -from uuid import UUID -from urllib.parse import urlparse - -from typing_extensions import Self - -# -- Common -- - - -def join_iter(joiner: Any, iterable: Iterable) -> Iterable: - it = iter(iterable) - try: - yield next(it) - except StopIteration: - return - for i in it: - yield joiner - yield i - - -def safezip(*args): - "zip but makes sure all sequences are the same length" - lens = list(map(len, args)) - if len(set(lens)) != 1: - raise ValueError(f"Mismatching lengths in arguments to safezip: {lens}") - return zip(*args) - - -def is_uuid(u): - try: - UUID(u) - except ValueError: - return False - return True - - -def match_regexps(regexps: Dict[str, Any], s: str) -> Sequence[tuple]: - for regexp, v in regexps.items(): - m = re.match(regexp + "$", s) - if m: - yield m, v - - -# -- Schema -- - -V = TypeVar("V") - - -class CaseAwareMapping(MutableMapping[str, V]): - @abstractmethod - def get_key(self, key: str) -> str: - ... - - def new(self, initial=()) -> Self: - return type(self)(initial) - - -class CaseInsensitiveDict(CaseAwareMapping): - def __init__(self, initial): - self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} - - def __getitem__(self, key: str) -> V: - return self._dict[key.lower()][1] - - def __iter__(self) -> Iterator[V]: - return iter(self._dict) - - def __len__(self) -> int: - return len(self._dict) - - def __setitem__(self, key: str, value): - k = key.lower() - if k in self._dict: - key = self._dict[k][0] - self._dict[k] = key, value - - def __delitem__(self, key: str): - del self._dict[key.lower()] - - def get_key(self, key: str) -> str: - return self._dict[key.lower()][0] - - def __repr__(self) -> str: - return repr(dict(self.items())) - - -class CaseSensitiveDict(dict, CaseAwareMapping): - def get_key(self, key): - self[key] # Throw KeyError if key doesn't exist - return key - - def as_insensitive(self): - return CaseInsensitiveDict(self) - - -# -- Alphanumerics -- - -alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase - - -class ArithString: - @classmethod - def new(cls, *args, **kw) -> Self: - return cls(*args, **kw) - - def range(self, other: "ArithString", count: int) -> List[Self]: - assert isinstance(other, ArithString) - checkpoints = split_space(self.int, other.int, count) - return [self.new(int=i) for i in checkpoints] - - -class ArithUUID(UUID, ArithString): - "A UUID that supports basic arithmetic (add, sub)" - - def __int__(self): - return self.int - - def __add__(self, other: int) -> Self: - if isinstance(other, int): - return self.new(int=self.int + other) - return NotImplemented - - def __sub__(self, other: Union[UUID, int]): - if isinstance(other, int): - return self.new(int=self.int - other) - elif isinstance(other, UUID): - return self.int - other.int - return NotImplemented - - -def numberToAlphanum(num: int, base: str = alphanums) -> str: - digits = [] - while num > 0: - num, remainder = divmod(num, len(base)) - digits.append(remainder) - return "".join(base[i] for i in digits[::-1]) - - -def alphanumToNumber(alphanum: str, base: str = alphanums) -> int: - num = 0 - for c in alphanum: - num = num * len(base) + base.index(c) - return num - - -def justify_alphanums(s1: str, s2: str): - max_len = max(len(s1), len(s2)) - s1 = s1.ljust(max_len) - s2 = s2.ljust(max_len) - return s1, s2 - - -def alphanums_to_numbers(s1: str, s2: str): - s1, s2 = justify_alphanums(s1, s2) - n1 = alphanumToNumber(s1) - n2 = alphanumToNumber(s2) - return n1, n2 - - -class ArithAlphanumeric(ArithString): - def __init__(self, s: str, max_len=None): - if s is None: - raise ValueError("Alphanum string cannot be None") - if max_len and len(s) > max_len: - raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {max_len}") - - for ch in s: - if ch not in alphanums: - raise ValueError(f"Unexpected character {ch} in alphanum string") - - self._str = s - self._max_len = max_len - - # @property - # def int(self): - # return alphanumToNumber(self._str, alphanums) - - def __str__(self): - s = self._str - if self._max_len: - s = s.rjust(self._max_len, alphanums[0]) - return s - - def __len__(self): - return len(self._str) - - def __repr__(self): - return f'alphanum"{self._str}"' - - def __add__(self, other: "Union[ArithAlphanumeric, int]") -> Self: - if isinstance(other, int): - if other != 1: - raise NotImplementedError("not implemented for arbitrary numbers") - num = alphanumToNumber(self._str) - return self.new(numberToAlphanum(num + 1)) - - return NotImplemented - - def range(self, other: "ArithAlphanumeric", count: int) -> List[Self]: - assert isinstance(other, ArithAlphanumeric) - n1, n2 = alphanums_to_numbers(self._str, other._str) - split = split_space(n1, n2, count) - return [self.new(numberToAlphanum(s)) for s in split] - - def __sub__(self, other: "Union[ArithAlphanumeric, int]") -> float: - if isinstance(other, ArithAlphanumeric): - n1, n2 = alphanums_to_numbers(self._str, other._str) - return n1 - n2 - - return NotImplemented - - def __ge__(self, other): - if not isinstance(other, type(self)): - return NotImplemented - return self._str >= other._str - - def __lt__(self, other): - if not isinstance(other, type(self)): - return NotImplemented - return self._str < other._str - - def __eq__(self, other): - if not isinstance(other, type(self)): - return NotImplemented - return self._str == other._str - - def new(self, *args, **kw) -> Self: - return type(self)(*args, **kw, max_len=self._max_len) - - -def number_to_human(n): - millnames = ["", "k", "m", "b"] - n = float(n) - millidx = max( - 0, - min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3))), - ) - - return "{:.0f}{}".format(n / 10 ** (3 * millidx), millnames[millidx]) - - -def split_space(start, end, count) -> List[int]: - size = end - start - assert count <= size, (count, size) - return list(range(start, end, (size + 1) // (count + 1)))[1 : count + 1] - - -def remove_passwords_in_dict(d: dict, replace_with: str = "***"): - for k, v in d.items(): - if k == "password": - d[k] = replace_with - elif isinstance(v, dict): - remove_passwords_in_dict(v, replace_with) - elif k.startswith("database"): - d[k] = remove_password_from_url(v, replace_with) - - -def _join_if_any(sym, args): - args = list(args) - if not args: - return "" - return sym.join(str(a) for a in args if a) - - -def remove_password_from_url(url: str, replace_with: str = "***") -> str: - parsed = urlparse(url) - account = parsed.username or "" - if parsed.password: - account += ":" + replace_with - host = _join_if_any(":", filter(None, [parsed.hostname, parsed.port])) - netloc = _join_if_any("@", filter(None, [account, host])) - replaced = parsed._replace(netloc=netloc) - return replaced.geturl() - - -def match_like(pattern: str, strs: Sequence[str]) -> Iterable[str]: - reo = re.compile(pattern.replace("%", ".*").replace("?", ".") + "$") - for s in strs: - if reo.match(s): - yield s - - -class UnknownMeta(type): - def __instancecheck__(self, instance): - return instance is Unknown - - def __repr__(self): - return "Unknown" - - -class Unknown(metaclass=UnknownMeta): - def __nonzero__(self): - raise TypeError() - - def __new__(class_, *args, **kwargs): - raise RuntimeError("Unknown is a singleton") diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 46672304..9864824a 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -7,11 +7,13 @@ from typing_extensions import Self from data_diff.utils import safezip, Vector -from data_diff.sqeleton.utils import ArithString, split_space -from data_diff.sqeleton.databases import Database, DbPath, DbKey, DbTime -from data_diff.sqeleton.schema import Schema, create_schema -from data_diff.sqeleton.queries import Count, Checksum, SKIP, table, this, Expr, min_, max_, Code -from data_diff.sqeleton.queries.extras import ApplyFuncAndNormalizeAsString, NormalizeAsString +from data_diff.utils import ArithString, split_space +from data_diff.databases.base import Database +from data_diff.abcs.database_types import DbPath, DbKey, DbTime +from data_diff.schema import Schema, create_schema +from data_diff.queries.extras import Checksum +from data_diff.queries.api import Count, SKIP, table, this, Expr, min_, max_, Code +from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, NormalizeAsString logger = logging.getLogger("table_segment") diff --git a/data_diff/utils.py b/data_diff/utils.py index 02870f60..b725285e 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -1,18 +1,38 @@ import json import logging import re -from typing import Dict, Iterable, Sequence +import string +from abc import abstractmethod +from typing import Any, Dict, Iterable, Iterator, List, MutableMapping, Sequence, TypeVar, Union from urllib.parse import urlparse import operator import threading from datetime import datetime +from uuid import UUID + from packaging.version import parse as parse_version import requests from tabulate import tabulate +from typing_extensions import Self + from data_diff.version import __version__ from rich.status import Status +# -- Common -- + + +def join_iter(joiner: Any, iterable: Iterable) -> Iterable: + it = iter(iterable) + try: + yield next(it) + except StopIteration: + return + for i in it: + yield joiner + yield i + + def safezip(*args): "zip but makes sure all sequences are the same length" lens = list(map(len, args)) @@ -21,6 +41,236 @@ def safezip(*args): return zip(*args) +def is_uuid(u): + try: + UUID(u) + except ValueError: + return False + return True + + +def match_regexps(regexps: Dict[str, Any], s: str) -> Sequence[tuple]: + for regexp, v in regexps.items(): + m = re.match(regexp + "$", s) + if m: + yield m, v + + +# -- Schema -- + +V = TypeVar("V") + + +class CaseAwareMapping(MutableMapping[str, V]): + @abstractmethod + def get_key(self, key: str) -> str: + ... + + def new(self, initial=()) -> Self: + return type(self)(initial) + + +class CaseInsensitiveDict(CaseAwareMapping): + def __init__(self, initial): + self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} + + def __getitem__(self, key: str) -> V: + return self._dict[key.lower()][1] + + def __iter__(self) -> Iterator[V]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + + def __setitem__(self, key: str, value): + k = key.lower() + if k in self._dict: + key = self._dict[k][0] + self._dict[k] = key, value + + def __delitem__(self, key: str): + del self._dict[key.lower()] + + def get_key(self, key: str) -> str: + return self._dict[key.lower()][0] + + def __repr__(self) -> str: + return repr(dict(self.items())) + + +class CaseSensitiveDict(dict, CaseAwareMapping): + def get_key(self, key): + self[key] # Throw KeyError if key doesn't exist + return key + + def as_insensitive(self): + return CaseInsensitiveDict(self) + + + +# -- Alphanumerics -- + +alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase + + +class ArithString: + @classmethod + def new(cls, *args, **kw) -> Self: + return cls(*args, **kw) + + def range(self, other: "ArithString", count: int) -> List[Self]: + assert isinstance(other, ArithString) + checkpoints = split_space(self.int, other.int, count) + return [self.new(int=i) for i in checkpoints] + + +class ArithUUID(UUID, ArithString): + "A UUID that supports basic arithmetic (add, sub)" + + def __int__(self): + return self.int + + def __add__(self, other: int) -> Self: + if isinstance(other, int): + return self.new(int=self.int + other) + return NotImplemented + + def __sub__(self, other: Union[UUID, int]): + if isinstance(other, int): + return self.new(int=self.int - other) + elif isinstance(other, UUID): + return self.int - other.int + return NotImplemented + + +def numberToAlphanum(num: int, base: str = alphanums) -> str: + digits = [] + while num > 0: + num, remainder = divmod(num, len(base)) + digits.append(remainder) + return "".join(base[i] for i in digits[::-1]) + + +def alphanumToNumber(alphanum: str, base: str = alphanums) -> int: + num = 0 + for c in alphanum: + num = num * len(base) + base.index(c) + return num + + +def justify_alphanums(s1: str, s2: str): + max_len = max(len(s1), len(s2)) + s1 = s1.ljust(max_len) + s2 = s2.ljust(max_len) + return s1, s2 + + +def alphanums_to_numbers(s1: str, s2: str): + s1, s2 = justify_alphanums(s1, s2) + n1 = alphanumToNumber(s1) + n2 = alphanumToNumber(s2) + return n1, n2 + + +class ArithAlphanumeric(ArithString): + def __init__(self, s: str, max_len=None): + if s is None: + raise ValueError("Alphanum string cannot be None") + if max_len and len(s) > max_len: + raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {max_len}") + + for ch in s: + if ch not in alphanums: + raise ValueError(f"Unexpected character {ch} in alphanum string") + + self._str = s + self._max_len = max_len + + # @property + # def int(self): + # return alphanumToNumber(self._str, alphanums) + + def __str__(self): + s = self._str + if self._max_len: + s = s.rjust(self._max_len, alphanums[0]) + return s + + def __len__(self): + return len(self._str) + + def __repr__(self): + return f'alphanum"{self._str}"' + + def __add__(self, other: "Union[ArithAlphanumeric, int]") -> Self: + if isinstance(other, int): + if other != 1: + raise NotImplementedError("not implemented for arbitrary numbers") + num = alphanumToNumber(self._str) + return self.new(numberToAlphanum(num + 1)) + + return NotImplemented + + def range(self, other: "ArithAlphanumeric", count: int) -> List[Self]: + assert isinstance(other, ArithAlphanumeric) + n1, n2 = alphanums_to_numbers(self._str, other._str) + split = split_space(n1, n2, count) + return [self.new(numberToAlphanum(s)) for s in split] + + def __sub__(self, other: "Union[ArithAlphanumeric, int]") -> float: + if isinstance(other, ArithAlphanumeric): + n1, n2 = alphanums_to_numbers(self._str, other._str) + return n1 - n2 + + return NotImplemented + + def __ge__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return self._str >= other._str + + def __lt__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return self._str < other._str + + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return self._str == other._str + + def new(self, *args, **kw) -> Self: + return type(self)(*args, **kw, max_len=self._max_len) + + +def number_to_human(n): + millnames = ["", "k", "m", "b"] + n = float(n) + millidx = max( + 0, + min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3))), + ) + + return "{:.0f}{}".format(n / 10 ** (3 * millidx), millnames[millidx]) + + +def split_space(start, end, count) -> List[int]: + size = end - start + assert count <= size, (count, size) + return list(range(start, end, (size + 1) // (count + 1)))[1 : count + 1] + + +def remove_passwords_in_dict(d: dict, replace_with: str = "***"): + for k, v in d.items(): + if k == "password": + d[k] = replace_with + elif isinstance(v, dict): + remove_passwords_in_dict(v, replace_with) + elif k.startswith("database"): + d[k] = remove_password_from_url(v, replace_with) + + def _join_if_any(sym, args): args = list(args) if not args: @@ -248,3 +498,19 @@ def _update_cloud_status(self, log=None): for model_name, status in self.cloud_diff_status.items(): cloud_status_string += f"{status} {model_name}\n" self.status.update(f"{cloud_status_string}{log or ''}") + + +class UnknownMeta(type): + def __instancecheck__(self, instance): + return instance is Unknown + + def __repr__(self): + return "Unknown" + + +class Unknown(metaclass=UnknownMeta): + def __nonzero__(self): + raise TypeError() + + def __new__(class_, *args, **kwargs): + raise RuntimeError("Unknown is a singleton") diff --git a/tests/common.py b/tests/common.py index e434eaa7..222ae94b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -9,8 +9,8 @@ from parameterized import parameterized_class -from data_diff.sqeleton.queries import table -from data_diff.sqeleton.databases import Database +from data_diff.queries.api import table +from data_diff.databases.base import Database from data_diff import databases as db from data_diff import tracking @@ -81,7 +81,7 @@ def get_git_revision_short_hash() -> str: db.Clickhouse: TEST_CLICKHOUSE_CONN_STRING, db.Vertica: TEST_VERTICA_CONN_STRING, db.DuckDB: TEST_DUCKDB_CONN_STRING, - db.MsSql: TEST_MSSQL_CONN_STRING, + db.MsSQL: TEST_MSSQL_CONN_STRING, } _database_instances = {} diff --git a/tests/test_api.py b/tests/test_api.py index d97f59a9..07a4d57d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,8 +1,8 @@ from datetime import datetime, timedelta from data_diff import diff_tables, connect_to_table, Algorithm -from data_diff.databases import MySQL -from data_diff.sqeleton.queries import table, commit +from data_diff.databases.mysql import MySQL +from data_diff.queries.api import table, commit from tests.common import TEST_MYSQL_CONN_STRING, get_conn, random_table_suffix, DiffTestCase diff --git a/tests/test_cli.py b/tests/test_cli.py index 2e8111da..1fc4833e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,7 @@ import sys from datetime import datetime, timedelta -from data_diff.sqeleton.queries import commit, current_timestamp +from data_diff.queries.api import commit, current_timestamp from tests.common import DiffTestCase, CONN_STRINGS from tests.test_diff_tables import test_each_database diff --git a/tests/test_database.py b/tests/test_database.py index 1b967cc8..4f4c8ce1 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -4,11 +4,12 @@ import pytz -from data_diff.sqeleton import connect -from data_diff.sqeleton import databases as dbs -from data_diff.sqeleton.queries import table, current_timestamp, NormalizeAsString +from data_diff import connect +from data_diff import databases as dbs +from data_diff.queries.api import table, current_timestamp +from data_diff.queries.extras import NormalizeAsString from tests.common import TEST_MYSQL_CONN_STRING, test_each_database_in_list, get_conn, str_to_checksum, random_table_suffix -from data_diff.sqeleton.abcs.database_types import TimestampTZ +from data_diff.abcs.database_types import TimestampTZ TEST_DATABASES = { dbs.MySQL, diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 75b0acee..3d345296 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -1,9 +1,7 @@ -from contextlib import suppress import unittest import time import json import re -import rich.progress import math import uuid from datetime import datetime, timedelta, timezone @@ -13,9 +11,9 @@ from parameterized import parameterized -from data_diff.sqeleton.utils import number_to_human -from data_diff.sqeleton.queries import table, commit, this, Code -from data_diff.sqeleton.queries.api import insert_rows_in_batches +from data_diff.utils import number_to_human +from data_diff.queries.api import table, commit, this, Code +from data_diff.queries.api import insert_rows_in_batches from data_diff import databases as db from data_diff.query_utils import drop_table @@ -351,7 +349,7 @@ def init_conns(): "boolean", ], }, - db.MsSql: { + db.MsSQL: { "int": ["INT", "BIGINT"], "datetime": ["datetime2(6)"], "float": ["DECIMAL(6, 2)", "FLOAT", "REAL"], @@ -625,7 +623,7 @@ def _insert_to_table(conn, table_path, values, coltype): for i, sample in values ] # mssql represents with int - elif isinstance(conn, db.MsSql) and coltype in ("BIT"): + elif isinstance(conn, db.MsSQL) and coltype in ("BIT"): values = [(i, int(sample)) for i, sample in values] insert_rows_in_batches(conn, tbl, values, columns=["id", "col"]) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index eb41a3ee..b5885a26 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -3,8 +3,8 @@ import uuid import unittest -from data_diff.sqeleton.queries import table, this, commit, code -from data_diff.sqeleton.utils import ArithAlphanumeric, numberToAlphanum +from data_diff.queries.api import table, this, commit, code +from data_diff.utils import ArithAlphanumeric, numberToAlphanum from data_diff.hashdiff_tables import HashDiffer from data_diff.joindiff_tables import JoinDiffer diff --git a/tests/test_format.py b/tests/test_format.py index 0aa8ee8e..4743acc4 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -1,8 +1,8 @@ import unittest from data_diff.diff_tables import DiffResultWrapper, InfoTree, SegmentInfo, TableSegment from data_diff.format import jsonify -from data_diff.sqeleton.abcs.database_types import Integer -from data_diff.sqeleton.databases import Database +from data_diff.abcs.database_types import Integer +from data_diff.databases.base import Database class TestFormat(unittest.TestCase): diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 6a1559d7..b2c5c419 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -1,8 +1,8 @@ from typing import List from datetime import datetime -from data_diff.sqeleton.queries.ast_classes import TablePath -from data_diff.sqeleton.queries import table, commit +from data_diff.queries.ast_classes import TablePath +from data_diff.queries.api import table, commit from data_diff.table_segment import TableSegment from data_diff import databases as db from data_diff.joindiff_tables import JoinDiffer diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 418f44fb..b5e9fa10 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,7 +1,6 @@ import unittest -from data_diff.sqeleton.queries import table, commit - +from data_diff.queries.api import table, commit from data_diff import TableSegment, HashDiffer from data_diff import databases as db from tests.common import get_conn, random_table_suffix diff --git a/tests/test_query.py b/tests/test_query.py index cfa6ada8..cc11b533 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,12 +1,13 @@ from datetime import datetime from typing import List, Optional import unittest -from data_diff.sqeleton.abcs import AbstractDatabase, AbstractDialect -from data_diff.sqeleton.utils import CaseInsensitiveDict, CaseSensitiveDict +from data_diff.abcs.database_types import AbstractDatabase, AbstractDialect +from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict -from data_diff.sqeleton.queries import this, table, Compiler, outerjoin, cte, when, coalesce, CompileError -from data_diff.sqeleton.queries.ast_classes import Random -from data_diff.sqeleton import code, this, table +from data_diff.queries.compiler import Compiler, CompileError +from data_diff.queries.api import outerjoin, cte, when, coalesce +from data_diff.queries.ast_classes import Random +from data_diff.queries.api import code, this, table def normalize_spaces(s: str): diff --git a/tests/test_sql.py b/tests/test_sql.py index d8e07046..2dcab403 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -2,8 +2,8 @@ from tests.common import TEST_MYSQL_CONN_STRING -from data_diff.sqeleton import connect -from data_diff.sqeleton.queries import Compiler, Count, Explain, Select, table, In, BinOp, Code +from data_diff.databases import connect +from data_diff.queries.api import Compiler, Count, Explain, Select, table, In, BinOp, Code class TestSQL(unittest.TestCase): diff --git a/tests/test_utils.py b/tests/test_utils.py index 973121a2..1277d5be 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,6 @@ import unittest -from data_diff.sqeleton.utils import remove_passwords_in_dict, match_regexps, match_like, number_to_human +from data_diff.utils import remove_passwords_in_dict, match_regexps, match_like, number_to_human class TestUtils(unittest.TestCase):