Skip to content

feat: support bigquery connection in managed function #1554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions bigframes/functions/_function_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def provision_bq_managed_function(
name,
packages,
is_row_processor,
bq_connection_id,
*,
capture_references=False,
):
Expand Down Expand Up @@ -273,12 +274,21 @@ def provision_bq_managed_function(
udf_code = textwrap.dedent(inspect.getsource(func))
udf_code = udf_code[udf_code.index("def") :]

with_connection_clause = (
(
f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{self._bq_connection_id}`"
)
if bq_connection_id
else ""
)

create_function_ddl = (
textwrap.dedent(
f"""
CREATE OR REPLACE FUNCTION {persistent_func_id}({','.join(bq_function_args)})
RETURNS {bq_function_return_type}
LANGUAGE python
{with_connection_clause}
OPTIONS ({managed_function_options_str})
AS r'''
__UDF_PLACE_HOLDER__
Expand Down
11 changes: 8 additions & 3 deletions bigframes/functions/_function_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,9 +807,13 @@ def udf(

bq_location, _ = _utils.get_remote_function_locations(bigquery_client.location)

# A connection is required for BQ managed function.
bq_connection_id = self._resolve_bigquery_connection_id(
session, dataset_ref, bq_location, bigquery_connection
# A connection is optional for BQ managed function.
bq_connection_id = (
self._resolve_bigquery_connection_id(
session, dataset_ref, bq_location, bigquery_connection
)
if bigquery_connection
else None
)

bq_connection_manager = session.bqconnectionmanager
Expand Down Expand Up @@ -907,6 +911,7 @@ def wrapper(func):
name=name,
packages=packages,
is_row_processor=is_row_processor,
bq_connection_id=bq_connection_id,
)

# TODO(shobs): Find a better way to support udfs with param named
Expand Down
9 changes: 7 additions & 2 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,13 @@ def session_tokyo(tokyo_location: str) -> Generator[bigframes.Session, None, Non


@pytest.fixture(scope="session")
def bq_connection(bigquery_client: bigquery.Client) -> str:
return f"{bigquery_client.project}.{bigquery_client.location}.bigframes-rf-conn"
def bq_connection_name() -> str:
return "bigframes-rf-conn"


@pytest.fixture(scope="session")
def bq_connection(bigquery_client: bigquery.Client, bq_connection_name: str) -> str:
return f"{bigquery_client.project}.{bigquery_client.location}.{bq_connection_name}"


@pytest.fixture(scope="session", autouse=True)
Expand Down
48 changes: 43 additions & 5 deletions tests/system/large/functions/test_managed_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,7 @@ def featurize(x: int) -> list[float]:
cleanup_function_assets(featurize, session.bqclient, ignore_failures=False)


def test_managed_function_series_apply(
session,
scalars_dfs,
):
def test_managed_function_series_apply(session, scalars_dfs):
try:

@session.udf()
Expand Down Expand Up @@ -504,7 +501,10 @@ def test_managed_function_dataframe_apply_axis_1_array_output(session):

try:

@session.udf(input_types=[int, float, str], output_type=list[str])
@session.udf(
input_types=[int, float, str],
output_type=list[str],
)
def foo(x, y, z):
return [str(x), str(y), z]

Expand Down Expand Up @@ -587,3 +587,41 @@ def foo(x, y, z):
finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(foo, session.bqclient, ignore_failures=False)


@pytest.mark.parametrize(
"connection_fixture",
[
"bq_connection_name",
"bq_connection",
],
)
def test_managed_function_with_connection(
session, scalars_dfs, request, connection_fixture
):
try:
bigquery_connection = request.getfixturevalue(connection_fixture)

@session.udf(bigquery_connection=bigquery_connection)
def foo(x: int) -> int:
return x + 10

# Function should still work normally.
assert foo(-2) == 8

scalars_df, scalars_pandas_df = scalars_dfs

bf_result_col = scalars_df["int64_too"].apply(foo)
bf_result = (
scalars_df["int64_too"].to_frame().assign(result=bf_result_col).to_pandas()
)

pd_result_col = scalars_pandas_df["int64_too"].apply(foo)
pd_result = (
scalars_pandas_df["int64_too"].to_frame().assign(result=pd_result_col)
)

pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(foo, session.bqclient, ignore_failures=False)