Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 60c3d68

Browse files
dlawinSung Won Chung
authored and
Sung Won Chung
committed
Support MSSQL for cross-database diffs
1 parent 91fe04a commit 60c3d68

18 files changed

+364
-47
lines changed

data_diff/databases/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
from .clickhouse import Clickhouse
1313
from .vertica import Vertica
1414
from .duckdb import DuckDB
15+
from .mssql import MsSql
1516

1617
from ._connect import connect

data_diff/databases/_connect.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .clickhouse import Clickhouse
1515
from .vertica import Vertica
1616
from .duckdb import DuckDB
17+
from .mssql import MsSql
1718

1819

1920
DATABASE_BY_SCHEME = {
@@ -29,6 +30,7 @@
2930
"trino": Trino,
3031
"clickhouse": Clickhouse,
3132
"vertica": Vertica,
33+
"mssql": MsSql
3234
}
3335

3436

data_diff/databases/mssql.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from data_diff.sqeleton.databases import mssql
2+
from .base import DatadiffDialect
3+
4+
5+
class Dialect(mssql.Dialect, mssql.Mixin_MD5, mssql.Mixin_NormalizeValue, DatadiffDialect):
6+
pass
7+
8+
9+
class MsSql(mssql.MsSQL):
10+
dialect = Dialect()

data_diff/joindiff_tables.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from runtype import dataclass
1212

13-
from data_diff.sqeleton.databases import Database, MySQL, BigQuery, Presto, Oracle, Snowflake, DbPath
13+
from data_diff.sqeleton.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake, DbPath
1414
from data_diff.sqeleton.abcs import NumericType
1515
from data_diff.sqeleton.queries import (
1616
table,
@@ -25,9 +25,10 @@
2525
leftjoin,
2626
rightjoin,
2727
this,
28+
when,
2829
Compiler,
2930
)
30-
from data_diff.sqeleton.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code, ITable
31+
from data_diff.sqeleton.queries.ast_classes import Concat, Count, Expr, Func, Random, TablePath, Code, ITable
3132
from data_diff.sqeleton.queries.extras import NormalizeAsString
3233

3334
from .info_tree import InfoTree
@@ -82,6 +83,12 @@ def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List
8283

8384
is_exclusive_a = and_(b[k] == None for k in keys2)
8485
is_exclusive_b = and_(a[k] == None for k in keys1)
86+
87+
if isinstance(db, MsSQL):
88+
# There is no "IS NULL" or "ISNULL()" as expressions, only as conditions.
89+
is_exclusive_a = when(is_exclusive_a).then(1).else_(0)
90+
is_exclusive_b = when(is_exclusive_b).then(1).else_(0)
91+
8592
if isinstance(db, Oracle):
8693
is_exclusive_a = bool_to_int(is_exclusive_a)
8794
is_exclusive_b = bool_to_int(is_exclusive_b)
@@ -342,7 +349,7 @@ def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols):
342349
self.stats["diff_counts"] = diff_counts
343350

344351
def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols):
345-
if isinstance(db, Oracle):
352+
if isinstance(db, (Oracle, MsSQL)):
346353
exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1))
347354
else:
348355
exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b)

data_diff/sqeleton/abcs/database_types.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,15 @@ def current_timestamp(self) -> str:
216216
"Provide SQL for returning the current timestamp, aka now"
217217

218218
@abstractmethod
219-
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
219+
def current_database(self) -> str:
220+
"Provide SQL for returning the current default database."
221+
222+
@abstractmethod
223+
def current_schema(self) -> str:
224+
"Provide SQL for returning the current default schema."
225+
226+
@abstractmethod
227+
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None) -> str:
220228
"Provide SQL fragment for limit and offset inside a select"
221229

222230
@abstractmethod

data_diff/sqeleton/databases/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
from .clickhouse import Clickhouse
1515
from .vertica import Vertica
1616
from .duckdb import DuckDB
17+
from .mssql import MsSQL
1718

1819
connect = Connect()

data_diff/sqeleton/databases/_connect.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .clickhouse import Clickhouse
2222
from .vertica import Vertica
2323
from .duckdb import DuckDB
24+
from .mssql import MsSQL
2425

2526

2627
@dataclass
@@ -86,6 +87,7 @@ def match_path(self, dsn):
8687
"trino": Trino,
8788
"clickhouse": Clickhouse,
8889
"vertica": Vertica,
90+
"mssql": MsSQL
8991
}
9092

9193

data_diff/sqeleton/databases/base.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class BaseDialect(AbstractDialect):
155155

156156
PLACEHOLDER_TABLE = None # Used for Oracle
157157

158-
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
158+
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None) -> str:
159159
if offset:
160160
raise NotImplementedError("No support for OFFSET in query")
161161

@@ -182,6 +182,12 @@ def random(self) -> str:
182182
def current_timestamp(self) -> str:
183183
return "current_timestamp()"
184184

185+
def current_database(self) -> str:
186+
return "current_database()"
187+
188+
def current_schema(self) -> str:
189+
return "current_schema()"
190+
185191
def explain_as_text(self, query: str) -> str:
186192
return f"EXPLAIN {query}"
187193

@@ -518,7 +524,12 @@ def _query_cursor(self, c, sql_code: str) -> QueryResult:
518524
c.execute(sql_code)
519525
if sql_code.lower().startswith(("select", "explain", "show")):
520526
columns = [col[0] for col in c.description]
521-
return QueryResult(c.fetchall(), columns)
527+
528+
# TODO FIXME pyodbc.Row seems to be causing a pydantic error
529+
# [ConstantTable] Attribute 'rows' expected value of type Sequence[Sequence[Any]]
530+
fetched = c.fetchall()
531+
result = QueryResult(fetched, columns)
532+
return result
522533
except Exception as _e:
523534
# logger.exception(e)
524535
# logger.error(f'Caused by SQL: {sql_code}')
@@ -590,7 +601,8 @@ def is_autocommit(self) -> bool:
590601
return False
591602

592603

593-
CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower, otherwise SUM() overflows
604+
# TODO FYI mssql md5_as_int currently requires this to be reduced
605+
CHECKSUM_HEXDIGITS = 14 # Must be 15 or lower, otherwise SUM() overflows
594606
MD5_HEXDIGITS = 32
595607

596608
_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2

0 commit comments

Comments
 (0)