From 9131f472a64550bf4c300426889695d4c53cf88a Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk <pinchuk.ilia@gmail.com> Date: Sat, 28 Oct 2023 18:40:37 +0600 Subject: [PATCH 1/8] feat: prevent type overflow when long string concatenating --- data_diff/databases/base.py | 27 ++++++++++++++++++++++++--- data_diff/databases/bigquery.py | 3 +++ data_diff/databases/clickhouse.py | 3 +++ data_diff/databases/databricks.py | 3 +++ data_diff/databases/duckdb.py | 3 +++ data_diff/databases/mssql.py | 3 +++ data_diff/databases/mysql.py | 3 +++ data_diff/databases/oracle.py | 3 +++ data_diff/databases/postgresql.py | 3 +++ data_diff/databases/presto.py | 3 +++ data_diff/databases/redshift.py | 3 +++ data_diff/databases/snowflake.py | 3 +++ data_diff/databases/vertica.py | 4 ++++ data_diff/diff_tables.py | 4 ++++ tests/test_query.py | 3 +++ 15 files changed, 68 insertions(+), 3 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 9dc03909..1c733d41 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -203,6 +203,15 @@ class BaseDialect(abc.ABC): PLACEHOLDER_TABLE = None # Used for Oracle + # Some database do not support long string so concatenation might lead to type overflow + PREVENT_OVERFLOW_WHEN_CONCAT: bool = False + + _prevent_overflow_when_concat: bool = False + + def enable_preventing_type_overflow(self) -> None: + logger.info("Preventing type overflow when concatenation is enabled") + self._prevent_overflow_when_concat = True + def parse_table_name(self, name: str) -> DbPath: "Parse the given table name into a DbPath" return parse_table_name(name) @@ -392,10 +401,18 @@ def render_checksum(self, c: Compiler, elem: Checksum) -> str: return f"sum({md5})" def render_concat(self, c: Compiler, elem: Concat) -> str: + if self._prevent_overflow_when_concat: + items = [ + f"{self.compile(c, Code(self.to_md5(self.to_string(self.compile(c, expr)))))}" for expr in elem.exprs + ] + # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL - items = [ - f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')" for expr in elem.exprs - ] + else: + items = [ + f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')" + for expr in elem.exprs + ] + assert items if len(items) == 1: return items[0] @@ -769,6 +786,10 @@ def set_timezone_to_utc(self) -> str: def md5_as_int(self, s: str) -> str: "Provide SQL for computing md5 and returning an int" + @abstractmethod + def to_md5(self, s: str) -> str: + """Method to calculate MD5""" + @abstractmethod def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: """Creates an SQL expression, that converts 'value' to a normalized timestamp. diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index e672b928..976c7ad4 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -134,6 +134,9 @@ def parse_table_name(self, name: str) -> DbPath: def md5_as_int(self, s: str) -> str: return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS})) as int64) as numeric) - {CHECKSUM_OFFSET}" + def to_md5(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 7a8816d8..5e5c9453 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -105,6 +105,9 @@ def md5_as_int(self, s: str) -> str: f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx}))))) - {CHECKSUM_OFFSET}" ) + def to_md5(self, s: str) -> str: + return f"hex(MD5({s}))" + def normalize_number(self, value: str, coltype: FractionalType) -> str: # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. # For example: diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 7394f2df..6e815c76 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -82,6 +82,9 @@ def parse_table_name(self, name: str) -> DbPath: def md5_as_int(self, s: str) -> str: return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0)) - {CHECKSUM_OFFSET}" + def to_md5(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: """Databricks timestamp contains no more than 6 digits in precision""" diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index a105b71a..a73055f1 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -100,6 +100,9 @@ def current_timestamp(self) -> str: def md5_as_int(self, s: str) -> str: return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT - {CHECKSUM_OFFSET}" + def to_md5(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: # It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers. if coltype.rounds and coltype.precision > 0: diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index fd23bef1..f92edca7 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -151,6 +151,9 @@ def normalize_number(self, value: str, coltype: NumericType) -> str: def md5_as_int(self, s: str) -> str: return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1)) - {CHECKSUM_OFFSET}" + def to_md5(self, s: str) -> str: + return f"HashBytes('MD5', {s})" + @attrs.define(frozen=False, init=False, kw_only=True) class MsSQL(ThreadedDatabase): diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index f4993b87..d45c31cc 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -101,6 +101,9 @@ def set_timezone_to_utc(self) -> str: def md5_as_int(self, s: str) -> str: return f"conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) - {CHECKSUM_OFFSET}" + def to_md5(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 32bd30ef..b0233d93 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -137,6 +137,9 @@ def md5_as_int(self, s: str) -> str: # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? return f"to_number(substr(standard_hash({s}, 'MD5'), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 'xxxxxxxxxxxxxxx') - {CHECKSUM_OFFSET}" + def to_md5(self, s: str) -> str: + return f"standard_hash({s}, 'MD5'" + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Cast is necessary for correct MD5 (trimming not enough) return f"CAST(TRIM({value}) AS VARCHAR(36))" diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 075d6aff..8c63b261 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -98,6 +98,9 @@ def type_repr(self, t) -> str: def md5_as_int(self, s: str) -> str: return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint - {CHECKSUM_OFFSET}" + def to_md5(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index f575719a..cb6ae47f 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -128,6 +128,9 @@ def current_timestamp(self) -> str: def md5_as_int(self, s: str) -> str: return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0)) - {CHECKSUM_OFFSET}" + def to_md5(self, s: str) -> str: + return f"to_hex(md5(to_utf8({s})))" + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type return f"TRIM(CAST({value} AS VARCHAR))" diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index 968f57bb..6e13d495 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -47,6 +47,9 @@ def type_repr(self, t) -> str: def md5_as_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38) - {CHECKSUM_OFFSET}" + def to_md5(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"{value}::timestamp(6)" diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 857e7c89..8ad34a37 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -76,6 +76,9 @@ def type_repr(self, t) -> str: def md5_as_int(self, s: str) -> str: return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK}) - {CHECKSUM_OFFSET}" + def to_md5(self, s: str) -> str: + return f"md5_number_lower64({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))" diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 51dc00fa..e561045f 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -36,6 +36,7 @@ def import_vertica(): return vertica_python +@attrs.define(frozen=False) class Dialect(BaseDialect): name = "Vertica" ROUNDS_ON_PREC_LOSS = True @@ -109,6 +110,9 @@ def current_timestamp(self) -> str: def md5_as_int(self, s: str) -> str: return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0)) - {CHECKSUM_OFFSET}" + def to_md5(self, s: str) -> str: + return f"MD5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 44daba34..66802426 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -208,6 +208,10 @@ def _diff_tables_wrapper(self, table1: TableSegment, table2: TableSegment, info_ event_json = create_start_event_json(options) run_as_daemon(send_event_json, event_json) + if table1.database.dialect.PREVENT_OVERFLOW_WHEN_CONCAT or table2.database.dialect.PREVENT_OVERFLOW_WHEN_CONCAT: + table1.database.dialect.enable_preventing_type_overflow() + table2.database.dialect.enable_preventing_type_overflow() + start = time.monotonic() error = None try: diff --git a/tests/test_query.py b/tests/test_query.py index 0625a75d..9b139471 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -76,6 +76,9 @@ def optimizer_hints(self, s: str): def md5_as_int(self, s: str) -> str: raise NotImplementedError + def to_md5(self, s: str) -> str: + raise NotImplementedError + def normalize_number(self, value: str, coltype: FractionalType) -> str: raise NotImplementedError From 6812670fe89bc7d1e4841079796608956706dc83 Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk <pinchuk.ilia@gmail.com> Date: Sat, 4 Nov 2023 16:18:32 +0600 Subject: [PATCH 2/8] feat: create an instance of the dialect instead of class object --- data_diff/databases/base.py | 8 +++++++- data_diff/databases/bigquery.py | 4 ++-- data_diff/databases/clickhouse.py | 4 ++-- data_diff/databases/databricks.py | 4 ++-- data_diff/databases/duckdb.py | 4 ++-- data_diff/databases/mssql.py | 4 ++-- data_diff/databases/mysql.py | 4 ++-- data_diff/databases/oracle.py | 4 ++-- data_diff/databases/postgresql.py | 2 +- data_diff/databases/presto.py | 4 ++-- data_diff/databases/redshift.py | 3 ++- data_diff/databases/snowflake.py | 4 ++-- data_diff/databases/trino.py | 6 +++--- data_diff/databases/vertica.py | 4 ++-- 14 files changed, 33 insertions(+), 26 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 1c733d41..4e2c0771 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -906,6 +906,8 @@ class Database(abc.ABC): Instanciated using :meth:`~data_diff.connect` """ + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = BaseDialect + SUPPORTS_ALPHANUMS: ClassVar[bool] = True SUPPORTS_UNIQUE_CONSTAINT: ClassVar[bool] = False CONNECT_URI_KWPARAMS: ClassVar[List[str]] = [] @@ -913,6 +915,7 @@ class Database(abc.ABC): default_schema: Optional[str] = None _interactive: bool = False is_closed: bool = False + _dialect: BaseDialect = None @property def name(self): @@ -1141,10 +1144,13 @@ def close(self): return super().close() @property - @abstractmethod def dialect(self) -> BaseDialect: "The dialect of the database. Used internally by Database, and also available publicly." + if not self._dialect: + self._dialect = self.DIALECT_CLASS() + return self._dialect + @property @abstractmethod def CONNECT_URI_HELP(self) -> str: diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 976c7ad4..c278844a 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,5 +1,5 @@ import re -from typing import Any, List, Union +from typing import Any, ClassVar, List, Union, Type import attrs @@ -182,9 +182,9 @@ def normalize_struct(self, value: str, _coltype: Struct) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class BigQuery(Database): + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "bigquery://<project>/<dataset>" CONNECT_URI_PARAMS = ["dataset"] - dialect = Dialect() project: str dataset: str diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 5e5c9453..7a63881a 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Type +from typing import Any, ClassVar, Dict, Optional, Type import attrs @@ -167,7 +167,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Clickhouse(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "clickhouse://<user>:<password>@<host>/<database>" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 6e815c76..7aadac44 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,5 +1,5 @@ import math -from typing import Any, Dict, Sequence +from typing import Any, ClassVar, Dict, Sequence, Type import logging import attrs @@ -107,7 +107,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Databricks(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "databricks://:<access_token>@<server_hostname>/<http_path>" CONNECT_URI_PARAMS = ["catalog", "schema"] diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index a73055f1..99f22b23 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union +from typing import Any, ClassVar, Dict, Union, Type import attrs @@ -119,7 +119,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class DuckDB(Database): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it CONNECT_URI_HELP = "duckdb://<dbname>@<filepath>" CONNECT_URI_PARAMS = ["database", "dbpath"] diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index f92edca7..94665961 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional, Type import attrs @@ -157,7 +157,7 @@ def to_md5(self, s: str) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class MsSQL(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "mssql://<user>:<password>@<host>/<database>/<schema>" CONNECT_URI_PARAMS = ["database", "schema"] diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index d45c31cc..a83be2b0 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, ClassVar, Dict, Type import attrs @@ -120,7 +120,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class MySQL(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect SUPPORTS_ALPHANUMS = False SUPPORTS_UNIQUE_CONSTAINT = True CONNECT_URI_HELP = "mysql://<user>:<password>@<host>/<database>" diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index b0233d93..c609f6ba 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional, Type import attrs @@ -164,7 +164,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Oracle(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "oracle://<user>:<password>@<host>/<database>" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 8c63b261..08872960 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -122,7 +122,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class PostgreSQL(ThreadedDatabase): - dialect = PostgresqlDialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = PostgresqlDialect SUPPORTS_UNIQUE_CONSTAINT = True CONNECT_URI_HELP = "postgresql://<user>:<password>@<host>/<database>" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index cb6ae47f..eda7a4d2 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,6 +1,6 @@ from functools import partial import re -from typing import Any +from typing import Any, ClassVar, Type import attrs @@ -153,7 +153,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Presto(Database): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "presto://<user>@<host>/<catalog>/<schema>" CONNECT_URI_PARAMS = ["catalog", "schema"] diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index 6e13d495..857cfc63 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -12,6 +12,7 @@ TimestampTZ, ) from data_diff.databases.postgresql import ( + BaseDialect, PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -79,7 +80,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Redshift(PostgreSQL): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "redshift://<user>:<password>@<host>/<database>" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 8ad34a37..2d6751d0 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,4 +1,4 @@ -from typing import Any, Union, List +from typing import Any, ClassVar, Union, List, Type import logging import attrs @@ -96,7 +96,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Snowflake(Database): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "snowflake://<user>:<password>@<account>/<database>/<SCHEMA>?warehouse=<WAREHOUSE>" CONNECT_URI_PARAMS = ["database", "schema"] CONNECT_URI_KWPARAMS = ["warehouse"] diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index f0c95ee4..b76ba74b 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,11 +1,11 @@ -from typing import Any +from typing import Any, ClassVar, Type import attrs from data_diff.abcs.database_types import TemporalType, ColType_UUID from data_diff.databases import presto from data_diff.databases.base import import_helper -from data_diff.databases.base import TIMESTAMP_PRECISION_POS +from data_diff.databases.base import TIMESTAMP_PRECISION_POS, BaseDialect @import_helper("trino") @@ -34,7 +34,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Trino(presto.Presto): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "trino://<user>@<host>/<catalog>/<schema>" CONNECT_URI_PARAMS = ["catalog", "schema"] diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index e561045f..d12cefe6 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, ClassVar, Dict, List, Type import attrs @@ -135,7 +135,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Vertica(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "vertica://<user>:<password>@<host>/<database>" CONNECT_URI_PARAMS = ["database?"] From b9d526f53ea008c054418197b12034703c78ec03 Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk <pinchuk.ilia@gmail.com> Date: Sat, 4 Nov 2023 16:19:43 +0600 Subject: [PATCH 3/8] feat: rename to_md5 to md5_as_hex --- data_diff/databases/base.py | 4 ++-- data_diff/databases/bigquery.py | 2 +- data_diff/databases/clickhouse.py | 2 +- data_diff/databases/databricks.py | 2 +- data_diff/databases/duckdb.py | 2 +- data_diff/databases/mssql.py | 2 +- data_diff/databases/mysql.py | 2 +- data_diff/databases/oracle.py | 2 +- data_diff/databases/postgresql.py | 2 +- data_diff/databases/presto.py | 2 +- data_diff/databases/redshift.py | 2 +- data_diff/databases/snowflake.py | 2 +- data_diff/databases/vertica.py | 2 +- tests/test_query.py | 2 +- 14 files changed, 15 insertions(+), 15 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 4e2c0771..db6b97f0 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -403,7 +403,7 @@ def render_checksum(self, c: Compiler, elem: Checksum) -> str: def render_concat(self, c: Compiler, elem: Concat) -> str: if self._prevent_overflow_when_concat: items = [ - f"{self.compile(c, Code(self.to_md5(self.to_string(self.compile(c, expr)))))}" for expr in elem.exprs + f"{self.compile(c, Code(self.md5_as_hex(self.to_string(self.compile(c, expr)))))}" for expr in elem.exprs ] # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL @@ -787,7 +787,7 @@ def md5_as_int(self, s: str) -> str: "Provide SQL for computing md5 and returning an int" @abstractmethod - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: """Method to calculate MD5""" @abstractmethod diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index c278844a..26d8aec3 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -134,7 +134,7 @@ def parse_table_name(self, name: str) -> DbPath: def md5_as_int(self, s: str) -> str: return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS})) as int64) as numeric) - {CHECKSUM_OFFSET}" - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: return f"md5({s})" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 7a63881a..13082504 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -105,7 +105,7 @@ def md5_as_int(self, s: str) -> str: f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx}))))) - {CHECKSUM_OFFSET}" ) - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: return f"hex(MD5({s}))" def normalize_number(self, value: str, coltype: FractionalType) -> str: diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 7aadac44..19a1f103 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -82,7 +82,7 @@ def parse_table_name(self, name: str) -> DbPath: def md5_as_int(self, s: str) -> str: return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0)) - {CHECKSUM_OFFSET}" - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: return f"md5({s})" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index 99f22b23..6c65b16b 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -100,7 +100,7 @@ def current_timestamp(self) -> str: def md5_as_int(self, s: str) -> str: return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT - {CHECKSUM_OFFSET}" - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: return f"md5({s})" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 94665961..0cf752d3 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -151,7 +151,7 @@ def normalize_number(self, value: str, coltype: NumericType) -> str: def md5_as_int(self, s: str) -> str: return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1)) - {CHECKSUM_OFFSET}" - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: return f"HashBytes('MD5', {s})" diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index a83be2b0..651efe82 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -101,7 +101,7 @@ def set_timezone_to_utc(self) -> str: def md5_as_int(self, s: str) -> str: return f"conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) - {CHECKSUM_OFFSET}" - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: return f"md5({s})" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index c609f6ba..a3c97b07 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -137,7 +137,7 @@ def md5_as_int(self, s: str) -> str: # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? return f"to_number(substr(standard_hash({s}, 'MD5'), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 'xxxxxxxxxxxxxxx') - {CHECKSUM_OFFSET}" - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: return f"standard_hash({s}, 'MD5'" def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 08872960..b4697fc9 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -98,7 +98,7 @@ def type_repr(self, t) -> str: def md5_as_int(self, s: str) -> str: return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint - {CHECKSUM_OFFSET}" - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: return f"md5({s})" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index eda7a4d2..ba1c7360 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -128,7 +128,7 @@ def current_timestamp(self) -> str: def md5_as_int(self, s: str) -> str: return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0)) - {CHECKSUM_OFFSET}" - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: return f"to_hex(md5(to_utf8({s})))" def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index 857cfc63..7a621f57 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -48,7 +48,7 @@ def type_repr(self, t) -> str: def md5_as_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38) - {CHECKSUM_OFFSET}" - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: return f"md5({s})" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 2d6751d0..3a5129b2 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -76,7 +76,7 @@ def type_repr(self, t) -> str: def md5_as_int(self, s: str) -> str: return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK}) - {CHECKSUM_OFFSET}" - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: return f"md5_number_lower64({s})" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index d12cefe6..23f63acc 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -110,7 +110,7 @@ def current_timestamp(self) -> str: def md5_as_int(self, s: str) -> str: return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0)) - {CHECKSUM_OFFSET}" - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: return f"MD5({s})" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: diff --git a/tests/test_query.py b/tests/test_query.py index 9b139471..2585c02e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -76,7 +76,7 @@ def optimizer_hints(self, s: str): def md5_as_int(self, s: str) -> str: raise NotImplementedError - def to_md5(self, s: str) -> str: + def md5_as_hex(self, s: str) -> str: raise NotImplementedError def normalize_number(self, value: str, coltype: FractionalType) -> str: From c56e6e0da01a6a2fe15bf29bd0e4d4c90aface20 Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk <pinchuk.ilia@gmail.com> Date: Sat, 4 Nov 2023 16:20:49 +0600 Subject: [PATCH 4/8] feat: fix a bug for md5_as_hex for oracle --- data_diff/databases/oracle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index a3c97b07..108782b3 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -138,7 +138,7 @@ def md5_as_int(self, s: str) -> str: return f"to_number(substr(standard_hash({s}, 'MD5'), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 'xxxxxxxxxxxxxxx') - {CHECKSUM_OFFSET}" def md5_as_hex(self, s: str) -> str: - return f"standard_hash({s}, 'MD5'" + return f"standard_hash({s}, 'MD5')" def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Cast is necessary for correct MD5 (trimming not enough) From 12388da9dea61ba4989e551f18d99492d9b239dd Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk <pinchuk.ilia@gmail.com> Date: Wed, 8 Nov 2023 23:32:47 +0600 Subject: [PATCH 5/8] feat: make PREVENT_OVERFLOW_WHEN_CONCAT classvar --- data_diff/databases/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index db6b97f0..e81098be 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -199,12 +199,12 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal class BaseDialect(abc.ABC): SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False SUPPORTS_INDEXES: ClassVar[bool] = False + PREVENT_OVERFLOW_WHEN_CONCAT: ClassVar[bool] = False TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {} PLACEHOLDER_TABLE = None # Used for Oracle # Some database do not support long string so concatenation might lead to type overflow - PREVENT_OVERFLOW_WHEN_CONCAT: bool = False _prevent_overflow_when_concat: bool = False From 8065192eaa2493cef5e5b287c71a0c8674ae845f Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk <pinchuk.ilia@gmail.com> Date: Wed, 8 Nov 2023 23:35:20 +0600 Subject: [PATCH 6/8] feat: apply formatter --- data_diff/databases/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index e81098be..871c650d 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -403,7 +403,8 @@ def render_checksum(self, c: Compiler, elem: Checksum) -> str: def render_concat(self, c: Compiler, elem: Concat) -> str: if self._prevent_overflow_when_concat: items = [ - f"{self.compile(c, Code(self.md5_as_hex(self.to_string(self.compile(c, expr)))))}" for expr in elem.exprs + f"{self.compile(c, Code(self.md5_as_hex(self.to_string(self.compile(c, expr)))))}" + for expr in elem.exprs ] # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL From 98e847bc611f178d2a6ff3c3212f7e16e1540f61 Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk <pinchuk.ilia@gmail.com> Date: Thu, 9 Nov 2023 00:01:59 +0600 Subject: [PATCH 7/8] feat: fix unit tests --- data_diff/databases/mssql.py | 2 +- data_diff/databases/mysql.py | 2 +- data_diff/databases/oracle.py | 2 +- data_diff/databases/postgresql.py | 2 +- tests/test_joindiff.py | 4 +++- 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 0cf752d3..8f5195ee 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -38,7 +38,7 @@ def import_mssql(): class Dialect(BaseDialect): name = "MsSQL" ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True SUPPORTS_INDEXES = True TYPE_CLASSES = { # Timestamps diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 651efe82..647388f2 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -40,7 +40,7 @@ def import_mysql(): class Dialect(BaseDialect): name = "MySQL" ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True SUPPORTS_INDEXES = True TYPE_CLASSES = { # Dates diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 108782b3..ab84f0b6 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -43,7 +43,7 @@ class Dialect( BaseDialect, ): name = "Oracle" - SUPPORTS_PRIMARY_KEY = True + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True SUPPORTS_INDEXES = True TYPE_CLASSES: Dict[str, type] = { "NUMBER": Decimal, diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index b4697fc9..4b9e945f 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -42,7 +42,7 @@ def import_postgresql(): class PostgresqlDialect(BaseDialect): name = "PostgreSQL" ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True SUPPORTS_INDEXES = True TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = { diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index ed8a31b6..0f664c45 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -270,7 +270,9 @@ def test_null_pks(self): self.assertRaises(ValueError, list, x) -@test_each_database_in_list(d for d in TEST_DATABASES if d.dialect.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT) +@test_each_database_in_list( + d for d in TEST_DATABASES if d.DIALECT_CLASS.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT +) class TestUniqueConstraint(DiffTestCase): def setUp(self): super().setUp() From 842481fecca336bee0cc18b131f22f5774c6e788 Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk <pinchuk.ilia@gmail.com> Date: Thu, 9 Nov 2023 02:23:07 +0600 Subject: [PATCH 8/8] feat: fix md5_as_hex for snowflake --- data_diff/databases/snowflake.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 3a5129b2..bedacd80 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -77,7 +77,7 @@ def md5_as_int(self, s: str) -> str: return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK}) - {CHECKSUM_OFFSET}" def md5_as_hex(self, s: str) -> str: - return f"md5_number_lower64({s})" + return f"md5({s})" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: