diff --git a/bigframes/series.py b/bigframes/series.py index da2f3f07c4..7196341c5d 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -25,6 +25,7 @@ import typing from typing import ( Any, + Callable, cast, Iterable, List, @@ -2330,7 +2331,7 @@ def _throw_if_index_contains_duplicates( def map( self, - arg: typing.Union[Mapping, Series], + arg: typing.Union[Mapping, Series, Callable], na_action: Optional[str] = None, *, verify_integrity: bool = False, @@ -2352,6 +2353,7 @@ def map( ) map_df = map_df.set_index("keys") elif callable(arg): + # This is for remote function and managed funtion. return self.apply(arg) else: # Mirroring pandas, call the uncallable object diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index dd08ed17d9..e74bc8579f 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -1245,7 +1245,7 @@ def the_sum(s): cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False) -def test_managed_function_series_where_mask(session, dataset_id, scalars_dfs): +def test_managed_function_series_where_mask_map(session, dataset_id, scalars_dfs): try: # The return type has to be bool type for callable where condition. @@ -1286,6 +1286,13 @@ def _is_positive(s): # Ignore any dtype difference. pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + # Test series.map method. + bf_result = bf_int64_filtered.map(is_positive_mf).to_pandas() + pd_result = pd_int64_filtered.map(_is_positive) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + finally: # Clean up the gcp assets created for the managed function. cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False)