Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 0a7d2cb

Browse files
authored
Merge pull request #710 from datafold/simplify-squash-sqeleton
Squash sqeleton into data_diff
2 parents 818fedb + 871c201 commit 0a7d2cb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+3641
-3934
lines changed

data_diff/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from typing import Sequence, Tuple, Iterator, Optional, Union
22

3-
from data_diff.sqeleton.abcs import DbTime, DbPath
4-
3+
from data_diff.abcs.database_types import DbTime, DbPath
54
from data_diff.tracking import disable_tracking
6-
from data_diff.databases import connect
5+
from data_diff.databases._connect import connect
76
from data_diff.diff_tables import Algorithm
87
from data_diff.hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
98
from data_diff.joindiff_tables import JoinDiffer, TABLE_WRITE_LIMIT

data_diff/__main__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
from rich.logging import RichHandler
1313
import click
1414

15-
from data_diff.sqeleton.schema import create_schema
16-
from data_diff.sqeleton.queries.api import current_timestamp
15+
from data_diff.schema import create_schema
16+
from data_diff.queries.api import current_timestamp
1717

1818
from data_diff.dbt import dbt_diff
1919
from data_diff.utils import eval_name_template, remove_password_from_url, safezip, match_like, LogStatusHandler
2020
from data_diff.diff_tables import Algorithm
2121
from data_diff.hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
2222
from data_diff.joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
2323
from data_diff.table_segment import TableSegment
24-
from data_diff.databases import connect
24+
from data_diff.databases._connect import connect
2525
from data_diff.parse_time import parse_time_before, UNITS_STR, ParseError
2626
from data_diff.config import apply_config_from_file
2727
from data_diff.tracking import disable_tracking, set_entrypoint_name

data_diff/abcs/__init__.py

Whitespace-only changes.
File renamed without changes.

data_diff/sqeleton/abcs/database_types.py renamed to data_diff/abcs/database_types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from runtype import dataclass
77
from typing_extensions import Self
88

9-
from data_diff.sqeleton.utils import ArithAlphanumeric, ArithUUID, Unknown
9+
from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown
1010

1111

1212
DbPath = Tuple[str, ...]

