Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Prevent type overflow #757

Merged
merged 8 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,19 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal
class BaseDialect(abc.ABC):
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False
SUPPORTS_INDEXES: ClassVar[bool] = False
PREVENT_OVERFLOW_WHEN_CONCAT: ClassVar[bool] = False
TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {}

PLACEHOLDER_TABLE = None # Used for Oracle

# Some database do not support long string so concatenation might lead to type overflow

_prevent_overflow_when_concat: bool = False

def enable_preventing_type_overflow(self) -> None:
logger.info("Preventing type overflow when concatenation is enabled")
self._prevent_overflow_when_concat = True

def parse_table_name(self, name: str) -> DbPath:
"Parse the given table name into a DbPath"
return parse_table_name(name)
Expand Down Expand Up @@ -392,10 +401,19 @@ def render_checksum(self, c: Compiler, elem: Checksum) -> str:
return f"sum({md5})"

def render_concat(self, c: Compiler, elem: Concat) -> str:
if self._prevent_overflow_when_concat:
items = [
f"{self.compile(c, Code(self.md5_as_hex(self.to_string(self.compile(c, expr)))))}"
for expr in elem.exprs
]

# We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL
items = [
f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')" for expr in elem.exprs
]
else:
items = [
f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')"
for expr in elem.exprs
]

assert items
if len(items) == 1:
return items[0]
Expand Down Expand Up @@ -769,6 +787,10 @@ def set_timezone_to_utc(self) -> str:
def md5_as_int(self, s: str) -> str:
"Provide SQL for computing md5 and returning an int"

@abstractmethod
def md5_as_hex(self, s: str) -> str:
"""Method to calculate MD5"""

