diff --git a/bigframes/functions/_function_client.py b/bigframes/functions/_function_client.py index 00d7244d59..8a591f6916 100644 --- a/bigframes/functions/_function_client.py +++ b/bigframes/functions/_function_client.py @@ -196,6 +196,7 @@ def provision_bq_managed_function( name, packages, is_row_processor, + bq_connection_id, *, capture_references=False, ): @@ -273,12 +274,21 @@ def provision_bq_managed_function( udf_code = textwrap.dedent(inspect.getsource(func)) udf_code = udf_code[udf_code.index("def") :] + with_connection_clause = ( + ( + f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{self._bq_connection_id}`" + ) + if bq_connection_id + else "" + ) + create_function_ddl = ( textwrap.dedent( f""" CREATE OR REPLACE FUNCTION {persistent_func_id}({','.join(bq_function_args)}) RETURNS {bq_function_return_type} LANGUAGE python + {with_connection_clause} OPTIONS ({managed_function_options_str}) AS r''' __UDF_PLACE_HOLDER__ diff --git a/bigframes/functions/_function_session.py b/bigframes/functions/_function_session.py index b42ddefbee..4f4d306317 100644 --- a/bigframes/functions/_function_session.py +++ b/bigframes/functions/_function_session.py @@ -807,9 +807,13 @@ def udf( bq_location, _ = _utils.get_remote_function_locations(bigquery_client.location) - # A connection is required for BQ managed function. - bq_connection_id = self._resolve_bigquery_connection_id( - session, dataset_ref, bq_location, bigquery_connection + # A connection is optional for BQ managed function. + bq_connection_id = ( + self._resolve_bigquery_connection_id( + session, dataset_ref, bq_location, bigquery_connection + ) + if bigquery_connection + else None ) bq_connection_manager = session.bqconnectionmanager @@ -907,6 +911,7 @@ def wrapper(func): name=name, packages=packages, is_row_processor=is_row_processor, + bq_connection_id=bq_connection_id, ) # TODO(shobs): Find a better way to support udfs with param named diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 398ee8a6b2..ce984f5ce4 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -185,8 +185,13 @@ def session_tokyo(tokyo_location: str) -> Generator[bigframes.Session, None, Non @pytest.fixture(scope="session") -def bq_connection(bigquery_client: bigquery.Client) -> str: - return f"{bigquery_client.project}.{bigquery_client.location}.bigframes-rf-conn" +def bq_connection_name() -> str: + return "bigframes-rf-conn" + + +@pytest.fixture(scope="session") +def bq_connection(bigquery_client: bigquery.Client, bq_connection_name: str) -> str: + return f"{bigquery_client.project}.{bigquery_client.location}.{bq_connection_name}" @pytest.fixture(scope="session", autouse=True) diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index eabafd96fb..831ab71be7 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -166,10 +166,7 @@ def featurize(x: int) -> list[float]: cleanup_function_assets(featurize, session.bqclient, ignore_failures=False) -def test_managed_function_series_apply( - session, - scalars_dfs, -): +def test_managed_function_series_apply(session, scalars_dfs): try: @session.udf() @@ -504,7 +501,10 @@ def test_managed_function_dataframe_apply_axis_1_array_output(session): try: - @session.udf(input_types=[int, float, str], output_type=list[str]) + @session.udf( + input_types=[int, float, str], + output_type=list[str], + ) def foo(x, y, z): return [str(x), str(y), z] @@ -587,3 +587,41 @@ def foo(x, y, z): finally: # Clean up the gcp assets created for the managed function. cleanup_function_assets(foo, session.bqclient, ignore_failures=False) + + +@pytest.mark.parametrize( + "connection_fixture", + [ + "bq_connection_name", + "bq_connection", + ], +) +def test_managed_function_with_connection( + session, scalars_dfs, request, connection_fixture +): + try: + bigquery_connection = request.getfixturevalue(connection_fixture) + + @session.udf(bigquery_connection=bigquery_connection) + def foo(x: int) -> int: + return x + 10 + + # Function should still work normally. + assert foo(-2) == 8 + + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result_col = scalars_df["int64_too"].apply(foo) + bf_result = ( + scalars_df["int64_too"].to_frame().assign(result=bf_result_col).to_pandas() + ) + + pd_result_col = scalars_pandas_df["int64_too"].apply(foo) + pd_result = ( + scalars_pandas_df["int64_too"].to_frame().assign(result=pd_result_col) + ) + + pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + finally: + # Clean up the gcp assets created for the managed function. + cleanup_function_assets(foo, session.bqclient, ignore_failures=False)