Skip to content

feat: Allow jobs to be run in a different project #1180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions sqlalchemy_bigquery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from google import auth
import google.api_core.exceptions
from google.cloud.bigquery import dbapi
from google.cloud.bigquery import dbapi, ConnectionProperty
from google.cloud.bigquery.table import (
RangePartitioning,
TableReference,
Expand Down Expand Up @@ -61,6 +61,7 @@
from .parse_url import parse_url
from . import _helpers, _struct, _types
import sqlalchemy_bigquery_vendored.sqlalchemy.postgresql.base as vendored_postgresql
from google.cloud.bigquery import QueryJobConfig

# Illegal characters is intended to be all characters that are not explicitly
# allowed as part of the flexible column names.
Expand Down Expand Up @@ -1080,6 +1081,7 @@ def __init__(
self,
arraysize=5000,
credentials_path=None,
billing_project_id=None,
location=None,
credentials_info=None,
credentials_base64=None,
Expand All @@ -1092,6 +1094,8 @@ def __init__(
self.credentials_path = credentials_path
self.credentials_info = credentials_info
self.credentials_base64 = credentials_base64
self.project_id = None
self.billing_project_id = billing_project_id
self.location = location
self.identifier_preparer = self.preparer(self)
self.dataset_id = None
Expand All @@ -1114,15 +1118,20 @@ def _build_formatted_table_id(table):
"""Build '<dataset_id>.<table_id>' string using given table."""
return "{}.{}".format(table.reference.dataset_id, table.table_id)

@staticmethod
def _add_default_dataset_to_job_config(job_config, project_id, dataset_id):
# If dataset_id is set, then we know the job_config isn't None
if dataset_id:
# If project_id is missing, use default project_id for the current environment
def create_job_config(self, provided_config: QueryJobConfig):
project_id = self.project_id
if self.dataset_id is None and project_id == self.billing_project_id:
return provided_config
job_config = provided_config or QueryJobConfig()
if project_id != self.billing_project_id:
job_config.connection_properties = [
ConnectionProperty(key="dataset_project_id", value=project_id)
]
if self.dataset_id:
if not project_id:
_, project_id = auth.default()

job_config.default_dataset = "{}.{}".format(project_id, dataset_id)
job_config.default_dataset = "{}.{}".format(project_id, self.dataset_id)
return job_config

def do_execute(self, cursor, statement, parameters, context=None):
kwargs = {}
Expand All @@ -1132,13 +1141,13 @@ def do_execute(self, cursor, statement, parameters, context=None):

def create_connect_args(self, url):
(
project_id,
self.project_id,
location,
dataset_id,
arraysize,
credentials_path,
credentials_base64,
default_query_job_config,
provided_job_config,
list_tables_page_size,
user_supplied_client,
) = parse_url(url)
Expand All @@ -1149,9 +1158,9 @@ def create_connect_args(self, url):
self.credentials_path = credentials_path or self.credentials_path
self.credentials_base64 = credentials_base64 or self.credentials_base64
self.dataset_id = dataset_id
self._add_default_dataset_to_job_config(
default_query_job_config, project_id, dataset_id
)
self.billing_project_id = self.billing_project_id or self.project_id

default_query_job_config = self.create_job_config(provided_job_config)

if user_supplied_client:
# The user is expected to supply a client with
Expand All @@ -1162,10 +1171,14 @@ def create_connect_args(self, url):
credentials_path=self.credentials_path,
credentials_info=self.credentials_info,
credentials_base64=self.credentials_base64,
project_id=project_id,
project_id=self.billing_project_id,
location=self.location,
default_query_job_config=default_query_job_config,
)
# If the user specified `bigquery://` we need to set the project_id
# from the client
self.project_id = self.project_id or client.project
self.billing_project_id = self.billing_project_id or client.project
return ([], {"client": client})

def _get_table_or_view_names(self, connection, item_types, schema=None):
Expand All @@ -1177,7 +1190,7 @@ def _get_table_or_view_names(self, connection, item_types, schema=None):
)

client = connection.connection._client
datasets = client.list_datasets()
datasets = client.list_datasets(self.project_id)

result = []
for dataset in datasets:
Expand Down Expand Up @@ -1278,7 +1291,8 @@ def _get_table(self, connection, table_name, schema=None):

client = connection.connection._client

table_ref = self._table_reference(schema, table_name, client.project)
# table_ref = self._table_reference(schema, table_name, client.project)
table_ref = self._table_reference(schema, table_name, self.project_id)
try:
table = client.get_table(table_ref)
except NotFound:
Expand Down Expand Up @@ -1332,7 +1346,7 @@ def get_schema_names(self, connection, **kw):
if isinstance(connection, Engine):
connection = connection.connect()

datasets = connection.connection._client.list_datasets()
datasets = connection.connection._client.list_datasets(self.project_id)
return [d.dataset_id for d in datasets]

def get_table_names(self, connection, schema=None, **kw):
Expand Down
107 changes: 107 additions & 0 deletions tests/system/test_sqlalchemy_bigquery_remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2017 The sqlalchemy-bigquery Authors
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

# -*- coding: utf-8 -*-

from sqlalchemy.engine import create_engine
from sqlalchemy.exc import DatabaseError
from sqlalchemy.schema import Table, MetaData
import pytest
import sqlalchemy
import google.api_core.exceptions as core_exceptions


EXPECTED_STATES = ["AL", "CA", "FL", "KY"]

REMOTE_TESTS = [
("bigquery-public-data", "bigquery-public-data.usa_names.usa_1910_2013"),
("bigquery-public-data", "usa_names.usa_1910_2013"),
("bigquery-public-data/usa_names", "bigquery-public-data.usa_names.usa_1910_2013"),
("bigquery-public-data/usa_names", "usa_1910_2013"),
("bigquery-public-data/usa_names", "usa_names.usa_1910_2013"),
]


@pytest.fixture(scope="session")
def engine_using_remote_dataset(bigquery_client):
engine = create_engine(
"bigquery://bigquery-public-data/usa_names",
billing_project_id=bigquery_client.project,
echo=True,
)
return engine


def test_remote_tables_list(engine_using_remote_dataset):
tables = sqlalchemy.inspect(engine_using_remote_dataset).get_table_names()
assert "usa_1910_2013" in tables


@pytest.mark.parametrize(
["urlpath", "table_name"],
REMOTE_TESTS,
ids=[f"test_engine_remote_sql_{x}" for x in range(len(REMOTE_TESTS))],
)
def test_engine_remote_sql(bigquery_client, urlpath, table_name):
engine = create_engine(
f"bigquery://{urlpath}", billing_project_id=bigquery_client.project, echo=True
)
with engine.connect() as conn:
rows = conn.execute(
sqlalchemy.text(f"SELECT DISTINCT(state) FROM `{table_name}`")
).fetchall()
states = set(map(lambda row: row[0], rows))
assert set(EXPECTED_STATES).issubset(states)


@pytest.mark.parametrize(
["urlpath", "table_name"],
REMOTE_TESTS,
ids=[f"test_engine_remote_table_{x}" for x in range(len(REMOTE_TESTS))],
)
def test_engine_remote_table(bigquery_client, urlpath, table_name):
engine = create_engine(
f"bigquery://{urlpath}", billing_project_id=bigquery_client.project, echo=True
)
with engine.connect() as conn:
table = Table(table_name, MetaData(), autoload_with=engine)
prepared = sqlalchemy.select(
sqlalchemy.distinct(table.c.state)
).set_label_style(sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL)
rows = conn.execute(prepared).fetchall()
states = set(map(lambda row: row[0], rows))
assert set(EXPECTED_STATES).issubset(states)


@pytest.mark.parametrize(
["urlpath", "table_name"],
REMOTE_TESTS,
ids=[f"test_engine_remote_table_fail_{x}" for x in range(len(REMOTE_TESTS))],
)
def test_engine_remote_table_fail(urlpath, table_name):
engine = create_engine(f"bigquery://{urlpath}", echo=True)
with pytest.raises(
(DatabaseError, core_exceptions.Forbidden), match="Access Denied"
):
with engine.connect() as conn:
table = Table(table_name, MetaData(), autoload_with=engine)
prepared = sqlalchemy.select(
sqlalchemy.distinct(table.c.state)
).set_label_style(sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL)
conn.execute(prepared).fetchall()
21 changes: 13 additions & 8 deletions tests/unit/fauxdbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,12 @@ def _fix_pickled(self, row):
pickle.loads(v.encode("latin1"))
# \x80\x04 is latin-1 encoded prefix for Pickle protocol 4.
if isinstance(v, str) and v[:2] == "\x80\x04" and v[-1] == "."
else pickle.loads(base64.b16decode(v))
# 8004 is base64 encoded prefix for Pickle protocol 4.
if isinstance(v, str) and v[:4] == "8004" and v[-2:] == "2E"
else v
else (
pickle.loads(base64.b16decode(v))
# 8004 is base64 encoded prefix for Pickle protocol 4.
if isinstance(v, str) and v[:4] == "8004" and v[-2:] == "2E"
else v
)
)
for d, v in zip(self.description, row)
]
Expand All @@ -355,7 +357,10 @@ def __getattr__(self, name):
class FauxClient:
def __init__(self, project_id=None, default_query_job_config=None, *args, **kw):
if project_id is None:
if default_query_job_config is not None:
if (
default_query_job_config is not None
and default_query_job_config.default_dataset
):
project_id = default_query_job_config.default_dataset.project
else:
project_id = "authproj" # we would still have gotten it from auth.
Expand Down Expand Up @@ -469,10 +474,10 @@ def get_table(self, table_ref):
else:
raise google.api_core.exceptions.NotFound(table_ref)

def list_datasets(self):
def list_datasets(self, project="myproject"):
return [
google.cloud.bigquery.Dataset("myproject.mydataset"),
google.cloud.bigquery.Dataset("myproject.yourdataset"),
google.cloud.bigquery.Dataset(f"{project}.mydataset"),
google.cloud.bigquery.Dataset(f"{project}.yourdataset"),
]

def list_tables(self, dataset, page_size):
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def test_engine_dataset_but_no_project(faux_conn):
assert conn.connection._client.project == "authproj"


def test_engine_dataset_with_billing_project(faux_conn):
engine = sqlalchemy.create_engine("bigquery://foo", billing_project_id="bar")
conn = engine.connect()
assert conn.connection._client.project == "bar"


def test_engine_no_dataset_no_project(faux_conn):
engine = sqlalchemy.create_engine("bigquery://")
conn = engine.connect()
Expand Down