Skip to content

Commit 6b8154c

Browse files
authored
feat: add ai.generate_double to bigframes.bigquery package (#2111)
* feat: add ai.generate_double to bigframes.bigquery package * fix lint * fix doctest
1 parent 1fc563c commit 6b8154c

File tree

11 files changed

+267
-2
lines changed

11 files changed

+267
-2
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,81 @@ def generate_int(
188188
return series_list[0]._apply_nary_op(operator, series_list[1:])
189189

190190

191+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
192+
def generate_double(
193+
prompt: PROMPT_TYPE,
194+
*,
195+
connection_id: str | None = None,
196+
endpoint: str | None = None,
197+
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
198+
model_params: Mapping[Any, Any] | None = None,
199+
) -> series.Series:
200+
"""
201+
Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.
202+
203+
**Examples:**
204+
205+
>>> import bigframes.pandas as bpd
206+
>>> import bigframes.bigquery as bbq
207+
>>> bpd.options.display.progress_bar = None
208+
>>> animal = bpd.Series(["Kangaroo", "Rabbit", "Spider"])
209+
>>> bbq.ai.generate_double(("How many legs does a ", animal, " have?"))
210+
0 {'result': 2.0, 'full_response': '{"candidates...
211+
1 {'result': 4.0, 'full_response': '{"candidates...
212+
2 {'result': 8.0, 'full_response': '{"candidates...
213+
dtype: struct<result: double, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
214+
215+
>>> bbq.ai.generate_double(("How many legs does a ", animal, " have?")).struct.field("result")
216+
0 2.0
217+
1 4.0
218+
2 8.0
219+
Name: result, dtype: Float64
220+
221+
Args:
222+
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
223+
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
224+
or pandas Series.
225+
connection_id (str, optional):
226+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
227+
If not provided, the connection from the current session will be used.
228+
endpoint (str, optional):
229+
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
230+
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
231+
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects a recent stable
232+
version of Gemini to use.
233+
request_type (Literal["dedicated", "shared", "unspecified"]):
234+
Specifies the type of inference request to send to the Gemini model. The request type determines what quota the request uses.
235+
* "dedicated": function only uses Provisioned Throughput quota. The function returns the error Provisioned throughput is not
236+
purchased or is not active if Provisioned Throughput quota isn't available.
237+
* "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota.
238+
* "unspecified": If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota.
239+
If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first.
240+
If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
241+
model_params (Mapping[Any, Any]):
242+
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.
243+
244+
Returns:
245+
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
246+
* "result": an DOUBLE value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
247+
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
248+
The generated text is in the text element.
249+
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
250+
"""
251+
252+
prompt_context, series_list = _separate_context_and_series(prompt)
253+
assert len(series_list) > 0
254+
255+
operator = ai_ops.AIGenerateDouble(
256+
prompt_context=tuple(prompt_context),
257+
connection_id=_resolve_connection_id(series_list[0], connection_id),
258+
endpoint=endpoint,
259+
request_type=request_type,
260+
model_params=json.dumps(model_params) if model_params else None,
261+
)
262+
263+
return series_list[0]._apply_nary_op(operator, series_list[1:])
264+
265+
191266
def _separate_context_and_series(
192267
prompt: PROMPT_TYPE,
193268
) -> Tuple[List[str | None], List[series.Series]]:

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1986,7 +1986,7 @@ def ai_generate_bool(
19861986

19871987
@scalar_op_compiler.register_nary_op(ops.AIGenerateInt, pass_op=True)
19881988
def ai_generate_int(
1989-
*values: ibis_types.Value, op: ops.AIGenerateBool
1989+
*values: ibis_types.Value, op: ops.AIGenerateInt
19901990
) -> ibis_types.StructValue:
19911991

19921992
return ai_ops.AIGenerateInt(
@@ -1998,6 +1998,20 @@ def ai_generate_int(
19981998
).to_expr()
19991999

20002000

2001+
@scalar_op_compiler.register_nary_op(ops.AIGenerateDouble, pass_op=True)
2002+
def ai_generate_double(
2003+
*values: ibis_types.Value, op: ops.AIGenerateDouble
2004+
) -> ibis_types.StructValue:
2005+
2006+
return ai_ops.AIGenerateDouble(
2007+
_construct_prompt(values, op.prompt_context), # type: ignore
2008+
op.connection_id, # type: ignore
2009+
op.endpoint, # type: ignore
2010+
op.request_type.upper(), # type: ignore
2011+
op.model_params, # type: ignore
2012+
).to_expr()
2013+
2014+
20012015
def _construct_prompt(
20022016
col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None]
20032017
) -> ibis_types.StructValue:

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateInt) -> sge.Expression:
4040
return sge.func("AI.GENERATE_INT", *args)
4141

4242

43+
@register_nary_op(ops.AIGenerateDouble, pass_op=True)
44+
def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression:
45+
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
46+
47+
return sge.func("AI.GENERATE_DOUBLE", *args)
48+
49+
4350
def _construct_prompt(
4451
exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...]
4552
) -> sge.Kwarg:

bigframes/operations/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateInt
17+
from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateDouble, AIGenerateInt
1818
from bigframes.operations.array_ops import (
1919
ArrayIndexOp,
2020
ArrayReduceOp,
@@ -413,6 +413,7 @@
413413
"GeoStDistanceOp",
414414
# AI ops
415415
"AIGenerateBool",
416+
"AIGenerateDouble",
416417
"AIGenerateInt",
417418
# Numpy ops mapping
418419
"NUMPY_TO_BINOP",

bigframes/operations/ai_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,25 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
6666
)
6767
)
6868
)
69+
70+
71+
@dataclasses.dataclass(frozen=True)
72+
class AIGenerateDouble(base_ops.NaryOp):
73+
name: ClassVar[str] = "ai_generate_double"
74+
75+
prompt_context: Tuple[str | None, ...]
76+
connection_id: str
77+
endpoint: str | None
78+
request_type: Literal["dedicated", "shared", "unspecified"]
79+
model_params: str | None
80+
81+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
82+
return pd.ArrowDtype(
83+
pa.struct(
84+
(
85+
pa.field("result", pa.float64()),
86+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
87+
pa.field("status", pa.string()),
88+
)
89+
)
90+
)

