Skip to content

Commit 32502f4

Browse files
authored
feat: add ai.if_() and ai.score() to bigframes.bigquery package (#2132)
* feat: add ai.if_() and ai.score() to bigframes.bigquery package * deflake ai score doc test
1 parent c390da1 commit 32502f4

File tree

11 files changed

+299
-4
lines changed

11 files changed

+299
-4
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,100 @@ def generate_double(
337337
return series_list[0]._apply_nary_op(operator, series_list[1:])
338338

339339

340+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
341+
def if_(
342+
prompt: PROMPT_TYPE,
343+
*,
344+
connection_id: str | None = None,
345+
) -> series.Series:
346+
"""
347+
Evaluates the prompt to True or False. Compared to `ai.generate_bool()`, this function
348+
provides optimization such that not all rows are evaluated with the LLM.
349+
350+
**Examples:**
351+
>>> import bigframes.pandas as bpd
352+
>>> import bigframes.bigquery as bbq
353+
>>> bpd.options.display.progress_bar = None
354+
>>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
355+
>>> bbq.ai.if_((us_state, " has a city called Springfield"))
356+
0 True
357+
1 True
358+
2 False
359+
dtype: boolean
360+
361+
>>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
362+
0 Massachusetts
363+
1 Illinois
364+
dtype: string
365+
366+
Args:
367+
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
368+
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
369+
or pandas Series.
370+
connection_id (str, optional):
371+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
372+
If not provided, the connection from the current session will be used.
373+
374+
Returns:
375+
bigframes.series.Series: A new series of bools.
376+
"""
377+
378+
prompt_context, series_list = _separate_context_and_series(prompt)
379+
assert len(series_list) > 0
380+
381+
operator = ai_ops.AIIf(
382+
prompt_context=tuple(prompt_context),
383+
connection_id=_resolve_connection_id(series_list[0], connection_id),
384+
)
385+
386+
return series_list[0]._apply_nary_op(operator, series_list[1:])
387+
388+
389+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
390+
def score(
391+
prompt: PROMPT_TYPE,
392+
*,
393+
connection_id: str | None = None,
394+
) -> series.Series:
395+
"""
396+
Computes a score based on rubrics described in natural language. It will return a double value.
397+
There is no fixed range for the score returned. To get high quality results, provide a scoring
398+
rubric with examples in the prompt.
399+
400+
**Examples:**
401+
>>> import bigframes.pandas as bpd
402+
>>> import bigframes.bigquery as bbq
403+
>>> bpd.options.display.progress_bar = None
404+
>>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
405+
>>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
406+
0 2.0
407+
1 1.0
408+
2 3.0
409+
dtype: Float64
410+
411+
Args:
412+
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
413+
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
414+
or pandas Series.
415+
connection_id (str, optional):
416+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
417+
If not provided, the connection from the current session will be used.
418+
419+
Returns:
420+
bigframes.series.Series: A new series of double (float) values.
421+
"""
422+
423+
prompt_context, series_list = _separate_context_and_series(prompt)
424+
assert len(series_list) > 0
425+
426+
operator = ai_ops.AIScore(
427+
prompt_context=tuple(prompt_context),
428+
connection_id=_resolve_connection_id(series_list[0], connection_id),
429+
)
430+
431+
return series_list[0]._apply_nary_op(operator, series_list[1:])
432+
433+
340434
def _separate_context_and_series(
341435
prompt: PROMPT_TYPE,
342436
) -> Tuple[List[str | None], List[series.Series]]:

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,6 +2030,24 @@ def ai_generate_double(
20302030
).to_expr()
20312031

20322032

2033+
@scalar_op_compiler.register_nary_op(ops.AIIf, pass_op=True)
2034+
def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:
2035+
2036+
return ai_ops.AIIf(
2037+
_construct_prompt(values, op.prompt_context), # type: ignore
2038+
op.connection_id, # type: ignore
2039+
).to_expr()
2040+
2041+
2042+
@scalar_op_compiler.register_nary_op(ops.AIScore, pass_op=True)
2043+
def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructValue:
2044+
2045+
return ai_ops.AIScore(
2046+
_construct_prompt(values, op.prompt_context), # type: ignore
2047+
op.connection_id, # type: ignore
2048+
).to_expr()
2049+
2050+
20332051
def _construct_prompt(
20342052
col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None]
20352053
) -> ibis_types.StructValue:

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression:
5454
return sge.func("AI.GENERATE_DOUBLE", *args)
5555

5656

57+
@register_nary_op(ops.AIIf, pass_op=True)
58+
def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression:
59+
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
60+
61+
return sge.func("AI.IF", *args)
62+
63+
64+
@register_nary_op(ops.AIScore, pass_op=True)
65+
def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
66+
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
67+
68+
return sge.func("AI.SCORE", *args)
69+
70+
5771
def _construct_prompt(
5872
exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...]
5973
) -> sge.Kwarg:
@@ -83,10 +97,13 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
8397
if endpoit is not None:
8498
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit)))
8599

86-
request_type = typing.cast(str, op_args["request_type"]).upper()
87-
args.append(
88-
sge.Kwarg(this="request_type", expression=sge.Literal.string(request_type))
89-
)
100+
request_type = typing.cast(str, op_args.get("request_type", None))
101+
if request_type is not None:
102+
args.append(
103+
sge.Kwarg(
104+
this="request_type", expression=sge.Literal.string(request_type.upper())
105+
)
106+
)
90107

