Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Fix the aftermath of refactoring #722

Merged
merged 8 commits into from
Oct 2, 2023
2 changes: 1 addition & 1 deletion data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def _data_diff(

schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths)))
schema1, schema2 = schemas = [
create_schema(db, table_path, schema, case_sensitive)
create_schema(db.name, table_path, schema, case_sensitive)
for db, table_path, schema in safezip(dbs, table_paths, schemas)
]

Expand Down
1 change: 1 addition & 0 deletions data_diff/databases/_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def match_path(self, dsn):

class Connect:
"""Provides methods for connecting to a supported database using a URL or connection dict."""

conn_cache: MutableMapping[Hashable, Database]

def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
Expand Down
90 changes: 47 additions & 43 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,39 @@
from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString
from data_diff.utils import ArithString, is_uuid, join_iter, safezip
from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this
from data_diff.queries.ast_classes import Alias, BinOp, CaseWhen, Cast, Column, Commit, Concat, ConstantTable, Count, \
CreateTable, Cte, \
CurrentTimestamp, DropTable, Func, \
GroupBy, \
ITable, In, InsertToTable, IsDistinctFrom, \
Join, \
Param, \
Random, \
Root, TableAlias, TableOp, TablePath, \
TimeTravel, TruncateTable, UnaryOp, WhenThen, _ResolveColumn
from data_diff.queries.ast_classes import (
Alias,
BinOp,
CaseWhen,
Cast,
Column,
Commit,
Concat,
ConstantTable,
Count,
CreateTable,
Cte,
CurrentTimestamp,
DropTable,
Func,
GroupBy,
ITable,
In,
InsertToTable,
IsDistinctFrom,
Join,
Param,
Random,
Root,
TableAlias,
TableOp,
TablePath,
TimeTravel,
TruncateTable,
UnaryOp,
WhenThen,
_ResolveColumn,
)
from data_diff.abcs.database_types import (
Array,
Struct,
Expand Down Expand Up @@ -67,17 +90,11 @@ class CompileError(Exception):
pass


# TODO: LATER: Resolve the circular imports of databases-compiler-dialects:
# A database uses a compiler to render the SQL query.
# The compiler delegates to a dialect.
# The dialect renders the SQL.
# AS IS: The dialect requires the db to normalize table paths — leading to the back-dependency.
# TO BE: All the tables paths must be pre-normalized before SQL rendering.
# Also: c.database.is_autocommit in render_commit().
# After this, the Compiler can cease referring Database/Dialect at all,
# and be used only as a CompilingContext (a counter/data-bearing class).
# As a result, it becomes low-level util, and the circular dependency auto-resolves.
# Meanwhile, the easy fix is to simply move the Compiler here.
# TODO: remove once switched to attrs, where ForwardRef[]/strings are resolved.
class _RuntypeHackToFixCicularRefrencedDatabase:
dialect: "BaseDialect"


@dataclass
class Compiler(AbstractCompiler):
"""
Expand All @@ -90,7 +107,7 @@ class Compiler(AbstractCompiler):
# Database is needed to normalize tables. Dialect is needed for recursive compilations.
# In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
# In practice, we currently bind the dialects to the specific database classes.
database: "Database"
database: _RuntypeHackToFixCicularRefrencedDatabase

in_select: bool = False # Compilation runtime flag
in_join: bool = False # Compilation runtime flag
Expand All @@ -102,7 +119,7 @@ class Compiler(AbstractCompiler):
_counter: List = field(default_factory=lambda: [0])

@property
def dialect(self) -> "Dialect":
def dialect(self) -> "BaseDialect":
return self.database.dialect

# TODO: DEPRECATED: Remove once the dialect is used directly in all places.
Expand Down Expand Up @@ -223,7 +240,6 @@ class BaseDialect(abc.ABC):
SUPPORTS_PRIMARY_KEY = False
SUPPORTS_INDEXES = False
TYPE_CLASSES: Dict[str, type] = {}
MIXINS = frozenset()

PLACEHOLDER_TABLE = None # Used for Oracle

Expand Down Expand Up @@ -414,7 +430,9 @@ def render_checksum(self, c: Compiler, elem: Checksum) -> str:

def render_concat(self, c: Compiler, elem: Concat) -> str:
# We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL
items = [f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')" for expr in elem.exprs]
items = [
f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')" for expr in elem.exprs
]
assert items
if len(items) == 1:
return items[0]
Expand Down Expand Up @@ -559,17 +577,15 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
columns=columns,
group_by_exprs=[Code(k) for k in keys],
having_exprs=elem.having_exprs,
)
),
)

keys_str = ", ".join(keys)
columns_str = ", ".join(self.compile(c, x) for x in columns)
having_str = (
" HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else ""
)
select = (
f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}"
)
select = f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}"

if c.in_select:
select = f"({select}) {c.new_unique_name()}"
Expand Down Expand Up @@ -601,7 +617,7 @@ def render_timetravel(self, c: Compiler, elem: TimeTravel) -> str:
# TODO: why is it c.? why not self? time-trvelling is the dialect's thing, isnt't it?
c.time_travel(
elem.table, before=elem.before, timestamp=elem.timestamp, offset=elem.offset, statement=elem.statement
)
),
)

def render_createtable(self, c: Compiler, elem: CreateTable) -> str:
Expand Down Expand Up @@ -768,18 +784,6 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
# See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
return math.floor(math.log(2**p, 10))

@classmethod
def load_mixins(cls, *abstract_mixins) -> Self:
"Load a list of mixins that implement the given abstract mixins"
mixins = {m for m in cls.MIXINS if issubclass(m, abstract_mixins)}

class _DialectWithMixins(cls, *mixins, *abstract_mixins):
pass

_DialectWithMixins.__name__ = cls.__name__
return _DialectWithMixins()


@property
@abstractmethod
def name(self) -> str:
Expand Down Expand Up @@ -822,7 +826,7 @@ def __getitem__(self, i):
return self.rows[i]


class Database(abc.ABC):
class Database(abc.ABC, _RuntypeHackToFixCicularRefrencedDatabase):
"""Base abstract class for databases.

Used for providing connection code and implementation specific SQL utilities.
Expand Down
5 changes: 3 additions & 2 deletions data_diff/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def time_travel(
)


class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
class Dialect(
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
):
name = "BigQuery"
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation
TYPE_CLASSES = {
Expand All @@ -159,7 +161,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Abstra
}
TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>")
TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>")
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample}

