Skip to content

feat: support list output for managed function #1457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4193,11 +4193,13 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
udf_input_dtypes = getattr(func, "input_dtypes")
if len(udf_input_dtypes) != len(self.columns):
raise ValueError(
f"Remote function takes {len(udf_input_dtypes)} arguments but DataFrame has {len(self.columns)} columns."
f"Bigframes bigquery function takes {len(udf_input_dtypes)}"
f" arguments but DataFrame has {len(self.columns)} columns."
)
if udf_input_dtypes != tuple(self.dtypes.to_list()):
raise ValueError(
f"Remote function takes arguments of types {udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}."
f"Bigframes bigquery function takes arguments of types "
f"{udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}."
)

series_list = [self[col] for col in self.columns]
Expand Down
39 changes: 30 additions & 9 deletions bigframes/operations/remote_function_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,18 @@ def expensive(self) -> bool:
return True

def output_type(self, *input_types):
# This property should be set to a valid Dtype by the @remote_function decorator or read_gbq_function method
# The output dtype should be set to a valid Dtype by @udf decorator,
# @remote_function decorator, or read_gbq_function method.

# This is for remote function.
if hasattr(self.func, "bigframes_bigquery_function_output_dtype"):
return self.func.bigframes_bigquery_function_output_dtype
else:
raise AttributeError("bigframes_bigquery_function_output_dtype not defined")

# This is for managed function.
if hasattr(self.func, "output_dtype"):
return self.func.output_dtype

raise AttributeError("output_dtype not defined")


@dataclasses.dataclass(frozen=True)
Expand All @@ -46,11 +53,18 @@ def expensive(self) -> bool:
return True

def output_type(self, *input_types):
# This property should be set to a valid Dtype by the @remote_function decorator or read_gbq_function method
# The output dtype should be set to a valid Dtype by @udf decorator,
# @remote_function decorator, or read_gbq_function method.

# This is for remote function.
if hasattr(self.func, "bigframes_bigquery_function_output_dtype"):
return self.func.bigframes_bigquery_function_output_dtype
else:
raise AttributeError("bigframes_bigquery_function_output_dtype not defined")

# This is for managed function.
if hasattr(self.func, "output_dtype"):
return self.func.output_dtype

raise AttributeError("output_dtype not defined")


@dataclasses.dataclass(frozen=True)
Expand All @@ -63,8 +77,15 @@ def expensive(self) -> bool:
return True

def output_type(self, *input_types):
# This property should be set to a valid Dtype by the @remote_function decorator or read_gbq_function method
# The output dtype should be set to a valid Dtype by @udf decorator,
# @remote_function decorator, or read_gbq_function method.

# This is for remote function.
if hasattr(self.func, "bigframes_bigquery_function_output_dtype"):
return self.func.bigframes_bigquery_function_output_dtype
else:
raise AttributeError("bigframes_bigquery_function_output_dtype not defined")

# This is for managed function.
if hasattr(self.func, "output_dtype"):
return self.func.output_dtype

raise AttributeError("output_dtype not defined")
156 changes: 156 additions & 0 deletions tests/system/large/functions/test_managed_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

import pandas
import pyarrow
import pytest

import bigframes
from bigframes.functions import _function_session as bff_session
from bigframes.functions._utils import get_python_version
import bigframes.pandas as bpd
Expand Down Expand Up @@ -164,3 +166,157 @@ def func(x, y):
cleanup_function_assets(
session.bqclient, session.cloudfunctionsclient, managed_func
)


@pytest.mark.parametrize(
"array_dtype",
[
bool,
int,
float,
str,
],
)
@pytest.mark.skipif(
get_python_version() not in bff_session._MANAGED_FUNC_PYTHON_VERSIONS,
reason=f"Supported version: {bff_session._MANAGED_FUNC_PYTHON_VERSIONS}",
)
def test_managed_function_array_output(session, scalars_dfs, dataset_id, array_dtype):
try:

@session.udf(dataset=dataset_id)
def featurize(x: int) -> list[array_dtype]: # type: ignore
return [array_dtype(i) for i in [x, x + 1, x + 2]]

scalars_df, scalars_pandas_df = scalars_dfs

bf_int64_col = scalars_df["int64_too"]
bf_result = bf_int64_col.apply(featurize).to_pandas()

pd_int64_col = scalars_pandas_df["int64_too"]
pd_result = pd_int64_col.apply(featurize)

# Ignore any dtype disparity.
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)

finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(
featurize, session.bqclient, session.cloudfunctionsclient
)


@pytest.mark.skipif(
get_python_version() not in bff_session._MANAGED_FUNC_PYTHON_VERSIONS,
reason=f"Supported version: {bff_session._MANAGED_FUNC_PYTHON_VERSIONS}",
)
def test_managed_function_binop_array_output(session, scalars_dfs, dataset_id):
try:

def func(x, y):
return [len(x), abs(y % 4)]

managed_func = session.udf(
input_types=[str, int],
output_type=list[int],
dataset=dataset_id,
)(func)

scalars_df, scalars_pandas_df = scalars_dfs

scalars_df = scalars_df.dropna()
scalars_pandas_df = scalars_pandas_df.dropna()
bf_result = (
scalars_df["string_col"]
.combine(scalars_df["int64_col"], managed_func)
.to_pandas()
)
pd_result = scalars_pandas_df["string_col"].combine(
scalars_pandas_df["int64_col"], func
)
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(
managed_func, session.bqclient, session.cloudfunctionsclient
)