tests/system/small/bigquery/test_ai.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,44 @@ def test_ai_generate_int_multi_model(session):
146146
)
147147

148148

149+
def test_ai_generate_double(session):
150+
s = bpd.Series(["Cat"], session=session)
151+
prompt = ("How many legs does a ", s, " have?")
152+
153+
result = bbq.ai.generate_double(prompt, endpoint="gemini-2.5-flash")
154+
155+
assert _contains_no_nulls(result)
156+
assert result.dtype == pd.ArrowDtype(
157+
pa.struct(
158+
(
159+
pa.field("result", pa.float64()),
160+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
161+
pa.field("status", pa.string()),
162+
)
163+
)
164+
)
165+
166+
167+
def test_ai_generate_double_multi_model(session):
168+
df = session.from_glob_path(
169+
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
170+
)
171+
172+
result = bbq.ai.generate_double(
173+
("How many animals are there in the picture ", df["image"])
174+
)
175+
176+
assert _contains_no_nulls(result)
177+
assert result.dtype == pd.ArrowDtype(
178+
pa.struct(
179+
(
180+
pa.field("result", pa.float64()),
181+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
182+
pa.field("status", pa.string()),
183+
)
184+
)
185+
)
186+
187+
149188
def _contains_no_nulls(s: series.Series) -> bool:
150189
return len(s) == s.count()
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.GENERATE_DOUBLE(
9+
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10+
connection_id => 'test_connection_id',
11+
endpoint => 'gemini-2.5-flash',
12+
request_type => 'SHARED'
13+
) AS `bfcol_1`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_1` AS `result`
18+
FROM `bfcte_1`
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.GENERATE_DOUBLE(
9+
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10+
connection_id => 'test_connection_id',
11+
request_type => 'SHARED',
12+
model_params => JSON '{}'
13+
) AS `bfcol_1`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_1` AS `result`
18+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,48 @@ def test_ai_generate_int_with_model_param(
111111
)
112112

113113
snapshot.assert_match(sql, "out.sql")
114+
115+
116+
def test_ai_generate_double(scalar_types_df: dataframe.DataFrame, snapshot):
117+
col_name = "string_col"
118+
119+
op = ops.AIGenerateDouble(
120+
# The prompt does not make semantic sense but we only care about syntax correctness.
121+
prompt_context=(None, " is the same as ", None),
122+
connection_id="test_connection_id",
123+
endpoint="gemini-2.5-flash",
124+
request_type="shared",
125+
model_params=None,
126+
)
127+
128+
sql = utils._apply_unary_ops(
129+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
130+
)
131+
132+
snapshot.assert_match(sql, "out.sql")
133+
134+
135+
def test_ai_generate_double_with_model_param(
136+
scalar_types_df: dataframe.DataFrame, snapshot
137+
):
138+
if version.Version(sqlglot.__version__) < version.Version("25.18.0"):
139+
pytest.skip(
140+
"Skip test because SQLGLot cannot compile model params to JSON at this version."
141+
)
142+
143+
col_name = "string_col"
144+
145+
op = ops.AIGenerateDouble(
146+
# The prompt does not make semantic sense but we only care about syntax correctness.
147+
prompt_context=(None, " is the same as ", None),
148+
connection_id="test_connection_id",
149+
endpoint=None,
150+
request_type="shared",
151+
model_params=json.dumps(dict()),
152+
)
153+
154+
sql = utils._apply_unary_ops(
155+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
156+
)
157+
158+
snapshot.assert_match(sql, "out.sql")

third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,9 @@ def visit_AIGenerateBool(self, op, **kwargs):
11101110
def visit_AIGenerateInt(self, op, **kwargs):
11111111
return sge.func("AI.GENERATE_INT", *self._compile_ai_args(**kwargs))
11121112

1113+
def visit_AIGenerateDouble(self, op, **kwargs):
1114+
return sge.func("AI.GENERATE_DOUBLE", *self._compile_ai_args(**kwargs))
1115+
11131116
def _compile_ai_args(self, **kwargs):
11141117
args = []
11151118

0 commit comments

Comments
 (0)