Skip to content

Commit 70d6562

Browse files
sycaitswast
andauthored
feat: Add ai_generate_bool to the bigframes.bigquery package (#2060)
* feat: Add ai_generate_bool to the bigframes.bigquery package * fix stuffs * Fix format * fix doc format * fix format * fix code * expose ai module and rename the function * add ai module to doc * fix test * fix test * Update bigframes/bigquery/_operations/ai.py Co-authored-by: Tim Sweña (Swast) <[email protected]> --------- Co-authored-by: Tim Sweña (Swast) <[email protected]>
1 parent 3b46a0d commit 70d6562

File tree

13 files changed

+420
-2
lines changed

13 files changed

+420
-2
lines changed

bigframes/bigquery/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import sys
2020

21+
from bigframes.bigquery._operations import ai
2122
from bigframes.bigquery._operations.approx_agg import approx_top_count
2223
from bigframes.bigquery._operations.array import (
2324
array_agg,
@@ -98,7 +99,7 @@
9899
struct,
99100
]
100101

101-
__all__ = [f.__name__ for f in _functions]
102+
__all__ = [f.__name__ for f in _functions] + ["ai"]
102103

103104
_module = sys.modules[__name__]
104105
for f in _functions:
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""This module integrates BigQuery built-in AI functions for use with Series/DataFrame objects,
16+
such as AI.GENERATE_BOOL:
17+
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-generate-bool"""
18+
19+
from __future__ import annotations
20+
21+
import json
22+
from typing import Any, List, Literal, Mapping, Tuple
23+
24+
from bigframes import clients, dtypes, series
25+
from bigframes.core import log_adapter
26+
from bigframes.operations import ai_ops
27+
28+
29+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
30+
def generate_bool(
31+
prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...],
32+
*,
33+
connection_id: str | None = None,
34+
endpoint: str | None = None,
35+
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
36+
model_params: Mapping[Any, Any] | None = None,
37+
) -> series.Series:
38+
"""
39+
Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.
40+
41+
**Examples:**
42+
43+
>>> import bigframes.pandas as bpd
44+
>>> import bigframes.bigquery as bbq
45+
>>> bpd.options.display.progress_bar = None
46+
>>> df = bpd.DataFrame({
47+
... "col_1": ["apple", "bear", "pear"],
48+
... "col_2": ["fruit", "animal", "animal"]
49+
... })
50+
>>> bbq.ai_generate_bool((df["col_1"], " is a ", df["col_2"]))
51+
0 {'result': True, 'full_response': '{"candidate...
52+
1 {'result': True, 'full_response': '{"candidate...
53+
2 {'result': False, 'full_response': '{"candidat...
54+
dtype: struct<result: bool, full_response: string, status: string>[pyarrow]
55+
56+
>>> bbq.ai_generate_bool((df["col_1"], " is a ", df["col_2"])).struct.field("result")
57+
0 True
58+
1 True
59+
2 False
60+
Name: result, dtype: boolean
61+
62+
>>> model_params = {
63+
... "generation_config": {
64+
... "thinking_config": {
65+
... "thinking_budget": 0
66+
... }
67+
... }
68+
... }
69+
>>> bbq.ai_generate_bool(
70+
... (df["col_1"], " is a ", df["col_2"]),
71+
... endpoint="gemini-2.5-pro",
72+
... model_params=model_params,
73+
... ).struct.field("result")
74+
0 True
75+
1 True
76+
2 False
77+
Name: result, dtype: boolean
78+
79+
Args:
80+
prompt (series.Series | List[str|series.Series] | Tuple[str|series.Series, ...]):
81+
A mixture of Series and string literals that specifies the prompt to send to the model.
82+
connection_id (str, optional):
83+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
84+
If not provided, the connection from the current session will be used.
85+
endpoint (str, optional):
86+
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
87+
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
88+
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects a recent stable
89+
version of Gemini to use.
90+
request_type (Literal["dedicated", "shared", "unspecified"]):
91+
Specifies the type of inference request to send to the Gemini model. The request type determines what quota the request uses.
92+
* "dedicated": function only uses Provisioned Throughput quota. The function returns the error Provisioned throughput is not
93+
purchased or is not active if Provisioned Throughput quota isn't available.
94+
* "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota.
95+
* "unspecified": If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota.
96+
If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first.
97+
If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
98+
model_params (Mapping[Any, Any]):
99+
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.
100+
101+
Returns:
102+
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
103+
* "result": a BOOL value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
104+
* "full_response": a STRING value containing the JSON response from the projects.locations.endpoints.generateContent call to the model.
105+
The generated text is in the text element.
106+
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
107+
"""
108+
109+
prompt_context, series_list = _separate_context_and_series(prompt)
110+
assert len(series_list) > 0
111+
112+
operator = ai_ops.AIGenerateBool(
113+
prompt_context=tuple(prompt_context),
114+
connection_id=_resolve_connection_id(series_list[0], connection_id),
115+
endpoint=endpoint,
116+
request_type=request_type,
117+
model_params=json.dumps(model_params) if model_params else None,
118+
)
119+
120+
return series_list[0]._apply_nary_op(operator, series_list[1:])
121+
122+
123+
def _separate_context_and_series(
124+
prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...],
125+
) -> Tuple[List[str | None], List[series.Series]]:
126+
"""
127+
Returns the two values. The first value is the prompt with all series replaced by None. The second value is all the series
128+
in the prompt. The original item order is kept.
129+
For example:
130+
Input: ("str1", series1, "str2", "str3", series2)
131+
Output: ["str1", None, "str2", "str3", None], [series1, series2]
132+
"""
133+
if not isinstance(prompt, (list, tuple, series.Series)):
134+
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
135+
136+
if isinstance(prompt, series.Series):
137+
if prompt.dtype == dtypes.OBJ_REF_DTYPE:
138+
# Multi-model support
139+
return [None], [prompt.blob.read_url()]
140+
return [None], [prompt]
141+
142+
prompt_context: List[str | None] = []
143+
series_list: List[series.Series] = []
144+
145+
for item in prompt:
146+
if isinstance(item, str):
147+
prompt_context.append(item)
148+
149+
elif isinstance(item, series.Series):
150+
prompt_context.append(None)
151+
152+
if item.dtype == dtypes.OBJ_REF_DTYPE:
153+
# Multi-model support
154+
item = item.blob.read_url()
155+
series_list.append(item)
156+
157+
else:
158+
raise TypeError(f"Unsupported type in prompt: {type(item)}")
159+
160+
if not series_list:
161+
raise ValueError("Please provide at least one Series in the prompt")
162+
163+
return prompt_context, series_list
164+
165+
166+
def _resolve_connection_id(series: series.Series, connection_id: str | None):
167+
return clients.get_canonical_bq_connection_id(
168+
connection_id or series._session._bq_connection,
169+
series._session._project,
170+
series._session._location,
171+
)

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import functools
1818
import typing
1919

20+
from bigframes_vendored import ibis
2021
import bigframes_vendored.ibis.expr.api as ibis_api
2122
import bigframes_vendored.ibis.expr.datatypes as ibis_dtypes
23+
import bigframes_vendored.ibis.expr.operations.ai_ops as ai_ops
2224
import bigframes_vendored.ibis.expr.operations.generic as ibis_generic
2325
import bigframes_vendored.ibis.expr.operations.udf as ibis_udf
2426
import bigframes_vendored.ibis.expr.types as ibis_types
@@ -1963,6 +1965,30 @@ def struct_op_impl(
19631965
return ibis_types.struct(data)
19641966

19651967

1968+
@scalar_op_compiler.register_nary_op(ops.AIGenerateBool, pass_op=True)
1969+
def ai_generate_bool(
1970+
*values: ibis_types.Value, op: ops.AIGenerateBool
1971+
) -> ibis_types.StructValue:
1972+
1973+
prompt: dict[str, ibis_types.Value | str] = {}
1974+
column_ref_idx = 0
1975+
1976+
for idx, elem in enumerate(op.prompt_context):
1977+
if elem is None:
1978+
prompt[f"_field_{idx + 1}"] = values[column_ref_idx]
1979+
column_ref_idx += 1
1980+
else:
1981+
prompt[f"_field_{idx + 1}"] = elem
1982+
1983+
return ai_ops.AIGenerateBool(
1984+
ibis.struct(prompt), # type: ignore
1985+
op.connection_id, # type: ignore
1986+
op.endpoint, # type: ignore
1987+
op.request_type.upper(), # type: ignore
1988+
op.model_params, # type: ignore
1989+
).to_expr()
1990+
1991+
19661992
@scalar_op_compiler.register_nary_op(ops.RowKey, pass_op=True)
19671993
def rowkey_op_impl(*values: ibis_types.Value, op: ops.RowKey) -> ibis_types.Value:
19681994
return bigframes.core.compile.default_ordering.gen_row_key(values)

bigframes/operations/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from bigframes.operations.ai_ops import AIGenerateBool
1718
from bigframes.operations.array_ops import (
1819
ArrayIndexOp,
1920
ArrayReduceOp,
@@ -408,6 +409,8 @@
408409
"geo_x_op",
409410
"geo_y_op",
410411
"GeoStDistanceOp",
412+
# AI ops
413+
"AIGenerateBool",
411414
# Numpy ops mapping
412415
"NUMPY_TO_BINOP",
413416
"NUMPY_TO_OP",

bigframes/operations/ai_ops.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
from typing import ClassVar, Literal, Tuple
19+
20+
import pandas as pd
21+
import pyarrow as pa
22+
23+
from bigframes import dtypes
24+
from bigframes.operations import base_ops
25+
26+
27+
@dataclasses.dataclass(frozen=True)
28+
class AIGenerateBool(base_ops.NaryOp):
29+
name: ClassVar[str] = "ai_generate_bool"
30+
31+
# None are the placeholders for column references.
32+
prompt_context: Tuple[str | None, ...]
33+
connection_id: str
34+
endpoint: str | None
35+
request_type: Literal["dedicated", "shared", "unspecified"]
36+
model_params: str | None
37+
38+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
39+
return pd.ArrowDtype(
40+
pa.struct(
41+
(
42+
pa.field("result", pa.bool_()),
43+
pa.field("full_response", pa.string()),
44+
pa.field("status", pa.string()),
45+
)
46+
)
47+
)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
bigframes.bigquery.ai
2+
=============================
3+
4+
.. automodule:: bigframes.bigquery._operations.ai
5+
:members:
6+
:inherited-members:
7+
:undoc-members:

docs/reference/bigframes.bigquery/index.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,9 @@ BigQuery Built-in Functions
55

66
.. automodule:: bigframes.bigquery
77
:members:
8-
:inherited-members:
98
:undoc-members:
9+
10+
.. toctree::
11+
:maxdepth: 2
12+
13+
ai

docs/templates/toc.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@
218218
- items:
219219
- name: BigQuery built-in functions
220220
uid: bigframes.bigquery
221+
- name: BigQuery AI Functions
222+
uid: bigframes.bigquery.ai
221223
name: bigframes.bigquery
222224
- items:
223225
- name: GeoSeries
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pandas as pd
16+
import pandas.testing
17+
18+
import bigframes.bigquery as bbq
19+
20+
21+
def test_ai_generate_bool_multi_model(session):
22+
df = session.from_glob_path(
23+
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
24+
)
25+
26+
result = bbq.ai.generate_bool((df["image"], " contains an animal")).struct.field(
27+
"result"
28+
)
29+
30+
pandas.testing.assert_series_equal(
31+
result.to_pandas(),
32+
pd.Series([True, True, False, False, False], name="result"),
33+
check_dtype=False,
34+
check_index=False,
35+
)

0 commit comments

Comments
 (0)