@pytest.mark.skipif(
get_python_version() not in bff_session._MANAGED_FUNC_PYTHON_VERSIONS,
reason=f"Supported version: {bff_session._MANAGED_FUNC_PYTHON_VERSIONS}",
)
def test_manage_function_df_apply_axis_1_array_output(session):
bf_df = bigframes.dataframe.DataFrame(
{
"Id": [1, 2, 3],
"Age": [22.5, 23, 23.5],
"Name": ["alpha", "beta", "gamma"],
}
)

expected_dtypes = (
bigframes.dtypes.INT_DTYPE,
bigframes.dtypes.FLOAT_DTYPE,
bigframes.dtypes.STRING_DTYPE,
)

# Assert the dataframe dtypes.
assert tuple(bf_df.dtypes) == expected_dtypes

try:

@session.udf(input_types=[int, float, str], output_type=list[str])
def foo(x, y, z):
return [str(x), str(y), z]

assert getattr(foo, "is_row_processor") is False
assert getattr(foo, "input_dtypes") == expected_dtypes
assert getattr(foo, "output_dtype") == pandas.ArrowDtype(
pyarrow.list_(
bigframes.dtypes.bigframes_dtype_to_arrow_dtype(
bigframes.dtypes.STRING_DTYPE
)
)
)

# Fails to apply on dataframe with incompatible number of columns.
with pytest.raises(
ValueError,
match="^Bigframes bigquery function takes 3 arguments but DataFrame has 2 columns\\.$",
):
bf_df[["Id", "Age"]].apply(foo, axis=1)
with pytest.raises(
ValueError,
match="^Bigframes bigquery function takes 3 arguments but DataFrame has 4 columns\\.$",
):
bf_df.assign(Country="lalaland").apply(foo, axis=1)

# Fails to apply on dataframe with incompatible column datatypes.
with pytest.raises(
ValueError,
match="^Bigframes bigquery function takes arguments of types .* but DataFrame dtypes are .*",
):
bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1)

# Successfully applies to dataframe with matching number of columns.
# and their datatypes.
bf_result = bf_df.apply(foo, axis=1).to_pandas()

# Since this scenario is not pandas-like, let's handcraft the
# expected result.
expected_result = pandas.Series(
[
["1", "22.5", "alpha"],
["2", "23", "beta"],
["3", "23.5", "gamma"],
]
)

pandas.testing.assert_series_equal(
expected_result, bf_result, check_dtype=False, check_index_type=False
)

finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(foo, session.bqclient, session.cloudfunctionsclient)
18 changes: 9 additions & 9 deletions tests/system/large/functions/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2085,19 +2085,19 @@ def foo(x, y, z):
# Fails to apply on dataframe with incompatible number of columns
with pytest.raises(
ValueError,
match="^Remote function takes 3 arguments but DataFrame has 2 columns\\.$",
match="^Bigframes bigquery function takes 3 arguments but DataFrame has 2 columns\\.$",
):
bf_df[["Id", "Age"]].apply(foo, axis=1)
with pytest.raises(
ValueError,
match="^Remote function takes 3 arguments but DataFrame has 4 columns\\.$",
match="^Bigframes bigquery function takes 3 arguments but DataFrame has 4 columns\\.$",
):
bf_df.assign(Country="lalaland").apply(foo, axis=1)

# Fails to apply on dataframe with incompatible column datatypes
with pytest.raises(
ValueError,
match="^Remote function takes arguments of types .* but DataFrame dtypes are .*",
match="^Bigframes bigquery function takes arguments of types .* but DataFrame dtypes are .*",
):
bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1)

Expand Down Expand Up @@ -2171,19 +2171,19 @@ def foo(x, y, z):
# Fails to apply on dataframe with incompatible number of columns
with pytest.raises(
ValueError,
match="^Remote function takes 3 arguments but DataFrame has 2 columns\\.$",
match="^Bigframes bigquery function takes 3 arguments but DataFrame has 2 columns\\.$",
):
bf_df[["Id", "Age"]].apply(foo, axis=1)
with pytest.raises(
ValueError,
match="^Remote function takes 3 arguments but DataFrame has 4 columns\\.$",
match="^Bigframes bigquery function takes 3 arguments but DataFrame has 4 columns\\.$",
):
bf_df.assign(Country="lalaland").apply(foo, axis=1)

# Fails to apply on dataframe with incompatible column datatypes
with pytest.raises(
ValueError,
match="^Remote function takes arguments of types .* but DataFrame dtypes are .*",
match="^Bigframes bigquery function takes arguments of types .* but DataFrame dtypes are .*",
):
bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1)

Expand Down Expand Up @@ -2240,19 +2240,19 @@ def foo(x):
# Fails to apply on dataframe with incompatible number of columns
with pytest.raises(
ValueError,
match="^Remote function takes 1 arguments but DataFrame has 0 columns\\.$",
match="^Bigframes bigquery function takes 1 arguments but DataFrame has 0 columns\\.$",
):
bf_df[[]].apply(foo, axis=1)
with pytest.raises(
ValueError,
match="^Remote function takes 1 arguments but DataFrame has 2 columns\\.$",
match="^Bigframes bigquery function takes 1 arguments but DataFrame has 2 columns\\.$",
):
bf_df.assign(Country="lalaland").apply(foo, axis=1)

# Fails to apply on dataframe with incompatible column datatypes
with pytest.raises(
ValueError,
match="^Remote function takes arguments of types .* but DataFrame dtypes are .*",
match="^Bigframes bigquery function takes arguments of types .* but DataFrame dtypes are .*",
):
bf_df.assign(Id=bf_df["Id"].astype("Float64")).apply(foo, axis=1)

Expand Down
Loading