Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit fef52d0

Browse files
authored
Merge pull request #680 from sar009/679
Adding support for custom bigquery client credentials
2 parents 74c53ff + 506c7ae commit fef52d0

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

data_diff/sqeleton/databases/_connect.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def load_mixins(self, *abstract_mixins: AbstractMixin) -> Self:
106106
database_by_scheme = {k: db.load_mixins(*abstract_mixins) for k, db in self.database_by_scheme.items()}
107107
return type(self)(database_by_scheme)
108108

109-
def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Database:
109+
def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs) -> Database:
110110
"""Connect to the given database uri
111111
112112
thread_count determines the max number of worker threads per database,
@@ -149,7 +149,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
149149
conn_dict = config["database"][database]
150150
except KeyError:
151151
raise ValueError(f"Cannot find database config named '{database}'.")
152-
return self.connect_with_dict(conn_dict, thread_count)
152+
return self.connect_with_dict(conn_dict, thread_count, **kwargs)
153153

154154
try:
155155
matcher = self.match_uri_path[scheme]
@@ -174,7 +174,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
174174

175175
if scheme == "bigquery":
176176
kw["project"] = dsn.host
177-
return cls(**kw)
177+
return cls(**kw, **kwargs)
178178

179179
if scheme == "snowflake":
180180
kw["account"] = dsn.host
@@ -194,13 +194,13 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
194194
kw = {k: v for k, v in kw.items() if v is not None}
195195

196196
if issubclass(cls, ThreadedDatabase):
197-
db = cls(thread_count=thread_count, **kw)
197+
db = cls(thread_count=thread_count, **kw, **kwargs)
198198
else:
199-
db = cls(**kw)
199+
db = cls(**kw, **kwargs)
200200

201201
return self._connection_created(db)
202202

203-
def connect_with_dict(self, d, thread_count):
203+
def connect_with_dict(self, d, thread_count, **kwargs):
204204
d = dict(d)
205205
driver = d.pop("driver")
206206
try:
@@ -210,17 +210,19 @@ def connect_with_dict(self, d, thread_count):
210210

211211
cls = matcher.database_cls
212212
if issubclass(cls, ThreadedDatabase):
213-
db = cls(thread_count=thread_count, **d)
213+
db = cls(thread_count=thread_count, **d, **kwargs)
214214
else:
215-
db = cls(**d)
215+
db = cls(**d, **kwargs)
216216

217217
return self._connection_created(db)
218218

219219
def _connection_created(self, db):
220220
"Nop function to be overridden by subclasses."
221221
return db
222222

223-
def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True) -> Database:
223+
def __call__(
224+
self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True, **kwargs
225+
) -> Database:
224226
"""Connect to a database using the given database configuration.
225227
226228
Configuration can be given either as a URI string, or as a dict of {option: value}.
@@ -234,6 +236,8 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, s
234236
db_conf (str | dict): The configuration for the database to connect. URI or dict.
235237
thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1)
236238
shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True)
239+
bigquery_credentials (google.oauth2.credentials.Credentials): Custom Google oAuth2 credential for BigQuery.
240+
(default: None)
237241
238242
Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.
239243
@@ -263,9 +267,9 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, s
263267
return conn
264268

265269
if isinstance(db_conf, str):
266-
conn = self.connect_to_uri(db_conf, thread_count)
270+
conn = self.connect_to_uri(db_conf, thread_count, **kwargs)
267271
elif isinstance(db_conf, dict):
268-
conn = self.connect_with_dict(db_conf, thread_count)
272+
conn = self.connect_with_dict(db_conf, thread_count, **kwargs)
269273
else:
270274
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")
271275

data_diff/sqeleton/databases/bigquery.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ class BigQuery(Database):
210210
CONNECT_URI_PARAMS = ["dataset"]
211211
dialect = Dialect()
212212

213-
def __init__(self, project, *, dataset, **kw):
214-
credentials = None
213+
def __init__(self, project, *, dataset, bigquery_credentials=None, **kw):
214+
credentials = bigquery_credentials
215215
bigquery = import_bigquery()
216216

217217
keyfile = kw.pop("keyfile", None)

0 commit comments

Comments
 (0)