@abstractmethod
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.
Expand Down Expand Up @@ -885,13 +907,16 @@ class Database(abc.ABC):
Instanciated using :meth:`~data_diff.connect`
"""

DIALECT_CLASS: ClassVar[Type[BaseDialect]] = BaseDialect

SUPPORTS_ALPHANUMS: ClassVar[bool] = True
SUPPORTS_UNIQUE_CONSTAINT: ClassVar[bool] = False
CONNECT_URI_KWPARAMS: ClassVar[List[str]] = []

default_schema: Optional[str] = None
_interactive: bool = False
is_closed: bool = False
_dialect: BaseDialect = None

@property
def name(self):
Expand Down Expand Up @@ -1120,10 +1145,13 @@ def close(self):
return super().close()

@property
@abstractmethod
def dialect(self) -> BaseDialect:
"The dialect of the database. Used internally by Database, and also available publicly."

if not self._dialect:
self._dialect = self.DIALECT_CLASS()
return self._dialect

@property
@abstractmethod
def CONNECT_URI_HELP(self) -> str:
Expand Down
7 changes: 5 additions & 2 deletions data_diff/databases/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, List, Union
from typing import Any, ClassVar, List, Union, Type

import attrs

Expand Down Expand Up @@ -134,6 +134,9 @@ def parse_table_name(self, name: str) -> DbPath:
def md5_as_int(self, s: str) -> str:
return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS})) as int64) as numeric) - {CHECKSUM_OFFSET}"

def md5_as_hex(self, s: str) -> str:
return f"md5({s})"

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
Expand Down Expand Up @@ -179,9 +182,9 @@ def normalize_struct(self, value: str, _coltype: Struct) -> str:

@attrs.define(frozen=False, init=False, kw_only=True)
class BigQuery(Database):
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
CONNECT_URI_HELP = "bigquery://<project>/<dataset>"
CONNECT_URI_PARAMS = ["dataset"]
dialect = Dialect()

project: str
dataset: str
Expand Down
7 changes: 5 additions & 2 deletions data_diff/databases/clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Type
from typing import Any, ClassVar, Dict, Optional, Type

import attrs

Expand Down Expand Up @@ -105,6 +105,9 @@ def md5_as_int(self, s: str) -> str:
f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx}))))) - {CHECKSUM_OFFSET}"
)

def md5_as_hex(self, s: str) -> str:
return f"hex(MD5({s}))"

def normalize_number(self, value: str, coltype: FractionalType) -> str:
# If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped.
# For example:
Expand Down Expand Up @@ -164,7 +167,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:

@attrs.define(frozen=False, init=False, kw_only=True)
class Clickhouse(ThreadedDatabase):
dialect = Dialect()
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
CONNECT_URI_HELP = "clickhouse://<user>:<password>@<host>/<database>"
CONNECT_URI_PARAMS = ["database?"]

Expand Down
7 changes: 5 additions & 2 deletions data_diff/databases/databricks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any, Dict, Sequence
from typing import Any, ClassVar, Dict, Sequence, Type
import logging

import attrs
Expand Down Expand Up @@ -82,6 +82,9 @@ def parse_table_name(self, name: str) -> DbPath:
def md5_as_int(self, s: str) -> str:
return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0)) - {CHECKSUM_OFFSET}"

def md5_as_hex(self, s: str) -> str:
return f"md5({s})"

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
"""Databricks timestamp contains no more than 6 digits in precision"""

Expand All @@ -104,7 +107,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:

@attrs.define(frozen=False, init=False, kw_only=True)
class Databricks(ThreadedDatabase):
dialect = Dialect()
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
CONNECT_URI_HELP = "databricks://:<access_token>@<server_hostname>/<http_path>"
CONNECT_URI_PARAMS = ["catalog", "schema"]

Expand Down
7 changes: 5 additions & 2 deletions data_diff/databases/duckdb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Union
from typing import Any, ClassVar, Dict, Union, Type

import attrs

Expand Down Expand Up @@ -100,6 +100,9 @@ def current_timestamp(self) -> str:
def md5_as_int(self, s: str) -> str:
return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT - {CHECKSUM_OFFSET}"

def md5_as_hex(self, s: str) -> str:
return f"md5({s})"

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
# It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers.
if coltype.rounds and coltype.precision > 0:
Expand All @@ -116,7 +119,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:

@attrs.define(frozen=False, init=False, kw_only=True)
class DuckDB(Database):
dialect = Dialect()
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it
CONNECT_URI_HELP = "duckdb://<dbname>@<filepath>"
CONNECT_URI_PARAMS = ["database", "dbpath"]
Expand Down
9 changes: 6 additions & 3 deletions data_diff/databases/mssql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, ClassVar, Dict, Optional, Type

import attrs

Expand Down Expand Up @@ -38,7 +38,7 @@ def import_mssql():
class Dialect(BaseDialect):
name = "MsSQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
SUPPORTS_INDEXES = True
TYPE_CLASSES = {
# Timestamps
Expand Down Expand Up @@ -151,10 +151,13 @@ def normalize_number(self, value: str, coltype: NumericType) -> str:
def md5_as_int(self, s: str) -> str:
return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1)) - {CHECKSUM_OFFSET}"

def md5_as_hex(self, s: str) -> str:
return f"HashBytes('MD5', {s})"


@attrs.define(frozen=False, init=False, kw_only=True)
class MsSQL(ThreadedDatabase):
dialect = Dialect()
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
CONNECT_URI_HELP = "mssql://<user>:<password>@<host>/<database>/<schema>"
CONNECT_URI_PARAMS = ["database", "schema"]

Expand Down
9 changes: 6 additions & 3 deletions data_diff/databases/mysql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, ClassVar, Dict, Type

import attrs

Expand Down Expand Up @@ -40,7 +40,7 @@ def import_mysql():
class Dialect(BaseDialect):
name = "MySQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
SUPPORTS_INDEXES = True
TYPE_CLASSES = {
# Dates
Expand Down Expand Up @@ -101,6 +101,9 @@ def set_timezone_to_utc(self) -> str:
def md5_as_int(self, s: str) -> str:
return f"conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) - {CHECKSUM_OFFSET}"

def md5_as_hex(self, s: str) -> str:
return f"md5({s})"

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))")
Expand All @@ -117,7 +120,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:

@attrs.define(frozen=False, init=False, kw_only=True)
class MySQL(ThreadedDatabase):
dialect = Dialect()
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
SUPPORTS_ALPHANUMS = False
SUPPORTS_UNIQUE_CONSTAINT = True
CONNECT_URI_HELP = "mysql://<user>:<password>@<host>/<database>"
Expand Down
9 changes: 6 additions & 3 deletions data_diff/databases/oracle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, ClassVar, Dict, List, Optional, Type

import attrs

Expand Down Expand Up @@ -43,7 +43,7 @@ class Dialect(
BaseDialect,
):
name = "Oracle"
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
SUPPORTS_INDEXES = True
TYPE_CLASSES: Dict[str, type] = {
"NUMBER": Decimal,
Expand Down Expand Up @@ -137,6 +137,9 @@ def md5_as_int(self, s: str) -> str:
# TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ?
return f"to_number(substr(standard_hash({s}, 'MD5'), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 'xxxxxxxxxxxxxxx') - {CHECKSUM_OFFSET}"

def md5_as_hex(self, s: str) -> str:
return f"standard_hash({s}, 'MD5')"

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
# Cast is necessary for correct MD5 (trimming not enough)
return f"CAST(TRIM({value}) AS VARCHAR(36))"
Expand All @@ -161,7 +164,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:

@attrs.define(frozen=False, init=False, kw_only=True)
class Oracle(ThreadedDatabase):
dialect = Dialect()
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
CONNECT_URI_HELP = "oracle://<user>:<password>@<host>/<database>"
CONNECT_URI_PARAMS = ["database?"]

Expand Down
7 changes: 5 additions & 2 deletions data_diff/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def import_postgresql():
class PostgresqlDialect(BaseDialect):
name = "PostgreSQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
SUPPORTS_INDEXES = True

TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {
Expand Down Expand Up @@ -98,6 +98,9 @@ def type_repr(self, t) -> str:
def md5_as_int(self, s: str) -> str:
return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint - {CHECKSUM_OFFSET}"

def md5_as_hex(self, s: str) -> str:
return f"md5({s})"

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')"
Expand All @@ -119,7 +122,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str:

@attrs.define(frozen=False, init=False, kw_only=True)
class PostgreSQL(ThreadedDatabase):
dialect = PostgresqlDialect()
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = PostgresqlDialect
SUPPORTS_UNIQUE_CONSTAINT = True
CONNECT_URI_HELP = "postgresql://<user>:<password>@<host>/<database>"
CONNECT_URI_PARAMS = ["database?"]
Expand Down
7 changes: 5 additions & 2 deletions data_diff/databases/presto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
import re
from typing import Any
from typing import Any, ClassVar, Type

import attrs

Expand Down Expand Up @@ -128,6 +128,9 @@ def current_timestamp(self) -> str:
def md5_as_int(self, s: str) -> str:
return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0)) - {CHECKSUM_OFFSET}"

def md5_as_hex(self, s: str) -> str:
return f"to_hex(md5(to_utf8({s})))"

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
# Trim doesn't work on CHAR type
return f"TRIM(CAST({value} AS VARCHAR))"
Expand All @@ -150,7 +153,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:

@attrs.define(frozen=False, init=False, kw_only=True)
class Presto(Database):
dialect = Dialect()
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
CONNECT_URI_HELP = "presto://<user>@<host>/<catalog>/<schema>"
CONNECT_URI_PARAMS = ["catalog", "schema"]

Expand Down
6 changes: 5 additions & 1 deletion data_diff/databases/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
TimestampTZ,
)
from data_diff.databases.postgresql import (
BaseDialect,
PostgreSQL,
MD5_HEXDIGITS,
CHECKSUM_HEXDIGITS,
Expand Down Expand Up @@ -47,6 +48,9 @@ def type_repr(self, t) -> str:
def md5_as_int(self, s: str) -> str:
return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38) - {CHECKSUM_OFFSET}"

def md5_as_hex(self, s: str) -> str:
return f"md5({s})"

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
timestamp = f"{value}::timestamp(6)"
Expand Down Expand Up @@ -76,7 +80,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str:

@attrs.define(frozen=False, init=False, kw_only=True)
class Redshift(PostgreSQL):
dialect = Dialect()
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
CONNECT_URI_HELP = "redshift://<user>:<password>@<host>/<database>"
CONNECT_URI_PARAMS = ["database?"]

Expand Down
7 changes: 5 additions & 2 deletions data_diff/databases/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Union, List
from typing import Any, ClassVar, Union, List, Type
import logging

import attrs
Expand Down Expand Up @@ -76,6 +76,9 @@ def type_repr(self, t) -> str:
def md5_as_int(self, s: str) -> str:
return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK}) - {CHECKSUM_OFFSET}"

def md5_as_hex(self, s: str) -> str:
return f"md5({s})"

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))"
Expand All @@ -93,7 +96,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:

@attrs.define(frozen=False, init=False, kw_only=True)
class Snowflake(Database):
dialect = Dialect()
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
CONNECT_URI_HELP = "snowflake://<user>:<password>@<account>/<database>/<SCHEMA>?warehouse=<WAREHOUSE>"
CONNECT_URI_PARAMS = ["database", "schema"]
CONNECT_URI_KWPARAMS = ["warehouse"]
Expand Down
Loading