91108
model_params = typing.cast(str, op_args.get("model_params", None))
92109
if model_params is not None:

bigframes/operations/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
AIGenerateBool,
2020
AIGenerateDouble,
2121
AIGenerateInt,
22+
AIIf,
23+
AIScore,
2224
)
2325
from bigframes.operations.array_ops import (
2426
ArrayIndexOp,
@@ -421,6 +423,8 @@
421423
"AIGenerateBool",
422424
"AIGenerateDouble",
423425
"AIGenerateInt",
426+
"AIIf",
427+
"AIScore",
424428
# Numpy ops mapping
425429
"NUMPY_TO_BINOP",
426430
"NUMPY_TO_OP",

bigframes/operations/ai_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,25 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
110110
)
111111
)
112112
)
113+
114+
115+
@dataclasses.dataclass(frozen=True)
116+
class AIIf(base_ops.NaryOp):
117+
name: ClassVar[str] = "ai_if"
118+
119+
prompt_context: Tuple[str | None, ...]
120+
connection_id: str
121+
122+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
123+
return dtypes.BOOL_DTYPE
124+
125+
126+
@dataclasses.dataclass(frozen=True)
127+
class AIScore(base_ops.NaryOp):
128+
name: ClassVar[str] = "ai_score"
129+
130+
prompt_context: Tuple[str | None, ...]
131+
connection_id: str
132+
133+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
134+
return dtypes.FLOAT_DTYPE

tests/system/small/bigquery/test_ai.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,5 +203,49 @@ def test_ai_generate_double_multi_model(session):
203203
)
204204

205205

206+
def test_ai_if(session):
207+
s1 = bpd.Series(["apple", "bear"], session=session)
208+
s2 = bpd.Series(["fruit", "tree"], session=session)
209+
prompt = (s1, " is a ", s2)
210+
211+
result = bbq.ai.if_(prompt)
212+
213+
assert _contains_no_nulls(result)
214+
assert result.dtype == dtypes.BOOL_DTYPE
215+
216+
217+
def test_ai_if_multi_model(session):
218+
df = session.from_glob_path(
219+
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
220+
)
221+
222+
result = bbq.ai.if_((df["image"], " contains an animal"))
223+
224+
assert _contains_no_nulls(result)
225+
assert result.dtype == dtypes.BOOL_DTYPE
226+
227+
228+
def test_ai_score(session):
229+
s = bpd.Series(["Tiger", "Rabbit"], session=session)
230+
prompt = ("Rank the relative weights of ", s, " on the scale from 1 to 3")
231+
232+
result = bbq.ai.score(prompt)
233+
234+
assert _contains_no_nulls(result)
235+
assert result.dtype == dtypes.FLOAT_DTYPE
236+
237+
238+
def test_ai_score_multi_model(session):
239+
df = session.from_glob_path(
240+
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
241+
)
242+
prompt = ("Rank the liveliness of ", df["image"], "on the scale from 1 to 3")
243+
244+
result = bbq.ai.score(prompt)
245+
246+
assert _contains_no_nulls(result)
247+
assert result.dtype == dtypes.FLOAT_DTYPE
248+
249+
206250
def _contains_no_nulls(s: series.Series) -> bool:
207251
return len(s) == s.count()
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.IF(
9+
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection'
11+
) AS `bfcol_1`
12+
FROM `bfcte_0`
13+
)
14+
SELECT
15+
`bfcol_1` AS `result`
16+
FROM `bfcte_1`
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.SCORE(
9+
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection'
11+
) AS `bfcol_1`
12+
FROM `bfcte_0`
13+
)
14+
SELECT
15+
`bfcol_1` AS `result`
16+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,33 @@ def test_ai_generate_double_with_model_param(
199199
)
200200

201201
snapshot.assert_match(sql, "out.sql")
202+
203+
204+
def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot):
205+
col_name = "string_col"
206+
207+
op = ops.AIIf(
208+
prompt_context=(None, " is the same as ", None),
209+
connection_id=CONNECTION_ID,
210+
)
211+
212+
sql = utils._apply_unary_ops(
213+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
214+
)
215+
216+
snapshot.assert_match(sql, "out.sql")
217+
218+
219+
def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot):
220+
col_name = "string_col"
221+
222+
op = ops.AIScore(
223+
prompt_context=(None, " is the same as ", None),
224+
connection_id=CONNECTION_ID,
225+
)
226+
227+
sql = utils._apply_unary_ops(
228+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
229+
)
230+
231+
snapshot.assert_match(sql, "out.sql")

third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,12 @@ def visit_AIGenerateInt(self, op, **kwargs):
11161116
def visit_AIGenerateDouble(self, op, **kwargs):
11171117
return sge.func("AI.GENERATE_DOUBLE", *self._compile_ai_args(**kwargs))
11181118

1119+
def visit_AIIf(self, op, **kwargs):
1120+
return sge.func("AI.IF", *self._compile_ai_args(**kwargs))
1121+
1122+
def visit_AIScore(self, op, **kwargs):
1123+
return sge.func("AI.SCORE", *self._compile_ai_args(**kwargs))
1124+
11191125
def _compile_ai_args(self, **kwargs):
11201126
args = []
11211127

0 commit comments

Comments
 (0)