Skip to content

Commit 3810452

Browse files
authored
feat: add ai.generate() to bigframes.bigquery module (#2128)
* feat: add ai.generate() to bigframes.bigquery module * fix doc * fix doc test
1 parent 8035e01 commit 3810452

File tree

17 files changed

+256
-14
lines changed

17 files changed

+256
-14
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,80 @@
3535
]
3636

3737

38+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
39+
def generate(
40+
prompt: PROMPT_TYPE,
41+
*,
42+
connection_id: str | None = None,
43+
endpoint: str | None = None,
44+
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
45+
model_params: Mapping[Any, Any] | None = None,
46+
# TODO(b/446974666) Add output_schema parameter
47+
) -> series.Series:
48+
"""
49+
Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.
50+
51+
**Examples:**
52+
53+
>>> import bigframes.pandas as bpd
54+
>>> import bigframes.bigquery as bbq
55+
>>> bpd.options.display.progress_bar = None
56+
>>> country = bpd.Series(["Japan", "Canada"])
57+
>>> bbq.ai.generate(("What's the capital city of ", country, " one word only"))
58+
0 {'result': 'Tokyo\\n', 'full_response': '{"cand...
59+
1 {'result': 'Ottawa\\n', 'full_response': '{"can...
60+
dtype: struct<result: string, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
61+
62+
>>> bbq.ai.generate(("What's the capital city of ", country, " one word only")).struct.field("result")
63+
0 Tokyo\\n
64+
1 Ottawa\\n
65+
Name: result, dtype: string
66+
67+
Args:
68+
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
69+
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
70+
or pandas Series.
71+
connection_id (str, optional):
72+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
73+
If not provided, the connection from the current session will be used.
74+
endpoint (str, optional):
75+
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
76+
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
77+
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects a recent stable
78+
version of Gemini to use.
79+
request_type (Literal["dedicated", "shared", "unspecified"]):
80+
Specifies the type of inference request to send to the Gemini model. The request type determines what quota the request uses.
81+
* "dedicated": function only uses Provisioned Throughput quota. The function returns the error Provisioned throughput is not
82+
purchased or is not active if Provisioned Throughput quota isn't available.
83+
* "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota.
84+
* "unspecified": If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota.
85+
If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first.
86+
If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
87+
model_params (Mapping[Any, Any]):
88+
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.
89+
90+
Returns:
91+
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
92+
* "result": a STRING value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
93+
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
94+
The generated text is in the text element.
95+
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
96+
"""
97+
98+
prompt_context, series_list = _separate_context_and_series(prompt)
99+
assert len(series_list) > 0
100+
101+
operator = ai_ops.AIGenerate(
102+
prompt_context=tuple(prompt_context),
103+
connection_id=_resolve_connection_id(series_list[0], connection_id),
104+
endpoint=endpoint,
105+
request_type=request_type,
106+
model_params=json.dumps(model_params) if model_params else None,
107+
)
108+
109+
return series_list[0]._apply_nary_op(operator, series_list[1:])
110+
111+
38112
@log_adapter.method_logger(custom_base_name="bigquery_ai")
39113
def generate_bool(
40114
prompt: PROMPT_TYPE,

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1974,6 +1974,20 @@ def struct_op_impl(
19741974
return ibis_types.struct(data)
19751975

19761976

1977+
@scalar_op_compiler.register_nary_op(ops.AIGenerate, pass_op=True)
1978+
def ai_generate(
1979+
*values: ibis_types.Value, op: ops.AIGenerate
1980+
) -> ibis_types.StructValue:
1981+
1982+
return ai_ops.AIGenerate(
1983+
_construct_prompt(values, op.prompt_context), # type: ignore
1984+
op.connection_id, # type: ignore
1985+
op.endpoint, # type: ignore
1986+
op.request_type.upper(), # type: ignore
1987+
op.model_params, # type: ignore
1988+
).to_expr()
1989+
1990+
19771991
@scalar_op_compiler.register_nary_op(ops.AIGenerateBool, pass_op=True)
19781992
def ai_generate_bool(
19791993
*values: ibis_types.Value, op: ops.AIGenerateBool

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
2727

2828

29+
@register_nary_op(ops.AIGenerate, pass_op=True)
30+
def _(*exprs: TypedExpr, op: ops.AIGenerate) -> sge.Expression:
31+
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
32+
33+
return sge.func("AI.GENERATE", *args)
34+
35+
2936
@register_nary_op(ops.AIGenerateBool, pass_op=True)
3037
def _(*exprs: TypedExpr, op: ops.AIGenerateBool) -> sge.Expression:
3138
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)

bigframes/operations/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
from __future__ import annotations
1616

17-
from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateDouble, AIGenerateInt
17+
from bigframes.operations.ai_ops import (
18+
AIGenerate,
19+
AIGenerateBool,
20+
AIGenerateDouble,
21+
AIGenerateInt,
22+
)
1823
from bigframes.operations.array_ops import (
1924
ArrayIndexOp,
2025
ArrayReduceOp,
@@ -412,6 +417,7 @@
412417
"geo_y_op",
413418
"GeoStDistanceOp",
414419
# AI ops
420+
"AIGenerate",
415421
"AIGenerateBool",
416422
"AIGenerateDouble",
417423
"AIGenerateInt",

bigframes/operations/ai_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,28 @@
2424
from bigframes.operations import base_ops
2525

2626

27+
@dataclasses.dataclass(frozen=True)
28+
class AIGenerate(base_ops.NaryOp):
29+
name: ClassVar[str] = "ai_generate"
30+
31+
prompt_context: Tuple[str | None, ...]
32+
connection_id: str
33+
endpoint: str | None
34+
request_type: Literal["dedicated", "shared", "unspecified"]
35+
model_params: str | None
36+
37+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
38+
return pd.ArrowDtype(
39+
pa.struct(
40+
(
41+
pa.field("result", pa.string()),
42+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
43+
pa.field("status", pa.string()),
44+
)
45+
)
46+
)
47+
48+
2749
@dataclasses.dataclass(frozen=True)
2850
class AIGenerateBool(base_ops.NaryOp):
2951
name: ClassVar[str] = "ai_generate_bool"

tests/system/small/bigquery/test_ai.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,24 @@ def test_ai_function_compile_model_params(session):
6969
)
7070

7171

72+
def test_ai_generate(session):
73+
country = bpd.Series(["Japan", "Canada"], session=session)
74+
prompt = ("What's the capital city of ", country, "? one word only")
75+
76+
result = bbq.ai.generate(prompt, endpoint="gemini-2.5-flash")
77+
78+
assert _contains_no_nulls(result)
79+
assert result.dtype == pd.ArrowDtype(
80+
pa.struct(
81+
(
82+
pa.field("result", pa.string()),
83+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
84+
pa.field("status", pa.string()),
85+
)
86+
)
87+
)
88+
89+
7290
def test_ai_generate_bool(session):
7391
s1 = bpd.Series(["apple", "bear"], session=session)
7492
s2 = bpd.Series(["fruit", "tree"], session=session)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.GENERATE(
9+
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
11+
endpoint => 'gemini-2.5-flash',
12+
request_type => 'SHARED'
13+
) AS `bfcol_1`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_1` AS `result`
18+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
77
*,
88
AI.GENERATE_BOOL(
99
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10-
connection_id => 'test_connection_id',
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
1111
endpoint => 'gemini-2.5-flash',
1212
request_type => 'SHARED'
1313
) AS `bfcol_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
77
*,
88
AI.GENERATE_BOOL(
99
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10-
connection_id => 'test_connection_id',
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
1111
request_type => 'SHARED',
1212
model_params => JSON '{}'
1313
) AS `bfcol_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
77
*,
88
AI.GENERATE_DOUBLE(
99
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10-
connection_id => 'test_connection_id',
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
1111
endpoint => 'gemini-2.5-flash',
1212
request_type => 'SHARED'
1313
) AS `bfcol_1`

0 commit comments

Comments
 (0)