Skip to content

Commit ef0b0b7

Browse files
feat: add output_schema parameter to ai.generate() (#2139)
* feat: add output_schema to ai.generate() * fix lint * fix lint * fix test * fix mypy * fix lint * code optimization * fix tests * support case-insensitive type parsing * fix test * fix: Fix row count local execution bug (#2133) * fix: join on, how args are now positional (#2140) --------- Co-authored-by: TrevorBergeron <[email protected]>
1 parent fa4e46f commit ef0b0b7

File tree

10 files changed

+328
-12
lines changed

10 files changed

+328
-12
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from bigframes import clients, dtypes, series, session
2727
from bigframes.core import convert, log_adapter
28-
from bigframes.operations import ai_ops
28+
from bigframes.operations import ai_ops, output_schemas
2929

3030
PROMPT_TYPE = Union[
3131
series.Series,
@@ -43,7 +43,7 @@ def generate(
4343
endpoint: str | None = None,
4444
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
4545
model_params: Mapping[Any, Any] | None = None,
46-
# TODO(b/446974666) Add output_schema parameter
46+
output_schema: Mapping[str, str] | None = None,
4747
) -> series.Series:
4848
"""
4949
Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.
@@ -64,6 +64,14 @@ def generate(
6464
1 Ottawa\\n
6565
Name: result, dtype: string
6666
67+
You get structured output when the `output_schema` parameter is set:
68+
69+
>>> animals = bpd.Series(["Rabbit", "Spider"])
70+
>>> bbq.ai.generate(animals, output_schema={"number_of_legs": "INT64", "is_herbivore": "BOOL"})
71+
0 {'is_herbivore': True, 'number_of_legs': 4, 'f...
72+
1 {'is_herbivore': False, 'number_of_legs': 8, '...
73+
dtype: struct<is_herbivore: bool, number_of_legs: int64, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
74+
6775
Args:
6876
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
6977
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
@@ -86,10 +94,14 @@ def generate(
8694
If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
8795
model_params (Mapping[Any, Any]):
8896
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.
97+
output_schema (Mapping[str, str]):
98+
A mapping value that specifies the schema of the output, in the form {field_name: data_type}. Supported data types include
99+
`STRING`, `INT64`, `FLOAT64`, `BOOL`, `ARRAY`, and `STRUCT`.
89100
90101
Returns:
91102
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
92103
* "result": a STRING value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
104+
If you specify an output schema then result is replaced by your custom schema.
93105
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
94106
The generated text is in the text element.
95107
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
@@ -98,12 +110,22 @@ def generate(
98110
prompt_context, series_list = _separate_context_and_series(prompt)
99111
assert len(series_list) > 0
100112

113+
if output_schema is None:
114+
output_schema_str = None
115+
else:
116+
output_schema_str = ", ".join(
117+
[f"{name} {sql_type}" for name, sql_type in output_schema.items()]
118+
)
119+
# Validate user input
120+
output_schemas.parse_sql_fields(output_schema_str)
121+
101122
operator = ai_ops.AIGenerate(
102123
prompt_context=tuple(prompt_context),
103124
connection_id=_resolve_connection_id(series_list[0], connection_id),
104125
endpoint=endpoint,
105126
request_type=request_type,
106127
model_params=json.dumps(model_params) if model_params else None,
128+
output_schema=output_schema_str,
107129
)
108130

109131
return series_list[0]._apply_nary_op(operator, series_list[1:])

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,6 +1985,7 @@ def ai_generate(
19851985
op.endpoint, # type: ignore
19861986
op.request_type.upper(), # type: ignore
19871987
op.model_params, # type: ignore
1988+
op.output_schema, # type: ignore
19881989
).to_expr()
19891990

19901991

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from __future__ import annotations
1616

1717
from dataclasses import asdict
18-
import typing
1918

2019
import sqlglot.expressions as sge
2120

@@ -105,24 +104,24 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
105104

106105
op_args = asdict(op)
107106

108-
connection_id = typing.cast(str, op_args["connection_id"])
107+
connection_id = op_args["connection_id"]
109108
args.append(
110109
sge.Kwarg(this="connection_id", expression=sge.Literal.string(connection_id))
111110
)
112111

113-
endpoit = typing.cast(str, op_args.get("endpoint", None))
112+
endpoit = op_args.get("endpoint", None)
114113
if endpoit is not None:
115114
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit)))
116115

117-
request_type = typing.cast(str, op_args.get("request_type", None))
116+
request_type = op_args.get("request_type", None)
118117
if request_type is not None:
119118
args.append(
120119
sge.Kwarg(
121120
this="request_type", expression=sge.Literal.string(request_type.upper())
122121
)
123122
)
124123

125-
model_params = typing.cast(str, op_args.get("model_params", None))
124+
model_params = op_args.get("model_params", None)
126125
if model_params is not None:
127126
args.append(
128127
sge.Kwarg(
@@ -133,4 +132,13 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
133132
)
134133
)
135134

135+
output_schema = op_args.get("output_schema", None)
136+
if output_schema is not None:
137+
args.append(
138+
sge.Kwarg(
139+
this="output_schema",
140+
expression=sge.Literal.string(output_schema),
141+
)
142+
)
143+
136144
return args

bigframes/operations/ai_ops.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pyarrow as pa
2222

2323
from bigframes import dtypes
24-
from bigframes.operations import base_ops
24+
from bigframes.operations import base_ops, output_schemas
2525

2626

2727
@dataclasses.dataclass(frozen=True)
@@ -33,12 +33,18 @@ class AIGenerate(base_ops.NaryOp):
3333
endpoint: str | None
3434
request_type: Literal["dedicated", "shared", "unspecified"]
3535
model_params: str | None
36+
output_schema: str | None
3637

3738
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
39+
if self.output_schema is None:
40+
output_fields = (pa.field("result", pa.string()),)
41+
else:
42+
output_fields = output_schemas.parse_sql_fields(self.output_schema)
43+
3844
return pd.ArrowDtype(
3945
pa.struct(
4046
(
41-
pa.field("result", pa.string()),
47+
*output_fields,
4248
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
4349
pa.field("status", pa.string()),
4450
)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pyarrow as pa
16+
17+
18+
def parse_sql_type(sql: str) -> pa.DataType:
19+
"""
20+
Parses a SQL type string to its PyArrow equivalence:
21+
22+
For example:
23+
"STRING" -> pa.string()
24+
"ARRAY<INT64>" -> pa.list_(pa.int64())
25+
"STRUCT<x ARRAY<FLOAT64>, y BOOL>" -> pa.struct(
26+
(
27+
pa.field("x", pa.list_(pa.float64())),
28+
pa.field("y", pa.bool_()),
29+
)
30+
)
31+
"""
32+
sql = sql.strip()
33+
34+
if sql.upper() == "STRING":
35+
return pa.string()
36+
37+
if sql.upper() == "INT64":
38+
return pa.int64()
39+
40+
if sql.upper() == "FLOAT64":
41+
return pa.float64()
42+
43+
if sql.upper() == "BOOL":
44+
return pa.bool_()
45+
46+
if sql.upper().startswith("ARRAY<") and sql.endswith(">"):
47+
inner_type = sql[len("ARRAY<") : -1]
48+
return pa.list_(parse_sql_type(inner_type))
49+
50+
if sql.upper().startswith("STRUCT<") and sql.endswith(">"):
51+
inner_fields = parse_sql_fields(sql[len("STRUCT<") : -1])
52+
return pa.struct(inner_fields)
53+
54+
raise ValueError(f"Unsupported SQL type: {sql}")
55+
56+
57+
def parse_sql_fields(sql: str) -> tuple[pa.Field]:
58+
sql = sql.strip()
59+
60+
start_idx = 0
61+
nested_depth = 0
62+
fields: list[pa.field] = []
63+
64+
for end_idx in range(len(sql)):
65+
c = sql[end_idx]
66+
67+
if c == "<":
68+
nested_depth += 1
69+
elif c == ">":
70+
nested_depth -= 1
71+
elif c == "," and nested_depth == 0:
72+
field = sql[start_idx:end_idx]
73+
fields.append(parse_sql_field(field))
74+
start_idx = end_idx + 1
75+
76+
# Append the last field
77+
fields.append(parse_sql_field(sql[start_idx:]))
78+
79+
return tuple(sorted(fields, key=lambda f: f.name))
80+
81+
82+
def parse_sql_field(sql: str) -> pa.Field:
83+
sql = sql.strip()
84+
85+
space_idx = sql.find(" ")
86+
87+
if space_idx == -1:
88+
raise ValueError(f"Invalid struct field: {sql}")
89+
90+
return pa.field(sql[:space_idx].strip(), parse_sql_type(sql[space_idx:]))

tests/system/small/bigquery/test_ai.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,41 @@ def test_ai_generate(session):
8787
)
8888

8989

90+
def test_ai_generate_with_output_schema(session):
91+
country = bpd.Series(["Japan", "Canada"], session=session)
92+
prompt = ("Describe ", country)
93+
94+
result = bbq.ai.generate(
95+
prompt,
96+
endpoint="gemini-2.5-flash",
97+
output_schema={"population": "INT64", "is_in_north_america": "bool"},
98+
)
99+
100+
assert _contains_no_nulls(result)
101+
assert result.dtype == pd.ArrowDtype(
102+
pa.struct(
103+
(
104+
pa.field("is_in_north_america", pa.bool_()),
105+
pa.field("population", pa.int64()),
106+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
107+
pa.field("status", pa.string()),
108+
)
109+
)
110+
)
111+
112+
113+
def test_ai_generate_with_invalid_output_schema_raise_error(session):
114+
country = bpd.Series(["Japan", "Canada"], session=session)
115+
prompt = ("Describe ", country)
116+
117+
with pytest.raises(ValueError):
118+
bbq.ai.generate(
119+
prompt,
120+
endpoint="gemini-2.5-flash",
121+
output_schema={"population": "INT64", "is_in_north_america": "JSON"},
122+
)
123+
124+
90125
def test_ai_generate_bool(session):
91126
s1 = bpd.Series(["apple", "bear"], session=session)
92127
s2 = bpd.Series(["fruit", "tree"], session=session)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.GENERATE(
9+
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
11+
endpoint => 'gemini-2.5-flash',
12+
request_type => 'SHARED',
13+
output_schema => 'x INT64, y FLOAT64'
14+
) AS `bfcol_1`
15+
FROM `bfcte_0`
16+
)
17+
SELECT
18+
`bfcol_1` AS `result`
19+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,26 @@ def test_ai_generate(scalar_types_df: dataframe.DataFrame, snapshot):
3636
endpoint="gemini-2.5-flash",
3737
request_type="shared",
3838
model_params=None,
39+
output_schema=None,
40+
)
41+
42+
sql = utils._apply_unary_ops(
43+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
44+
)
45+
46+
snapshot.assert_match(sql, "out.sql")
47+
48+
49+
def test_ai_generate_with_output_schema(scalar_types_df: dataframe.DataFrame, snapshot):
50+
col_name = "string_col"
51+
52+
op = ops.AIGenerate(
53+
prompt_context=(None, " is the same as ", None),
54+
connection_id=CONNECTION_ID,
55+
endpoint="gemini-2.5-flash",
56+
request_type="shared",
57+
model_params=None,
58+
output_schema="x INT64, y FLOAT64",
3959
)
4060

4161
sql = utils._apply_unary_ops(
@@ -59,6 +79,7 @@ def test_ai_generate_with_model_param(scalar_types_df: dataframe.DataFrame, snap
5979
endpoint=None,
6080
request_type="shared",
6181
model_params=json.dumps(dict()),
82+
output_schema=None,
6283
)
6384

6485
sql = utils._apply_unary_ops(

0 commit comments

Comments
 (0)