@@ -106,7 +106,7 @@ def load_mixins(self, *abstract_mixins: AbstractMixin) -> Self:
106
106
database_by_scheme = {k : db .load_mixins (* abstract_mixins ) for k , db in self .database_by_scheme .items ()}
107
107
return type (self )(database_by_scheme )
108
108
109
- def connect_to_uri (self , db_uri : str , thread_count : Optional [int ] = 1 ) -> Database :
109
+ def connect_to_uri (self , db_uri : str , thread_count : Optional [int ] = 1 , ** kwargs ) -> Database :
110
110
"""Connect to the given database uri
111
111
112
112
thread_count determines the max number of worker threads per database,
@@ -149,7 +149,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
149
149
conn_dict = config ["database" ][database ]
150
150
except KeyError :
151
151
raise ValueError (f"Cannot find database config named '{ database } '." )
152
- return self .connect_with_dict (conn_dict , thread_count )
152
+ return self .connect_with_dict (conn_dict , thread_count , ** kwargs )
153
153
154
154
try :
155
155
matcher = self .match_uri_path [scheme ]
@@ -174,7 +174,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
174
174
175
175
if scheme == "bigquery" :
176
176
kw ["project" ] = dsn .host
177
- return cls (** kw )
177
+ return cls (** kw , ** kwargs )
178
178
179
179
if scheme == "snowflake" :
180
180
kw ["account" ] = dsn .host
@@ -194,13 +194,13 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
194
194
kw = {k : v for k , v in kw .items () if v is not None }
195
195
196
196
if issubclass (cls , ThreadedDatabase ):
197
- db = cls (thread_count = thread_count , ** kw )
197
+ db = cls (thread_count = thread_count , ** kw , ** kwargs )
198
198
else :
199
- db = cls (** kw )
199
+ db = cls (** kw , ** kwargs )
200
200
201
201
return self ._connection_created (db )
202
202
203
- def connect_with_dict (self , d , thread_count ):
203
+ def connect_with_dict (self , d , thread_count , ** kwargs ):
204
204
d = dict (d )
205
205
driver = d .pop ("driver" )
206
206
try :
@@ -210,17 +210,19 @@ def connect_with_dict(self, d, thread_count):
210
210
211
211
cls = matcher .database_cls
212
212
if issubclass (cls , ThreadedDatabase ):
213
- db = cls (thread_count = thread_count , ** d )
213
+ db = cls (thread_count = thread_count , ** d , ** kwargs )
214
214
else :
215
- db = cls (** d )
215
+ db = cls (** d , ** kwargs )
216
216
217
217
return self ._connection_created (db )
218
218
219
219
def _connection_created (self , db ):
220
220
"Nop function to be overridden by subclasses."
221
221
return db
222
222
223
- def __call__ (self , db_conf : Union [str , dict ], thread_count : Optional [int ] = 1 , shared : bool = True ) -> Database :
223
+ def __call__ (
224
+ self , db_conf : Union [str , dict ], thread_count : Optional [int ] = 1 , shared : bool = True , ** kwargs
225
+ ) -> Database :
224
226
"""Connect to a database using the given database configuration.
225
227
226
228
Configuration can be given either as a URI string, or as a dict of {option: value}.
@@ -234,6 +236,8 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, s
234
236
db_conf (str | dict): The configuration for the database to connect. URI or dict.
235
237
thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1)
236
238
shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True)
239
+ bigquery_credentials (google.oauth2.credentials.Credentials): Custom Google oAuth2 credential for BigQuery.
240
+ (default: None)
237
241
238
242
Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.
239
243
@@ -263,9 +267,9 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, s
263
267
return conn
264
268
265
269
if isinstance (db_conf , str ):
266
- conn = self .connect_to_uri (db_conf , thread_count )
270
+ conn = self .connect_to_uri (db_conf , thread_count , ** kwargs )
267
271
elif isinstance (db_conf , dict ):
268
- conn = self .connect_with_dict (db_conf , thread_count )
272
+ conn = self .connect_with_dict (db_conf , thread_count , ** kwargs )
269
273
else :
270
274
raise TypeError (f"db configuration must be a URI string or a dictionary. Instead got '{ db_conf } '." )
271
275
0 commit comments