Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 076fec2

Browse files
authored
Merge pull request #722 from datafold/fix-aftermath-of-refactoring
Fix the aftermath of refactoring
2 parents 3949a27 + c36ff0a commit 076fec2

28 files changed

+154
-110
lines changed

data_diff/__main__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def _data_diff(
461461

462462
schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths)))
463463
schema1, schema2 = schemas = [
464-
create_schema(db, table_path, schema, case_sensitive)
464+
create_schema(db.name, table_path, schema, case_sensitive)
465465
for db, table_path, schema in safezip(dbs, table_paths, schemas)
466466
]
467467

data_diff/databases/_connect.py

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def match_path(self, dsn):
9494

9595
class Connect:
9696
"""Provides methods for connecting to a supported database using a URL or connection dict."""
97+
9798
conn_cache: MutableMapping[Hashable, Database]
9899

99100
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):

data_diff/databases/base.py

+47-43
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,39 @@
2121
from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString
2222
from data_diff.utils import ArithString, is_uuid, join_iter, safezip
2323
from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this
24-
from data_diff.queries.ast_classes import Alias, BinOp, CaseWhen, Cast, Column, Commit, Concat, ConstantTable, Count, \
25-
CreateTable, Cte, \
26-
CurrentTimestamp, DropTable, Func, \
27-
GroupBy, \
28-
ITable, In, InsertToTable, IsDistinctFrom, \
29-
Join, \
30-
Param, \
31-
Random, \
32-
Root, TableAlias, TableOp, TablePath, \
33-
TimeTravel, TruncateTable, UnaryOp, WhenThen, _ResolveColumn
24+
from data_diff.queries.ast_classes import (
25+
Alias,
26+
BinOp,
27+
CaseWhen,
28+
Cast,
29+
Column,
30+
Commit,
31+
Concat,
32+
ConstantTable,
33+
Count,
34+
CreateTable,
35+
Cte,
36+
CurrentTimestamp,
37+
DropTable,
38+
Func,
39+
GroupBy,
40+
ITable,
41+
In,
42+
InsertToTable,
43+
IsDistinctFrom,
44+
Join,
45+
Param,
46+
Random,
47+
Root,
48+
TableAlias,
49+
TableOp,
50+
TablePath,
51+
TimeTravel,
52+
TruncateTable,
53+
UnaryOp,
54+
WhenThen,
55+
_ResolveColumn,
56+
)
3457
from data_diff.abcs.database_types import (
3558
Array,
3659
Struct,
@@ -67,17 +90,11 @@ class CompileError(Exception):
6790
pass
6891

6992

70-
# TODO: LATER: Resolve the circular imports of databases-compiler-dialects:
71-
# A database uses a compiler to render the SQL query.
72-
# The compiler delegates to a dialect.
73-
# The dialect renders the SQL.
74-
# AS IS: The dialect requires the db to normalize table paths — leading to the back-dependency.
75-
# TO BE: All the tables paths must be pre-normalized before SQL rendering.
76-
# Also: c.database.is_autocommit in render_commit().
77-
# After this, the Compiler can cease referring Database/Dialect at all,
78-
# and be used only as a CompilingContext (a counter/data-bearing class).
79-
# As a result, it becomes low-level util, and the circular dependency auto-resolves.
80-
# Meanwhile, the easy fix is to simply move the Compiler here.
93+
# TODO: remove once switched to attrs, where ForwardRef[]/strings are resolved.
94+
class _RuntypeHackToFixCicularRefrencedDatabase:
95+
dialect: "BaseDialect"
96+
97+
8198
@dataclass
8299
class Compiler(AbstractCompiler):
83100
"""
@@ -90,7 +107,7 @@ class Compiler(AbstractCompiler):
90107
# Database is needed to normalize tables. Dialect is needed for recursive compilations.
91108
# In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
92109
# In practice, we currently bind the dialects to the specific database classes.
93-
database: "Database"
110+
database: _RuntypeHackToFixCicularRefrencedDatabase
94111

95112
in_select: bool = False # Compilation runtime flag
96113
in_join: bool = False # Compilation runtime flag
@@ -102,7 +119,7 @@ class Compiler(AbstractCompiler):
102119
_counter: List = field(default_factory=lambda: [0])
103120

104121
@property
105-
def dialect(self) -> "Dialect":
122+
def dialect(self) -> "BaseDialect":
106123
return self.database.dialect
107124

108125
# TODO: DEPRECATED: Remove once the dialect is used directly in all places.
@@ -223,7 +240,6 @@ class BaseDialect(abc.ABC):
223240
SUPPORTS_PRIMARY_KEY = False
224241
SUPPORTS_INDEXES = False
225242
TYPE_CLASSES: Dict[str, type] = {}
226-
MIXINS = frozenset()
227243

228244
PLACEHOLDER_TABLE = None # Used for Oracle
229245

@@ -414,7 +430,9 @@ def render_checksum(self, c: Compiler, elem: Checksum) -> str:
414430

415431
def render_concat(self, c: Compiler, elem: Concat) -> str:
416432
# We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL
417-
items = [f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')" for expr in elem.exprs]
433+
items = [
434+
f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')" for expr in elem.exprs
435+
]
418436
assert items
419437
if len(items) == 1:
420438
return items[0]
@@ -559,17 +577,15 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
559577
columns=columns,
560578
group_by_exprs=[Code(k) for k in keys],
561579
having_exprs=elem.having_exprs,
562-
)
580+
),
563581
)
564582

565583
keys_str = ", ".join(keys)
566584
columns_str = ", ".join(self.compile(c, x) for x in columns)
567585
having_str = (
568586
" HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else ""
569587
)
570-
select = (
571-
f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}"
572-
)
588+
select = f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}"
573589

574590
if c.in_select:
575591
select = f"({select}) {c.new_unique_name()}"
@@ -601,7 +617,7 @@ def render_timetravel(self, c: Compiler, elem: TimeTravel) -> str:
601617
# TODO: why is it c.? why not self? time-trvelling is the dialect's thing, isnt't it?
602618
c.time_travel(
603619
elem.table, before=elem.before, timestamp=elem.timestamp, offset=elem.offset, statement=elem.statement
604-
)
620+
),
605621
)
606622

607623
def render_createtable(self, c: Compiler, elem: CreateTable) -> str:
@@ -768,18 +784,6 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
768784
# See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
769785
return math.floor(math.log(2**p, 10))
770786

771-
@classmethod
772-
def load_mixins(cls, *abstract_mixins) -> Self:
773-
"Load a list of mixins that implement the given abstract mixins"
774-
mixins = {m for m in cls.MIXINS if issubclass(m, abstract_mixins)}
775-
776-
class _DialectWithMixins(cls, *mixins, *abstract_mixins):
777-
pass
778-
779-
_DialectWithMixins.__name__ = cls.__name__
780-
return _DialectWithMixins()
781-
782-
783787
@property
784788
@abstractmethod
785789
def name(self) -> str:
@@ -822,7 +826,7 @@ def __getitem__(self, i):
822826
return self.rows[i]
823827

824828

825-
class Database(abc.ABC):
829+
class Database(abc.ABC, _RuntypeHackToFixCicularRefrencedDatabase):
826830
"""Base abstract class for databases.
827831
828832
Used for providing connection code and implementation specific SQL utilities.

data_diff/databases/bigquery.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ def time_travel(
139139
)
140140

141141

142-
class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
142+
class Dialect(
143+
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
144+
):
143145
name = "BigQuery"
144146
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation
145147
TYPE_CLASSES = {
@@ -159,7 +161,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Abstra
159161
}
160162
TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>")
161163
TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>")
162-
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample}
163164