data_diff/sqeleton/abcs/mixins.py renamed to data_diff/abcs/mixins.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from data_diff.sqeleton.abcs.database_types import (
2+
from data_diff.abcs.database_types import (
33
Array,
44
TemporalType,
55
FractionalType,
@@ -10,7 +10,7 @@
1010
JSON,
1111
Struct,
1212
)
13-
from data_diff.sqeleton.abcs.compiler import Compilable
13+
from data_diff.abcs.compiler import Compilable
1414

1515

1616
class AbstractMixin(ABC):

data_diff/sqeleton/bound_exprs.py renamed to data_diff/bound_exprs.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from runtype import dataclass
88
from typing_extensions import Self
99

10-
from data_diff.sqeleton.abcs import AbstractDatabase, AbstractCompiler
11-
from data_diff.sqeleton.queries.ast_classes import ExprNode, ITable, TablePath, Compilable
12-
from data_diff.sqeleton.queries.api import table
13-
from data_diff.sqeleton.schema import create_schema
10+
from data_diff.abcs.database_types import AbstractDatabase
11+
from data_diff.abcs.compiler import AbstractCompiler
12+
from data_diff.queries.ast_classes import ExprNode, TablePath, Compilable
13+
from data_diff.queries.api import table
14+
from data_diff.schema import create_schema
1415

1516

1617
@dataclass
@@ -80,8 +81,8 @@ def bound_table(database: AbstractDatabase, table_path: Union[TablePath, str, tu
8081
# Database.table = bound_table
8182

8283
# def test():
83-
# from data_diff.sqeleton. import connect
84-
# from data_diff.sqeleton.queries.api import table
84+
# from data_diff import connect
85+
# from data_diff.queries.api import table
8586
# d = connect("mysql://erez:qweqwe123@localhost/erez")
8687
# t = table(('Rating',))
8788

data_diff/databases/__init__.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
from data_diff.sqeleton.databases import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError
2-
3-
from data_diff.databases.postgresql import PostgreSQL
4-
from data_diff.databases.mysql import MySQL
5-
from data_diff.databases.oracle import Oracle
6-
from data_diff.databases.snowflake import Snowflake
7-
from data_diff.databases.bigquery import BigQuery
8-
from data_diff.databases.redshift import Redshift
9-
from data_diff.databases.presto import Presto
10-
from data_diff.databases.databricks import Databricks
11-
from data_diff.databases.trino import Trino
12-
from data_diff.databases.clickhouse import Clickhouse
13-
from data_diff.databases.vertica import Vertica
14-
from data_diff.databases.duckdb import DuckDB
15-
from data_diff.databases.mssql import MsSql
16-
17-
from data_diff.databases._connect import connect
1+
from data_diff.databases.base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError, BaseDialect, Database
2+
from data_diff.databases._connect import connect as connect
3+
from data_diff.databases._connect import Connect as Connect
4+
from data_diff.databases.postgresql import PostgreSQL as PostgreSQL
5+
from data_diff.databases.mysql import MySQL as MySQL
6+
from data_diff.databases.oracle import Oracle as Oracle
7+
from data_diff.databases.snowflake import Snowflake as Snowflake
8+
from data_diff.databases.bigquery import BigQuery as BigQuery
9+
from data_diff.databases.redshift import Redshift as Redshift
10+
from data_diff.databases.presto import Presto as Presto
11+
from data_diff.databases.databricks import Databricks as Databricks
12+
from data_diff.databases.trino import Trino as Trino
13+
from data_diff.databases.clickhouse import Clickhouse as Clickhouse
14+
from data_diff.databases.vertica import Vertica as Vertica
15+
from data_diff.databases.duckdb import DuckDB as DuckDB
16+
from data_diff.databases.mssql import MsSQL as MsSQL

data_diff/databases/_connect.py

+252-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
import logging
2+
from typing import Hashable, MutableMapping, Type, Optional, Union, Dict
3+
from itertools import zip_longest
4+
from contextlib import suppress
5+
import weakref
6+
import dsnparse
7+
import toml
28

3-
from data_diff.sqeleton.databases import Connect
9+
from runtype import dataclass
10+
from typing_extensions import Self
411

12+
from data_diff.databases.base import Database, ThreadedDatabase
513
from data_diff.databases.postgresql import PostgreSQL
614
from data_diff.databases.mysql import MySQL
715
from data_diff.databases.oracle import Oracle
@@ -14,7 +22,57 @@
1422
from data_diff.databases.clickhouse import Clickhouse
1523
from data_diff.databases.vertica import Vertica
1624
from data_diff.databases.duckdb import DuckDB
17-
from data_diff.databases.mssql import MsSql
25+
from data_diff.databases.mssql import MsSQL
26+
27+
28+
@dataclass
29+
class MatchUriPath:
30+
database_cls: Type[Database]
31+
32+
def match_path(self, dsn):
33+
help_str = self.database_cls.CONNECT_URI_HELP
34+
params = self.database_cls.CONNECT_URI_PARAMS
35+
kwparams = self.database_cls.CONNECT_URI_KWPARAMS
36+
37+
dsn_dict = dict(dsn.query)
38+
matches = {}
39+
for param, arg in zip_longest(params, dsn.paths):
40+
if param is None:
41+
raise ValueError(f"Too many parts to path. Expected format: {help_str}")
42+
43+
optional = param.endswith("?")
44+
param = param.rstrip("?")
45+
46+
if arg is None:
47+
try:
48+
arg = dsn_dict.pop(param)
49+
except KeyError:
50+
if not optional:
51+
raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}")
52+
53+
arg = None
54+
55+
assert param and param not in matches
56+
matches[param] = arg
57+
58+
for param in kwparams:
59+
try:
60+
arg = dsn_dict.pop(param)
61+
except KeyError:
62+
raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}")
63+
64+
assert param and arg and param not in matches, (param, arg, matches.keys())
65+
matches[param] = arg
66+
67+
for param, value in dsn_dict.items():
68+
if param in matches:
69+
raise ValueError(
70+
f"Parameter '{param}' already provided as positional argument. Expected format: {help_str}"
71+
)
72+
73+
matches[param] = value
74+
75+
return matches
1876

1977

2078
DATABASE_BY_SCHEME = {
@@ -30,10 +88,201 @@
3088
"trino": Trino,
3189
"clickhouse": Clickhouse,
3290
"vertica": Vertica,
33-
"mssql": MsSql,
91+
"mssql": MsSQL,
3492
}
3593

3694

95+
class Connect:
96+
"""Provides methods for connecting to a supported database using a URL or connection dict."""
97+
conn_cache: MutableMapping[Hashable, Database]
98+
99+
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
100+
self.database_by_scheme = database_by_scheme
101+
self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()}
102+
self.conn_cache = weakref.WeakValueDictionary()
103+
104+
def for_databases(self, *dbs) -> Self:
105+
database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs}
106+
return type(self)(database_by_scheme)
107+
108+
def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs) -> Database:
109+
"""Connect to the given database uri
110+
111+
thread_count determines the max number of worker threads per database,
112+
if relevant. None means no limit.
113+
114+
Parameters:
115+
db_uri (str): The URI for the database to connect
116+
thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1)
117+
118+
Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.
119+
120+
Supported schemes:
121+
- postgresql
122+
- mysql
123+
- oracle
124+
- snowflake
125+
- bigquery
126+
- redshift
127+
- presto
128+
- databricks
129+
- trino
130+
- clickhouse
131+
- vertica
132+
- duckdb
133+
"""
134+
135+
dsn = dsnparse.parse(db_uri)
136+
if len(dsn.schemes) > 1:
137+
raise NotImplementedError("No support for multiple schemes")
138+
(scheme,) = dsn.schemes
139+
140+
if scheme == "toml":
141+
toml_path = dsn.path or dsn.host
142+
database = dsn.fragment
143+
if not database:
144+
raise ValueError("Must specify a database name, e.g. 'toml://path#database'. ")
145+
with open(toml_path) as f:
146+
config = toml.load(f)
147+
try:
148+
conn_dict = config["database"][database]
149+
except KeyError:
150+
raise ValueError(f"Cannot find database config named '{database}'.")
151+
return self.connect_with_dict(conn_dict, thread_count, **kwargs)
152+
153+
try:
154+
matcher = self.match_uri_path[scheme]
155+
except KeyError:
156+
raise NotImplementedError(f"Scheme '{scheme}' currently not supported")
157+
158+
cls = matcher.database_cls
159+
160+
if scheme == "databricks":
161+
assert not dsn.user
162+
kw = {}
163+
kw["access_token"] = dsn.password
164+
kw["http_path"] = dsn.path
165+
kw["server_hostname"] = dsn.host
166+
kw.update(dsn.query)
167+
elif scheme == "duckdb":
168+
kw = {}
169+
kw["filepath"] = dsn.dbname
170+
kw["dbname"] = dsn.user
171+
else:
172+
kw = matcher.match_path(dsn)
173+
174+
if scheme == "bigquery":
175+
kw["project"] = dsn.host
176+
return cls(**kw, **kwargs)
177+
178+
if scheme == "snowflake":
179+
kw["account"] = dsn.host
180+
assert not dsn.port
181+
kw["user"] = dsn.user
182+
kw["password"] = dsn.password
183+
else:
184+
if scheme == "oracle":
185+
kw["host"] = dsn.hostloc
186+
else:
187+
kw["host"] = dsn.host
188+
kw["port"] = dsn.port
189+
kw["user"] = dsn.user
190+
if dsn.password:
191+
kw["password"] = dsn.password
192+
193+
kw = {k: v for k, v in kw.items() if v is not None}
194+
195+
if issubclass(cls, ThreadedDatabase):
196+
db = cls(thread_count=thread_count, **kw, **kwargs)
197+
else:
198+
db = cls(**kw, **kwargs)
199+
200+
return self._connection_created(db)
201+
202+
def connect_with_dict(self, d, thread_count, **kwargs):
203+
d = dict(d)
204+
driver = d.pop("driver")
205+
try:
206+
matcher = self.match_uri_path[driver]
207+
except KeyError:
208+
raise NotImplementedError(f"Driver '{driver}' currently not supported")
209+
210+
cls = matcher.database_cls
211+
if issubclass(cls, ThreadedDatabase):
212+
db = cls(thread_count=thread_count, **d, **kwargs)
213+
else:
214+
db = cls(**d, **kwargs)
215+
216+
return self._connection_created(db)
217+
218+
def _connection_created(self, db):
219+
"Nop function to be overridden by subclasses."
220+
return db
221+
222+
def __call__(
223+
self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True, **kwargs
224+
) -> Database:
225+
"""Connect to a database using the given database configuration.
226+
227+
Configuration can be given either as a URI string, or as a dict of {option: value}.
228+
229+
The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf.
230+
231+
thread_count determines the max number of worker threads per database,
232+
if relevant. None means no limit.
233+
234+
Parameters:
235+
db_conf (str | dict): The configuration for the database to connect. URI or dict.
236+
thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1)
237+
shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True)
238+
bigquery_credentials (google.oauth2.credentials.Credentials): Custom Google oAuth2 credential for BigQuery.
239+
(default: None)
240+
241+
Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.
242+
243+
Supported drivers:
244+
- postgresql
245+
- mysql
246+
- oracle
247+
- snowflake
248+
- bigquery
249+
- redshift
250+
- presto
251+
- databricks
252+
- trino
253+
- clickhouse
254+
- vertica
255+
256+
Example:
257+
>>> connect("mysql://localhost/db")
258+
<data_diff.databases.mysql.MySQL object at ...>
259+
>>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
260+
<data_diff.databases.mysql.MySQL object at ...>
261+
"""
262+
cache_key = self.__make_cache_key(db_conf)
263+
if shared:
264+
with suppress(KeyError):
265+
conn = self.conn_cache[cache_key]
266+
if not conn.is_closed:
267+
return conn
268+
269+
if isinstance(db_conf, str):
270+
conn = self.connect_to_uri(db_conf, thread_count, **kwargs)
271+
elif isinstance(db_conf, dict):
272+
conn = self.connect_with_dict(db_conf, thread_count, **kwargs)
273+
else:
274+
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")
275+
276+
if shared:
277+
self.conn_cache[cache_key] = conn
278+
return conn
279+
280+
def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable:
281+
if isinstance(db_conf, dict):
282+
return tuple(db_conf.items())
283+
return db_conf
284+
285+
37286
class Connect_SetUTC(Connect):
38287
"""Provides methods for connecting to a supported database using a URL or connection dict.
39288

0 commit comments

Comments
 (0)