|
1 | 1 | import logging
|
| 2 | +from typing import Hashable, MutableMapping, Type, Optional, Union, Dict |
| 3 | +from itertools import zip_longest |
| 4 | +from contextlib import suppress |
| 5 | +import weakref |
| 6 | +import dsnparse |
| 7 | +import toml |
2 | 8 |
|
3 |
| -from data_diff.sqeleton.databases import Connect |
| 9 | +from runtype import dataclass |
| 10 | +from typing_extensions import Self |
4 | 11 |
|
| 12 | +from data_diff.databases.base import Database, ThreadedDatabase |
5 | 13 | from data_diff.databases.postgresql import PostgreSQL
|
6 | 14 | from data_diff.databases.mysql import MySQL
|
7 | 15 | from data_diff.databases.oracle import Oracle
|
|
14 | 22 | from data_diff.databases.clickhouse import Clickhouse
|
15 | 23 | from data_diff.databases.vertica import Vertica
|
16 | 24 | from data_diff.databases.duckdb import DuckDB
|
17 |
| -from data_diff.databases.mssql import MsSql |
| 25 | +from data_diff.databases.mssql import MsSQL |
| 26 | + |
| 27 | + |
| 28 | +@dataclass |
| 29 | +class MatchUriPath: |
| 30 | + database_cls: Type[Database] |
| 31 | + |
| 32 | + def match_path(self, dsn): |
| 33 | + help_str = self.database_cls.CONNECT_URI_HELP |
| 34 | + params = self.database_cls.CONNECT_URI_PARAMS |
| 35 | + kwparams = self.database_cls.CONNECT_URI_KWPARAMS |
| 36 | + |
| 37 | + dsn_dict = dict(dsn.query) |
| 38 | + matches = {} |
| 39 | + for param, arg in zip_longest(params, dsn.paths): |
| 40 | + if param is None: |
| 41 | + raise ValueError(f"Too many parts to path. Expected format: {help_str}") |
| 42 | + |
| 43 | + optional = param.endswith("?") |
| 44 | + param = param.rstrip("?") |
| 45 | + |
| 46 | + if arg is None: |
| 47 | + try: |
| 48 | + arg = dsn_dict.pop(param) |
| 49 | + except KeyError: |
| 50 | + if not optional: |
| 51 | + raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") |
| 52 | + |
| 53 | + arg = None |
| 54 | + |
| 55 | + assert param and param not in matches |
| 56 | + matches[param] = arg |
| 57 | + |
| 58 | + for param in kwparams: |
| 59 | + try: |
| 60 | + arg = dsn_dict.pop(param) |
| 61 | + except KeyError: |
| 62 | + raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") |
| 63 | + |
| 64 | + assert param and arg and param not in matches, (param, arg, matches.keys()) |
| 65 | + matches[param] = arg |
| 66 | + |
| 67 | + for param, value in dsn_dict.items(): |
| 68 | + if param in matches: |
| 69 | + raise ValueError( |
| 70 | + f"Parameter '{param}' already provided as positional argument. Expected format: {help_str}" |
| 71 | + ) |
| 72 | + |
| 73 | + matches[param] = value |
| 74 | + |
| 75 | + return matches |
18 | 76 |
|
19 | 77 |
|
20 | 78 | DATABASE_BY_SCHEME = {
|
|
30 | 88 | "trino": Trino,
|
31 | 89 | "clickhouse": Clickhouse,
|
32 | 90 | "vertica": Vertica,
|
33 |
| - "mssql": MsSql, |
| 91 | + "mssql": MsSQL, |
34 | 92 | }
|
35 | 93 |
|
36 | 94 |
|
| 95 | +class Connect: |
| 96 | + """Provides methods for connecting to a supported database using a URL or connection dict.""" |
| 97 | + conn_cache: MutableMapping[Hashable, Database] |
| 98 | + |
| 99 | + def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): |
| 100 | + self.database_by_scheme = database_by_scheme |
| 101 | + self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()} |
| 102 | + self.conn_cache = weakref.WeakValueDictionary() |
| 103 | + |
| 104 | + def for_databases(self, *dbs) -> Self: |
| 105 | + database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs} |
| 106 | + return type(self)(database_by_scheme) |
| 107 | + |
| 108 | + def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs) -> Database: |
| 109 | + """Connect to the given database uri |
| 110 | +
|
| 111 | + thread_count determines the max number of worker threads per database, |
| 112 | + if relevant. None means no limit. |
| 113 | +
|
| 114 | + Parameters: |
| 115 | + db_uri (str): The URI for the database to connect |
| 116 | + thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) |
| 117 | +
|
| 118 | + Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. |
| 119 | +
|
| 120 | + Supported schemes: |
| 121 | + - postgresql |
| 122 | + - mysql |
| 123 | + - oracle |
| 124 | + - snowflake |
| 125 | + - bigquery |
| 126 | + - redshift |
| 127 | + - presto |
| 128 | + - databricks |
| 129 | + - trino |
| 130 | + - clickhouse |
| 131 | + - vertica |
| 132 | + - duckdb |
| 133 | + """ |
| 134 | + |
| 135 | + dsn = dsnparse.parse(db_uri) |
| 136 | + if len(dsn.schemes) > 1: |
| 137 | + raise NotImplementedError("No support for multiple schemes") |
| 138 | + (scheme,) = dsn.schemes |
| 139 | + |
| 140 | + if scheme == "toml": |
| 141 | + toml_path = dsn.path or dsn.host |
| 142 | + database = dsn.fragment |
| 143 | + if not database: |
| 144 | + raise ValueError("Must specify a database name, e.g. 'toml://path#database'. ") |
| 145 | + with open(toml_path) as f: |
| 146 | + config = toml.load(f) |
| 147 | + try: |
| 148 | + conn_dict = config["database"][database] |
| 149 | + except KeyError: |
| 150 | + raise ValueError(f"Cannot find database config named '{database}'.") |
| 151 | + return self.connect_with_dict(conn_dict, thread_count, **kwargs) |
| 152 | + |
| 153 | + try: |
| 154 | + matcher = self.match_uri_path[scheme] |
| 155 | + except KeyError: |
| 156 | + raise NotImplementedError(f"Scheme '{scheme}' currently not supported") |
| 157 | + |
| 158 | + cls = matcher.database_cls |
| 159 | + |
| 160 | + if scheme == "databricks": |
| 161 | + assert not dsn.user |
| 162 | + kw = {} |
| 163 | + kw["access_token"] = dsn.password |
| 164 | + kw["http_path"] = dsn.path |
| 165 | + kw["server_hostname"] = dsn.host |
| 166 | + kw.update(dsn.query) |
| 167 | + elif scheme == "duckdb": |
| 168 | + kw = {} |
| 169 | + kw["filepath"] = dsn.dbname |
| 170 | + kw["dbname"] = dsn.user |
| 171 | + else: |
| 172 | + kw = matcher.match_path(dsn) |
| 173 | + |
| 174 | + if scheme == "bigquery": |
| 175 | + kw["project"] = dsn.host |
| 176 | + return cls(**kw, **kwargs) |
| 177 | + |
| 178 | + if scheme == "snowflake": |
| 179 | + kw["account"] = dsn.host |
| 180 | + assert not dsn.port |
| 181 | + kw["user"] = dsn.user |
| 182 | + kw["password"] = dsn.password |
| 183 | + else: |
| 184 | + if scheme == "oracle": |
| 185 | + kw["host"] = dsn.hostloc |
| 186 | + else: |
| 187 | + kw["host"] = dsn.host |
| 188 | + kw["port"] = dsn.port |
| 189 | + kw["user"] = dsn.user |
| 190 | + if dsn.password: |
| 191 | + kw["password"] = dsn.password |
| 192 | + |
| 193 | + kw = {k: v for k, v in kw.items() if v is not None} |
| 194 | + |
| 195 | + if issubclass(cls, ThreadedDatabase): |
| 196 | + db = cls(thread_count=thread_count, **kw, **kwargs) |
| 197 | + else: |
| 198 | + db = cls(**kw, **kwargs) |
| 199 | + |
| 200 | + return self._connection_created(db) |
| 201 | + |
| 202 | + def connect_with_dict(self, d, thread_count, **kwargs): |
| 203 | + d = dict(d) |
| 204 | + driver = d.pop("driver") |
| 205 | + try: |
| 206 | + matcher = self.match_uri_path[driver] |
| 207 | + except KeyError: |
| 208 | + raise NotImplementedError(f"Driver '{driver}' currently not supported") |
| 209 | + |
| 210 | + cls = matcher.database_cls |
| 211 | + if issubclass(cls, ThreadedDatabase): |
| 212 | + db = cls(thread_count=thread_count, **d, **kwargs) |
| 213 | + else: |
| 214 | + db = cls(**d, **kwargs) |
| 215 | + |
| 216 | + return self._connection_created(db) |
| 217 | + |
| 218 | + def _connection_created(self, db): |
| 219 | + "Nop function to be overridden by subclasses." |
| 220 | + return db |
| 221 | + |
| 222 | + def __call__( |
| 223 | + self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True, **kwargs |
| 224 | + ) -> Database: |
| 225 | + """Connect to a database using the given database configuration. |
| 226 | +
|
| 227 | + Configuration can be given either as a URI string, or as a dict of {option: value}. |
| 228 | +
|
| 229 | + The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf. |
| 230 | +
|
| 231 | + thread_count determines the max number of worker threads per database, |
| 232 | + if relevant. None means no limit. |
| 233 | +
|
| 234 | + Parameters: |
| 235 | + db_conf (str | dict): The configuration for the database to connect. URI or dict. |
| 236 | + thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) |
| 237 | + shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True) |
| 238 | + bigquery_credentials (google.oauth2.credentials.Credentials): Custom Google oAuth2 credential for BigQuery. |
| 239 | + (default: None) |
| 240 | +
|
| 241 | + Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. |
| 242 | +
|
| 243 | + Supported drivers: |
| 244 | + - postgresql |
| 245 | + - mysql |
| 246 | + - oracle |
| 247 | + - snowflake |
| 248 | + - bigquery |
| 249 | + - redshift |
| 250 | + - presto |
| 251 | + - databricks |
| 252 | + - trino |
| 253 | + - clickhouse |
| 254 | + - vertica |
| 255 | +
|
| 256 | + Example: |
| 257 | + >>> connect("mysql://localhost/db") |
| 258 | + <data_diff.databases.mysql.MySQL object at ...> |
| 259 | + >>> connect({"driver": "mysql", "host": "localhost", "database": "db"}) |
| 260 | + <data_diff.databases.mysql.MySQL object at ...> |
| 261 | + """ |
| 262 | + cache_key = self.__make_cache_key(db_conf) |
| 263 | + if shared: |
| 264 | + with suppress(KeyError): |
| 265 | + conn = self.conn_cache[cache_key] |
| 266 | + if not conn.is_closed: |
| 267 | + return conn |
| 268 | + |
| 269 | + if isinstance(db_conf, str): |
| 270 | + conn = self.connect_to_uri(db_conf, thread_count, **kwargs) |
| 271 | + elif isinstance(db_conf, dict): |
| 272 | + conn = self.connect_with_dict(db_conf, thread_count, **kwargs) |
| 273 | + else: |
| 274 | + raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.") |
| 275 | + |
| 276 | + if shared: |
| 277 | + self.conn_cache[cache_key] = conn |
| 278 | + return conn |
| 279 | + |
| 280 | + def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable: |
| 281 | + if isinstance(db_conf, dict): |
| 282 | + return tuple(db_conf.items()) |
| 283 | + return db_conf |
| 284 | + |
| 285 | + |
37 | 286 | class Connect_SetUTC(Connect):
|
38 | 287 | """Provides methods for connecting to a supported database using a URL or connection dict.
|
39 | 288 |
|
|
0 commit comments