164165
def random(self) -> str:
165166
return "RAND()"

data_diff/databases/clickhouse.py

-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, A
125125
"DateTime64": Timestamp,
126126
"Bool": Boolean,
127127
}
128-
MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}
129128

130129
def quote(self, s: str) -> str:
131130
return f'"{s}"'

data_diff/databases/databricks.py

-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, A
7979
# Boolean
8080
"BOOLEAN": Boolean,
8181
}
82-
MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}
8382

8483
def quote(self, s: str):
8584
return f"`{s}`"

data_diff/databases/duckdb.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,13 @@ def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable:
6868
return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio))
6969

7070

71-
class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
71+
class Dialect(
72+
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
73+
):
7274
name = "DuckDB"
7375
ROUNDS_ON_PREC_LOSS = False
7476
SUPPORTS_PRIMARY_KEY = True
7577
SUPPORTS_INDEXES = True
76-
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}
7778

7879
TYPE_CLASSES = {
7980
# Timestamps

data_diff/databases/mssql.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,15 @@ def md5_as_int(self, s: str) -> str:
5858
return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))"
5959

6060

61-
class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
61+
class Dialect(
62+
BaseDialect,
63+
Mixin_Schema,
64+
Mixin_OptimizerHints,
65+
Mixin_MD5,
66+
Mixin_NormalizeValue,
67+
AbstractMixin_MD5,
68+
AbstractMixin_NormalizeValue,
69+
):
6270
name = "MsSQL"
6371
ROUNDS_ON_PREC_LOSS = True
6472
SUPPORTS_PRIMARY_KEY = True
@@ -98,8 +106,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_
98106
"json": JSON,
99107
}
100108

