diff --git a/data_diff/sqeleton/databases/_connect.py b/data_diff/sqeleton/databases/_connect.py index 04baa413..2d2314fa 100644 --- a/data_diff/sqeleton/databases/_connect.py +++ b/data_diff/sqeleton/databases/_connect.py @@ -106,7 +106,7 @@ def load_mixins(self, *abstract_mixins: AbstractMixin) -> Self: database_by_scheme = {k: db.load_mixins(*abstract_mixins) for k, db in self.database_by_scheme.items()} return type(self)(database_by_scheme) - def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Database: + def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs) -> Database: """Connect to the given database uri thread_count determines the max number of worker threads per database, @@ -149,7 +149,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa conn_dict = config["database"][database] except KeyError: raise ValueError(f"Cannot find database config named '{database}'.") - return self.connect_with_dict(conn_dict, thread_count) + return self.connect_with_dict(conn_dict, thread_count, **kwargs) try: matcher = self.match_uri_path[scheme] @@ -174,7 +174,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa if scheme == "bigquery": kw["project"] = dsn.host - return cls(**kw) + return cls(**kw, **kwargs) if scheme == "snowflake": kw["account"] = dsn.host @@ -194,13 +194,13 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa kw = {k: v for k, v in kw.items() if v is not None} if issubclass(cls, ThreadedDatabase): - db = cls(thread_count=thread_count, **kw) + db = cls(thread_count=thread_count, **kw, **kwargs) else: - db = cls(**kw) + db = cls(**kw, **kwargs) return self._connection_created(db) - def connect_with_dict(self, d, thread_count): + def connect_with_dict(self, d, thread_count, **kwargs): d = dict(d) driver = d.pop("driver") try: @@ -210,9 +210,9 @@ def connect_with_dict(self, d, thread_count): cls = matcher.database_cls if issubclass(cls, ThreadedDatabase): - db = cls(thread_count=thread_count, **d) + db = cls(thread_count=thread_count, **d, **kwargs) else: - db = cls(**d) + db = cls(**d, **kwargs) return self._connection_created(db) @@ -220,7 +220,9 @@ def _connection_created(self, db): "Nop function to be overridden by subclasses." return db - def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True) -> Database: + def __call__( + self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True, **kwargs + ) -> Database: """Connect to a database using the given database configuration. Configuration can be given either as a URI string, or as a dict of {option: value}. @@ -234,6 +236,8 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, s db_conf (str | dict): The configuration for the database to connect. URI or dict. thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True) + bigquery_credentials (google.oauth2.credentials.Credentials): Custom Google oAuth2 credential for BigQuery. + (default: None) Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. @@ -263,9 +267,9 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, s return conn if isinstance(db_conf, str): - conn = self.connect_to_uri(db_conf, thread_count) + conn = self.connect_to_uri(db_conf, thread_count, **kwargs) elif isinstance(db_conf, dict): - conn = self.connect_with_dict(db_conf, thread_count) + conn = self.connect_with_dict(db_conf, thread_count, **kwargs) else: raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.") diff --git a/data_diff/sqeleton/databases/bigquery.py b/data_diff/sqeleton/databases/bigquery.py index d54e541b..2c0a57ca 100644 --- a/data_diff/sqeleton/databases/bigquery.py +++ b/data_diff/sqeleton/databases/bigquery.py @@ -210,8 +210,8 @@ class BigQuery(Database): CONNECT_URI_PARAMS = ["dataset"] dialect = Dialect() - def __init__(self, project, *, dataset, **kw): - credentials = None + def __init__(self, project, *, dataset, bigquery_credentials=None, **kw): + credentials = bigquery_credentials bigquery = import_bigquery() keyfile = kw.pop("keyfile", None)