From 79ee615a8129280e50c7971d51cb8f5960722f5d Mon Sep 17 00:00:00 2001 From: jialuo Date: Thu, 27 Mar 2025 19:16:07 +0000 Subject: [PATCH 1/3] feat: support bigquery connection in managed function --- bigframes/functions/_function_client.py | 9 ++++ bigframes/functions/_function_session.py | 3 ++ .../large/functions/test_managed_function.py | 42 +++++++++++++------ 3 files changed, 42 insertions(+), 12 deletions(-) diff --git a/bigframes/functions/_function_client.py b/bigframes/functions/_function_client.py index 00d7244d59..bc00524458 100644 --- a/bigframes/functions/_function_client.py +++ b/bigframes/functions/_function_client.py @@ -196,6 +196,7 @@ def provision_bq_managed_function( name, packages, is_row_processor, + bq_connection_id, *, capture_references=False, ): @@ -273,12 +274,20 @@ def provision_bq_managed_function( udf_code = textwrap.dedent(inspect.getsource(func)) udf_code = udf_code[udf_code.index("def") :] + bq_connection_str = "" + if bq_connection_id: + bq_connection_str = ( + f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}." + f"{self._bq_connection_id}`" + ) + create_function_ddl = ( textwrap.dedent( f""" CREATE OR REPLACE FUNCTION {persistent_func_id}({','.join(bq_function_args)}) RETURNS {bq_function_return_type} LANGUAGE python + {bq_connection_str} OPTIONS ({managed_function_options_str}) AS r''' __UDF_PLACE_HOLDER__ diff --git a/bigframes/functions/_function_session.py b/bigframes/functions/_function_session.py index b42ddefbee..397185f697 100644 --- a/bigframes/functions/_function_session.py +++ b/bigframes/functions/_function_session.py @@ -907,6 +907,9 @@ def wrapper(func): name=name, packages=packages, is_row_processor=is_row_processor, + # If bigquery_connection is not provided, bq_connection_id can + # be ignored in provision_bq_managed_function. + bq_connection_id=bq_connection_id if bigquery_connection else None, ) # TODO(shobs): Find a better way to support udfs with param named diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index eabafd96fb..a6648bf381 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -29,12 +29,23 @@ ) +@pytest.fixture(scope="module") +def bq_cf_connection() -> str: + """Pre-created BQ connection in the test project in US location, used to + invoke cloud function. + + $ bq show --connection --location=us --project_id=PROJECT_ID bigframes-rf-conn + """ + return "bigframes-rf-conn" + + def test_managed_function_multiply_with_ibis( session, scalars_table_id, bigquery_client, ibis_client, dataset_id, + bq_cf_connection, ): try: @@ -43,6 +54,7 @@ def test_managed_function_multiply_with_ibis( input_types=[int, int], output_type=int, dataset=dataset_id, + bigquery_connection=bq_cf_connection, ) def multiply(x, y): return x * y @@ -126,10 +138,12 @@ def stringify(x): cleanup_function_assets(stringify, bigquery_client, ignore_failures=False) -def test_managed_function_array_output(session, scalars_dfs, dataset_id): +def test_managed_function_array_output( + session, scalars_dfs, dataset_id, bq_cf_connection +): try: - @session.udf(dataset=dataset_id) + @session.udf(dataset=dataset_id, bigquery_connection=bq_cf_connection) def featurize(x: int) -> list[float]: return [float(i) for i in [x, x + 1, x + 2]] @@ -166,13 +180,10 @@ def featurize(x: int) -> list[float]: cleanup_function_assets(featurize, session.bqclient, ignore_failures=False) -def test_managed_function_series_apply( - session, - scalars_dfs, -): +def test_managed_function_series_apply(session, scalars_dfs, bq_cf_connection): try: - @session.udf() + @session.udf(bigquery_connection=bq_cf_connection) def foo(x: int) -> bytes: return bytes(abs(x)) @@ -252,7 +263,7 @@ def foo_list(x: int) -> list[float]: cleanup_function_assets(foo_list, session.bqclient, ignore_failures=False) -def test_managed_function_series_combine(session, scalars_dfs): +def test_managed_function_series_combine(session, scalars_dfs, bq_cf_connection): try: # This function is deliberately written to not work with NA input. def add(x: int, y: int) -> int: @@ -267,7 +278,7 @@ def add(x: int, y: int) -> int: # make sure there are NA values in the test column. assert any([pandas.isna(val) for val in bf_df[int_col_name_with_nulls]]) - add_managed_func = session.udf()(add) + add_managed_func = session.udf(bigquery_connection=bq_cf_connection)(add) # with nulls in the series the managed function application would fail. with pytest.raises( @@ -373,7 +384,7 @@ def add_list(x: int, y: int) -> list[int]: ) -def test_managed_function_dataframe_map(session, scalars_dfs): +def test_managed_function_dataframe_map(session, scalars_dfs, bq_cf_connection): try: def add_one(x): @@ -382,6 +393,7 @@ def add_one(x): mf_add_one = session.udf( input_types=[int], output_type=int, + bigquery_connection=bq_cf_connection, )(add_one) scalars_df, scalars_pandas_df = scalars_dfs @@ -484,7 +496,9 @@ def add_ints(x, y): cleanup_function_assets(add_ints_mf, session.bqclient, ignore_failures=False) -def test_managed_function_dataframe_apply_axis_1_array_output(session): +def test_managed_function_dataframe_apply_axis_1_array_output( + session, bq_cf_connection +): bf_df = bigframes.dataframe.DataFrame( { "Id": [1, 2, 3], @@ -504,7 +518,11 @@ def test_managed_function_dataframe_apply_axis_1_array_output(session): try: - @session.udf(input_types=[int, float, str], output_type=list[str]) + @session.udf( + input_types=[int, float, str], + output_type=list[str], + bigquery_connection=bq_cf_connection, + ) def foo(x, y, z): return [str(x), str(y), z] From b5038ce9db73975cde4746455ac46029bdafe73f Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Fri, 28 Mar 2025 06:05:49 +0000 Subject: [PATCH 2/3] simplify a bit the intended change --- bigframes/functions/_function_client.py | 13 +-- bigframes/functions/_function_session.py | 12 ++- tests/system/conftest.py | 9 +- .../large/functions/test_managed_function.py | 82 ++++++++++++------- 4 files changed, 73 insertions(+), 43 deletions(-) diff --git a/bigframes/functions/_function_client.py b/bigframes/functions/_function_client.py index bc00524458..8a591f6916 100644 --- a/bigframes/functions/_function_client.py +++ b/bigframes/functions/_function_client.py @@ -274,12 +274,13 @@ def provision_bq_managed_function( udf_code = textwrap.dedent(inspect.getsource(func)) udf_code = udf_code[udf_code.index("def") :] - bq_connection_str = "" - if bq_connection_id: - bq_connection_str = ( - f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}." - f"{self._bq_connection_id}`" + with_connection_clause = ( + ( + f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{self._bq_connection_id}`" ) + if bq_connection_id + else "" + ) create_function_ddl = ( textwrap.dedent( @@ -287,7 +288,7 @@ def provision_bq_managed_function( CREATE OR REPLACE FUNCTION {persistent_func_id}({','.join(bq_function_args)}) RETURNS {bq_function_return_type} LANGUAGE python - {bq_connection_str} + {with_connection_clause} OPTIONS ({managed_function_options_str}) AS r''' __UDF_PLACE_HOLDER__ diff --git a/bigframes/functions/_function_session.py b/bigframes/functions/_function_session.py index 397185f697..36acac32c7 100644 --- a/bigframes/functions/_function_session.py +++ b/bigframes/functions/_function_session.py @@ -807,9 +807,13 @@ def udf( bq_location, _ = _utils.get_remote_function_locations(bigquery_client.location) - # A connection is required for BQ managed function. - bq_connection_id = self._resolve_bigquery_connection_id( - session, dataset_ref, bq_location, bigquery_connection + # A connection is optional for BQ managed function. + bq_connection_id = ( + self._resolve_bigquery_connection_id( + session, dataset_ref, bq_location, bigquery_connection + ) + if bigquery_connection + else None ) bq_connection_manager = session.bqconnectionmanager @@ -909,7 +913,7 @@ def wrapper(func): is_row_processor=is_row_processor, # If bigquery_connection is not provided, bq_connection_id can # be ignored in provision_bq_managed_function. - bq_connection_id=bq_connection_id if bigquery_connection else None, + bq_connection_id=bq_connection_id, ) # TODO(shobs): Find a better way to support udfs with param named diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 398ee8a6b2..ce984f5ce4 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -185,8 +185,13 @@ def session_tokyo(tokyo_location: str) -> Generator[bigframes.Session, None, Non @pytest.fixture(scope="session") -def bq_connection(bigquery_client: bigquery.Client) -> str: - return f"{bigquery_client.project}.{bigquery_client.location}.bigframes-rf-conn" +def bq_connection_name() -> str: + return "bigframes-rf-conn" + + +@pytest.fixture(scope="session") +def bq_connection(bigquery_client: bigquery.Client, bq_connection_name: str) -> str: + return f"{bigquery_client.project}.{bigquery_client.location}.{bq_connection_name}" @pytest.fixture(scope="session", autouse=True) diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index a6648bf381..aabd0f438f 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -22,21 +22,11 @@ import bigframes.pandas as bpd from tests.system.utils import cleanup_function_assets -# TODO(shobs): restore these tests after the managed udf cleanup issue is -# resolved in the test project -pytestmark = pytest.mark.skip( - reason="temporarily disable to debug managed udf cleanup in the test project" -) - - -@pytest.fixture(scope="module") -def bq_cf_connection() -> str: - """Pre-created BQ connection in the test project in US location, used to - invoke cloud function. - - $ bq show --connection --location=us --project_id=PROJECT_ID bigframes-rf-conn - """ - return "bigframes-rf-conn" +# # TODO(shobs): restore these tests after the managed udf cleanup issue is +# # resolved in the test project +# pytestmark = pytest.mark.skip( +# reason="temporarily disable to debug managed udf cleanup in the test project" +# ) def test_managed_function_multiply_with_ibis( @@ -45,7 +35,6 @@ def test_managed_function_multiply_with_ibis( bigquery_client, ibis_client, dataset_id, - bq_cf_connection, ): try: @@ -54,7 +43,6 @@ def test_managed_function_multiply_with_ibis( input_types=[int, int], output_type=int, dataset=dataset_id, - bigquery_connection=bq_cf_connection, ) def multiply(x, y): return x * y @@ -138,12 +126,10 @@ def stringify(x): cleanup_function_assets(stringify, bigquery_client, ignore_failures=False) -def test_managed_function_array_output( - session, scalars_dfs, dataset_id, bq_cf_connection -): +def test_managed_function_array_output(session, scalars_dfs, dataset_id): try: - @session.udf(dataset=dataset_id, bigquery_connection=bq_cf_connection) + @session.udf(dataset=dataset_id) def featurize(x: int) -> list[float]: return [float(i) for i in [x, x + 1, x + 2]] @@ -180,10 +166,10 @@ def featurize(x: int) -> list[float]: cleanup_function_assets(featurize, session.bqclient, ignore_failures=False) -def test_managed_function_series_apply(session, scalars_dfs, bq_cf_connection): +def test_managed_function_series_apply(session, scalars_dfs): try: - @session.udf(bigquery_connection=bq_cf_connection) + @session.udf() def foo(x: int) -> bytes: return bytes(abs(x)) @@ -263,7 +249,7 @@ def foo_list(x: int) -> list[float]: cleanup_function_assets(foo_list, session.bqclient, ignore_failures=False) -def test_managed_function_series_combine(session, scalars_dfs, bq_cf_connection): +def test_managed_function_series_combine(session, scalars_dfs): try: # This function is deliberately written to not work with NA input. def add(x: int, y: int) -> int: @@ -278,7 +264,7 @@ def add(x: int, y: int) -> int: # make sure there are NA values in the test column. assert any([pandas.isna(val) for val in bf_df[int_col_name_with_nulls]]) - add_managed_func = session.udf(bigquery_connection=bq_cf_connection)(add) + add_managed_func = session.udf()(add) # with nulls in the series the managed function application would fail. with pytest.raises( @@ -384,7 +370,7 @@ def add_list(x: int, y: int) -> list[int]: ) -def test_managed_function_dataframe_map(session, scalars_dfs, bq_cf_connection): +def test_managed_function_dataframe_map(session, scalars_dfs): try: def add_one(x): @@ -393,7 +379,6 @@ def add_one(x): mf_add_one = session.udf( input_types=[int], output_type=int, - bigquery_connection=bq_cf_connection, )(add_one) scalars_df, scalars_pandas_df = scalars_dfs @@ -496,9 +481,7 @@ def add_ints(x, y): cleanup_function_assets(add_ints_mf, session.bqclient, ignore_failures=False) -def test_managed_function_dataframe_apply_axis_1_array_output( - session, bq_cf_connection -): +def test_managed_function_dataframe_apply_axis_1_array_output(session): bf_df = bigframes.dataframe.DataFrame( { "Id": [1, 2, 3], @@ -521,7 +504,6 @@ def test_managed_function_dataframe_apply_axis_1_array_output( @session.udf( input_types=[int, float, str], output_type=list[str], - bigquery_connection=bq_cf_connection, ) def foo(x, y, z): return [str(x), str(y), z] @@ -605,3 +587,41 @@ def foo(x, y, z): finally: # Clean up the gcp assets created for the managed function. cleanup_function_assets(foo, session.bqclient, ignore_failures=False) + + +@pytest.mark.parametrize( + "connection_fixture", + [ + "bq_connection_name", + "bq_connection", + ], +) +def test_managed_function_with_connection( + session, scalars_dfs, request, connection_fixture +): + try: + bigquery_connection = request.getfixturevalue(connection_fixture) + + @session.udf(bigquery_connection=bigquery_connection) + def foo(x: int) -> int: + return x + 10 + + # Function should still work normally. + assert foo(-2) == 8 + + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result_col = scalars_df["int64_too"].apply(foo) + bf_result = ( + scalars_df["int64_too"].to_frame().assign(result=bf_result_col).to_pandas() + ) + + pd_result_col = scalars_pandas_df["int64_too"].apply(foo) + pd_result = ( + scalars_pandas_df["int64_too"].to_frame().assign(result=pd_result_col) + ) + + pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + finally: + # Clean up the gcp assets created for the managed function. + cleanup_function_assets(foo, session.bqclient, ignore_failures=False) From 237b1cdb2a664d064c38931b7e492b16fea52552 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Fri, 28 Mar 2025 06:09:06 +0000 Subject: [PATCH 3/3] restore pytestmark, remove a comment --- bigframes/functions/_function_session.py | 2 -- tests/system/large/functions/test_managed_function.py | 10 +++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/bigframes/functions/_function_session.py b/bigframes/functions/_function_session.py index 36acac32c7..4f4d306317 100644 --- a/bigframes/functions/_function_session.py +++ b/bigframes/functions/_function_session.py @@ -911,8 +911,6 @@ def wrapper(func): name=name, packages=packages, is_row_processor=is_row_processor, - # If bigquery_connection is not provided, bq_connection_id can - # be ignored in provision_bq_managed_function. bq_connection_id=bq_connection_id, ) diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index aabd0f438f..831ab71be7 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -22,11 +22,11 @@ import bigframes.pandas as bpd from tests.system.utils import cleanup_function_assets -# # TODO(shobs): restore these tests after the managed udf cleanup issue is -# # resolved in the test project -# pytestmark = pytest.mark.skip( -# reason="temporarily disable to debug managed udf cleanup in the test project" -# ) +# TODO(shobs): restore these tests after the managed udf cleanup issue is +# resolved in the test project +pytestmark = pytest.mark.skip( + reason="temporarily disable to debug managed udf cleanup in the test project" +) def test_managed_function_multiply_with_ibis(