101-
MIXINS = {Mixin_Schema, Mixin_NormalizeValue, Mixin_RandomSample}
102-
103109
def quote(self, s: str):
104110
return f"[{s}]"
105111

data_diff/databases/mysql.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,15 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
6060
return f"TRIM(CAST({value} AS char))"
6161

6262

63-
class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
63+
class Dialect(
64+
BaseDialect,
65+
Mixin_Schema,
66+
Mixin_OptimizerHints,
67+
Mixin_MD5,
68+
Mixin_NormalizeValue,
69+
AbstractMixin_MD5,
70+
AbstractMixin_NormalizeValue,
71+
):
6472
name = "MySQL"
6573
ROUNDS_ON_PREC_LOSS = True
6674
SUPPORTS_PRIMARY_KEY = True
@@ -91,7 +99,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_
9199
# Boolean
92100
"boolean": Boolean,
93101
}
94-
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}
95102

96103
def quote(self, s: str):
97104
return f"`{s}`"

data_diff/databases/oracle.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,15 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
8080
)
8181

8282

83-
class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
83+
class Dialect(
84+
BaseDialect,
85+
Mixin_Schema,
86+
Mixin_OptimizerHints,
87+
Mixin_MD5,
88+
Mixin_NormalizeValue,
89+
AbstractMixin_MD5,
90+
AbstractMixin_NormalizeValue,
91+
):
8492
name = "Oracle"
8593
SUPPORTS_PRIMARY_KEY = True
8694
SUPPORTS_INDEXES = True
@@ -96,7 +104,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_
96104
}
97105
ROUNDS_ON_PREC_LOSS = True
98106
PLACEHOLDER_TABLE = "DUAL"
99-
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}
100107

101108
def quote(self, s: str):
102109
return f'"{s}"'

data_diff/databases/postgresql.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ def normalize_json(self, value: str, _coltype: JSON) -> str:
6060
return f"{value}::text"
6161

6262

63-
class PostgresqlDialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
63+
class PostgresqlDialect(
64+
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
65+
):
6466
name = "PostgreSQL"
6567
ROUNDS_ON_PREC_LOSS = True
6668
SUPPORTS_PRIMARY_KEY = True
6769
SUPPORTS_INDEXES = True
68-
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}
6970

7071
TYPE_CLASSES = {
7172
# Timestamps

data_diff/databases/presto.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
7676
return self.to_string(f"cast ({value} as int)")
7777

7878

79-
class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
79+
class Dialect(
80+
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
81+
):
8082
name = "Presto"
8183
ROUNDS_ON_PREC_LOSS = True
8284
TYPE_CLASSES = {
@@ -94,7 +96,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Abstra
9496
# Boolean
9597
"boolean": Boolean,
9698
}
97-
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}
9899

99100
def explain_as_text(self, query: str) -> str:
100101
return f"EXPLAIN (FORMAT TEXT) {query}"

data_diff/databases/snowflake.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def time_travel(
104104
return code(f"{{table}} {at_or_before}({key} => {{value}})", table=table, value=value)
105105

106106

107-
class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
107+
class Dialect(
108+
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
109+
):
108110
name = "Snowflake"
109111
ROUNDS_ON_PREC_LOSS = False
110112
TYPE_CLASSES = {
@@ -121,7 +123,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Abstra
121123
# Boolean
122124
"BOOLEAN": Boolean,
123125
}
124-
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample}
125126

126127
def explain_as_text(self, query: str) -> str:
127128
return f"EXPLAIN USING TEXT {query}"

data_diff/databases/vertica.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
7878
)
7979

8080

81-
class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
81+
class Dialect(
82+
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
83+
):
8284
name = "Vertica"
8385
ROUNDS_ON_PREC_LOSS = True
8486

@@ -96,7 +98,6 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Abstra
9698
# Boolean
9799
"boolean": Boolean,
98100
}
99-
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}
100101

101102
def quote(self, s: str):
102103
return f'"{s}"'

0 commit comments

Comments
 (0)