From 11fc17eecc8416a7f56f6bbfd434deceba7953bc Mon Sep 17 00:00:00 2001 From: Sarad Mohanan Date: Tue, 15 Aug 2023 00:23:26 +0530 Subject: [PATCH 1/2] adding support for custom bigquery client credentials Signed-off-by: Sarad Mohanan --- data_diff/sqeleton/databases/_connect.py | 24 +++++++++++++----------- data_diff/sqeleton/databases/bigquery.py | 4 ++-- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/data_diff/sqeleton/databases/_connect.py b/data_diff/sqeleton/databases/_connect.py index 04baa413..300a40f3 100644 --- a/data_diff/sqeleton/databases/_connect.py +++ b/data_diff/sqeleton/databases/_connect.py @@ -106,7 +106,7 @@ def load_mixins(self, *abstract_mixins: AbstractMixin) -> Self: database_by_scheme = {k: db.load_mixins(*abstract_mixins) for k, db in self.database_by_scheme.items()} return type(self)(database_by_scheme) - def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Database: + def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs) -> Database: """Connect to the given database uri thread_count determines the max number of worker threads per database, @@ -149,7 +149,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa conn_dict = config["database"][database] except KeyError: raise ValueError(f"Cannot find database config named '{database}'.") - return self.connect_with_dict(conn_dict, thread_count) + return self.connect_with_dict(conn_dict, thread_count, **kwargs) try: matcher = self.match_uri_path[scheme] @@ -174,7 +174,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa if scheme == "bigquery": kw["project"] = dsn.host - return cls(**kw) + return cls(**kw, **kwargs) if scheme == "snowflake": kw["account"] = dsn.host @@ -194,13 +194,13 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa kw = {k: v for k, v in kw.items() if v is not None} if issubclass(cls, ThreadedDatabase): - db = cls(thread_count=thread_count, **kw) + db = cls(thread_count=thread_count, **kw, **kwargs) else: - db = cls(**kw) + db = cls(**kw, **kwargs) return self._connection_created(db) - def connect_with_dict(self, d, thread_count): + def connect_with_dict(self, d, thread_count, **kwargs): d = dict(d) driver = d.pop("driver") try: @@ -210,9 +210,9 @@ def connect_with_dict(self, d, thread_count): cls = matcher.database_cls if issubclass(cls, ThreadedDatabase): - db = cls(thread_count=thread_count, **d) + db = cls(thread_count=thread_count, **d, **kwargs) else: - db = cls(**d) + db = cls(**d, **kwargs) return self._connection_created(db) @@ -220,7 +220,9 @@ def _connection_created(self, db): "Nop function to be overridden by subclasses." return db - def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True) -> Database: + def __call__( + self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True, **kwargs + ) -> Database: """Connect to a database using the given database configuration. Configuration can be given either as a URI string, or as a dict of {option: value}. @@ -263,9 +265,9 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, s return conn if isinstance(db_conf, str): - conn = self.connect_to_uri(db_conf, thread_count) + conn = self.connect_to_uri(db_conf, thread_count, **kwargs) elif isinstance(db_conf, dict): - conn = self.connect_with_dict(db_conf, thread_count) + conn = self.connect_with_dict(db_conf, thread_count, **kwargs) else: raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.") diff --git a/data_diff/sqeleton/databases/bigquery.py b/data_diff/sqeleton/databases/bigquery.py index d54e541b..2c0a57ca 100644 --- a/data_diff/sqeleton/databases/bigquery.py +++ b/data_diff/sqeleton/databases/bigquery.py @@ -210,8 +210,8 @@ class BigQuery(Database): CONNECT_URI_PARAMS = ["dataset"] dialect = Dialect() - def __init__(self, project, *, dataset, **kw): - credentials = None + def __init__(self, project, *, dataset, bigquery_credentials=None, **kw): + credentials = bigquery_credentials bigquery = import_bigquery() keyfile = kw.pop("keyfile", None) From ea069924aa5ebeb3caec2aecb278ccc51e482b38 Mon Sep 17 00:00:00 2001 From: Sarad Mohanan Date: Tue, 15 Aug 2023 00:30:11 +0530 Subject: [PATCH 2/2] updating docs Signed-off-by: Sarad Mohanan --- data_diff/sqeleton/databases/_connect.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/data_diff/sqeleton/databases/_connect.py b/data_diff/sqeleton/databases/_connect.py index 300a40f3..2d2314fa 100644 --- a/data_diff/sqeleton/databases/_connect.py +++ b/data_diff/sqeleton/databases/_connect.py @@ -236,6 +236,8 @@ def __call__( db_conf (str | dict): The configuration for the database to connect. URI or dict. thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True) + bigquery_credentials (google.oauth2.credentials.Credentials): Custom Google oAuth2 credential for BigQuery. + (default: None) Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.