def random(self) -> str:
return "RAND()"
Expand Down
1 change: 0 additions & 1 deletion data_diff/databases/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, A
"DateTime64": Timestamp,
"Bool": Boolean,
}
MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}

def quote(self, s: str) -> str:
return f'"{s}"'
Expand Down
1 change: 0 additions & 1 deletion data_diff/databases/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, A
# Boolean
"BOOLEAN": Boolean,
}
MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}

def quote(self, s: str):
return f"`{s}`"
Expand Down
5 changes: 3 additions & 2 deletions data_diff/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@ def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable:
return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio))


class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
class Dialect(
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
):
name = "DuckDB"
ROUNDS_ON_PREC_LOSS = False
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_INDEXES = True
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}

TYPE_CLASSES = {
# Timestamps
Expand Down
12 changes: 9 additions & 3 deletions data_diff/databases/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,15 @@ def md5_as_int(self, s: str) -> str:
return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))"


class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
class Dialect(
BaseDialect,
Mixin_Schema,
Mixin_OptimizerHints,
Mixin_MD5,
Mixin_NormalizeValue,
AbstractMixin_MD5,
AbstractMixin_NormalizeValue,
):
name = "MsSQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
Expand Down Expand Up @@ -98,8 +106,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_
"json": JSON,
}

MIXINS = {Mixin_Schema, Mixin_NormalizeValue, Mixin_RandomSample}

def quote(self, s: str):
return f"[{s}]"

Expand Down
11 changes: 9 additions & 2 deletions data_diff/databases/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,15 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
return f"TRIM(CAST({value} AS char))"


class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
class Dialect(
BaseDialect,
Mixin_Schema,
Mixin_OptimizerHints,
Mixin_MD5,
Mixin_NormalizeValue,
AbstractMixin_MD5,
AbstractMixin_NormalizeValue,
):
name = "MySQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
Expand Down Expand Up @@ -91,7 +99,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_
# Boolean
"boolean": Boolean,
}
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}

def quote(self, s: str):
return f"`{s}`"
Expand Down
11 changes: 9 additions & 2 deletions data_diff/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,15 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
)


class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
class Dialect(
BaseDialect,
Mixin_Schema,
Mixin_OptimizerHints,
Mixin_MD5,
Mixin_NormalizeValue,
AbstractMixin_MD5,
AbstractMixin_NormalizeValue,
):
name = "Oracle"
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_INDEXES = True
Expand All @@ -96,7 +104,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_
}
ROUNDS_ON_PREC_LOSS = True
PLACEHOLDER_TABLE = "DUAL"
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}

def quote(self, s: str):
return f'"{s}"'
Expand Down
5 changes: 3 additions & 2 deletions data_diff/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ def normalize_json(self, value: str, _coltype: JSON) -> str:
return f"{value}::text"


class PostgresqlDialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
class PostgresqlDialect(
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
):
name = "PostgreSQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_INDEXES = True
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}

TYPE_CLASSES = {
# Timestamps
Expand Down
5 changes: 3 additions & 2 deletions data_diff/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
return self.to_string(f"cast ({value} as int)")


class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
class Dialect(
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
):
name = "Presto"
ROUNDS_ON_PREC_LOSS = True
TYPE_CLASSES = {
Expand All @@ -94,7 +96,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Abstra
# Boolean
"boolean": Boolean,
}
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}

def explain_as_text(self, query: str) -> str:
return f"EXPLAIN (FORMAT TEXT) {query}"
Expand Down
5 changes: 3 additions & 2 deletions data_diff/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def time_travel(
return code(f"{{table}} {at_or_before}({key} => {{value}})", table=table, value=value)


class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
class Dialect(
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
):
name = "Snowflake"
ROUNDS_ON_PREC_LOSS = False
TYPE_CLASSES = {
Expand All @@ -121,7 +123,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Abstra
# Boolean
"BOOLEAN": Boolean,
}
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample}

def explain_as_text(self, query: str) -> str:
return f"EXPLAIN USING TEXT {query}"
Expand Down
5 changes: 3 additions & 2 deletions data_diff/databases/vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
)


class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
class Dialect(
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
):
name = "Vertica"
ROUNDS_ON_PREC_LOSS = True

Expand All @@ -96,7 +98,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Abstra
# Boolean
"boolean": Boolean,
}
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}

def quote(self, s: str):
return f'"{s}"'
Expand Down
Loading