Skip to content

Commit 7600001

Browse files
authored
feat: support string literal inputs for AI functions (#2152)
* feat: support string literal inputs for AI functions * polish code * update pydoc
1 parent 1f434fb commit 7600001

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from bigframes.operations import ai_ops, output_schemas
2929

3030
PROMPT_TYPE = Union[
31+
str,
3132
series.Series,
3233
pd.Series,
3334
List[Union[str, series.Series, pd.Series]],
@@ -73,7 +74,7 @@ def generate(
7374
dtype: struct<is_herbivore: bool, number_of_legs: int64, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
7475
7576
Args:
76-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
77+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
7778
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
7879
or pandas Series.
7980
connection_id (str, optional):
@@ -165,7 +166,7 @@ def generate_bool(
165166
Name: result, dtype: boolean
166167
167168
Args:
168-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
169+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
169170
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
170171
or pandas Series.
171172
connection_id (str, optional):
@@ -240,7 +241,7 @@ def generate_int(
240241
Name: result, dtype: Int64
241242
242243
Args:
243-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
244+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
244245
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
245246
or pandas Series.
246247
connection_id (str, optional):
@@ -315,7 +316,7 @@ def generate_double(
315316
Name: result, dtype: Float64
316317
317318
Args:
318-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
319+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
319320
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
320321
or pandas Series.
321322
connection_id (str, optional):
@@ -386,7 +387,7 @@ def if_(
386387
dtype: string
387388
388389
Args:
389-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
390+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
390391
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
391392
or pandas Series.
392393
connection_id (str, optional):
@@ -433,7 +434,7 @@ def classify(
433434
[2 rows x 2 columns]
434435
435436
Args:
436-
input (Series | List[str|Series] | Tuple[str|Series, ...]):
437+
input (str | Series | List[str|Series] | Tuple[str|Series, ...]):
437438
A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series
438439
or pandas Series.
439440
categories (tuple[str, ...] | list[str]):
@@ -482,7 +483,7 @@ def score(
482483
dtype: Float64
483484
484485
Args:
485-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
486+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
486487
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
487488
or pandas Series.
488489
connection_id (str, optional):
@@ -514,9 +515,12 @@ def _separate_context_and_series(
514515
Input: ("str1", series1, "str2", "str3", series2)
515516
Output: ["str1", None, "str2", "str3", None], [series1, series2]
516517
"""
517-
if not isinstance(prompt, (list, tuple, series.Series)):
518+
if not isinstance(prompt, (str, list, tuple, series.Series)):
518519
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
519520

521+
if isinstance(prompt, str):
522+
return [None], [series.Series([prompt])]
523+
520524
if isinstance(prompt, series.Series):
521525
if prompt.dtype == dtypes.OBJ_REF_DTYPE:
522526
# Multi-model support

tests/system/small/bigquery/test_ai.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from unittest import mock
16+
1517
from packaging import version
1618
import pandas as pd
1719
import pyarrow as pa
@@ -42,6 +44,27 @@ def test_ai_function_pandas_input(session):
4244
)
4345

4446

47+
def test_ai_function_string_input(session):
48+
with mock.patch(
49+
"bigframes.core.global_session.get_global_session"
50+
) as mock_get_session:
51+
mock_get_session.return_value = session
52+
prompt = "Is apple a fruit?"
53+
54+
result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash")
55+
56+
assert _contains_no_nulls(result)
57+
assert result.dtype == pd.ArrowDtype(
58+
pa.struct(
59+
(
60+
pa.field("result", pa.bool_()),
61+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
62+
pa.field("status", pa.string()),
63+
)
64+
)
65+
)
66+
67+
4568
def test_ai_function_compile_model_params(session):
4669
if version.Version(sqlglot.__version__) < version.Version("25.18.0"):
4770
pytest.skip(

0 commit comments

Comments
 (0)