Skip to content

Commit 56e5033

Browse files
authored
feat: add ai.classify() to bigframes.bigquery package (#2137)
1 parent eca22ee commit 56e5033

File tree

10 files changed

+189
-25
lines changed

10 files changed

+189
-25
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -348,20 +348,20 @@ def if_(
348348
provides optimization such that not all rows are evaluated with the LLM.
349349
350350
**Examples:**
351-
>>> import bigframes.pandas as bpd
352-
>>> import bigframes.bigquery as bbq
353-
>>> bpd.options.display.progress_bar = None
354-
>>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
355-
>>> bbq.ai.if_((us_state, " has a city called Springfield"))
356-
0 True
357-
1 True
358-
2 False
359-
dtype: boolean
360-
361-
>>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
362-
0 Massachusetts
363-
1 Illinois
364-
dtype: string
351+
>>> import bigframes.pandas as bpd
352+
>>> import bigframes.bigquery as bbq
353+
>>> bpd.options.display.progress_bar = None
354+
>>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
355+
>>> bbq.ai.if_((us_state, " has a city called Springfield"))
356+
0 True
357+
1 True
358+
2 False
359+
dtype: boolean
360+
361+
>>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
362+
0 Massachusetts
363+
1 Illinois
364+
dtype: string
365365
366366
Args:
367367
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
@@ -386,6 +386,56 @@ def if_(
386386
return series_list[0]._apply_nary_op(operator, series_list[1:])
387387

388388

389+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
390+
def classify(
391+
input: PROMPT_TYPE,
392+
categories: tuple[str, ...] | list[str],
393+
*,
394+
connection_id: str | None = None,
395+
) -> series.Series:
396+
"""
397+
Classifies a given input into one of the specified categories. It will always return one of the provided categories best fit the prompt input.
398+
399+
**Examples:**
400+
401+
>>> import bigframes.pandas as bpd
402+
>>> import bigframes.bigquery as bbq
403+
>>> bpd.options.display.progress_bar = None
404+
>>> df = bpd.DataFrame({'creature': ['Cat', 'Salmon']})
405+
>>> df['type'] = bbq.ai.classify(df['creature'], ['Mammal', 'Fish'])
406+
>>> df
407+
creature type
408+
0 Cat Mammal
409+
1 Salmon Fish
410+
<BLANKLINE>
411+
[2 rows x 2 columns]
412+
413+
Args:
414+
input (Series | List[str|Series] | Tuple[str|Series, ...]):
415+
A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series
416+
or pandas Series.
417+
categories (tuple[str, ...] | list[str]):
418+
Categories to classify the input into.
419+
connection_id (str, optional):
420+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
421+
If not provided, the connection from the current session will be used.
422+
423+
Returns:
424+
bigframes.series.Series: A new series of strings.
425+
"""
426+
427+
prompt_context, series_list = _separate_context_and_series(input)
428+
assert len(series_list) > 0
429+
430+
operator = ai_ops.AIClassify(
431+
prompt_context=tuple(prompt_context),
432+
categories=tuple(categories),
433+
connection_id=_resolve_connection_id(series_list[0], connection_id),
434+
)
435+
436+
return series_list[0]._apply_nary_op(operator, series_list[1:])
437+
438+
389439
@log_adapter.method_logger(custom_base_name="bigquery_ai")
390440
def score(
391441
prompt: PROMPT_TYPE,
@@ -398,15 +448,16 @@ def score(
398448
rubric with examples in the prompt.
399449
400450
**Examples:**
401-
>>> import bigframes.pandas as bpd
402-
>>> import bigframes.bigquery as bbq
403-
>>> bpd.options.display.progress_bar = None
404-
>>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
405-
>>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
406-
0 2.0
407-
1 1.0
408-
2 3.0
409-
dtype: Float64
451+
452+
>>> import bigframes.pandas as bpd
453+
>>> import bigframes.bigquery as bbq
454+
>>> bpd.options.display.progress_bar = None
455+
>>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
456+
>>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
457+
0 2.0
458+
1 1.0
459+
2 3.0
460+
dtype: Float64
410461
411462
Args:
412463
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,6 +2039,18 @@ def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:
20392039
).to_expr()
20402040

20412041

2042+
@scalar_op_compiler.register_nary_op(ops.AIClassify, pass_op=True)
2043+
def ai_classify(
2044+
*values: ibis_types.Value, op: ops.AIClassify
2045+
) -> ibis_types.StructValue:
2046+
2047+
return ai_ops.AIClassify(
2048+
_construct_prompt(values, op.prompt_context), # type: ignore
2049+
op.categories, # type: ignore
2050+
op.connection_id, # type: ignore
2051+
).to_expr()
2052+
2053+
20422054
@scalar_op_compiler.register_nary_op(ops.AIScore, pass_op=True)
20432055
def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructValue:
20442056

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,21 @@ def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression:
6161
return sge.func("AI.IF", *args)
6262

6363

64+
@register_nary_op(ops.AIClassify, pass_op=True)
65+
def _(*exprs: TypedExpr, op: ops.AIClassify) -> sge.Expression:
66+
category_literals = [sge.Literal.string(cat) for cat in op.categories]
67+
categories_arg = sge.Kwarg(
68+
this="categories", expression=sge.array(*category_literals)
69+
)
70+
71+
args = [
72+
_construct_prompt(exprs, op.prompt_context, param_name="input"),
73+
categories_arg,
74+
] + _construct_named_args(op)
75+
76+
return sge.func("AI.CLASSIFY", *args)
77+
78+
6479
@register_nary_op(ops.AIScore, pass_op=True)
6580
def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
6681
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
@@ -69,7 +84,9 @@ def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
6984

7085

7186
def _construct_prompt(
72-
exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...]
87+
exprs: tuple[TypedExpr, ...],
88+
prompt_context: tuple[str | None, ...],
89+
param_name: str = "prompt",
7390
) -> sge.Kwarg:
7491
prompt: list[str | sge.Expression] = []
7592
column_ref_idx = 0
@@ -80,7 +97,7 @@ def _construct_prompt(
8097
else:
8198
prompt.append(sge.Literal.string(elem))
8299

83-
return sge.Kwarg(this="prompt", expression=sge.Tuple(expressions=prompt))
100+
return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt))
84101

85102

86103
def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:

bigframes/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from bigframes.operations.ai_ops import (
18+
AIClassify,
1819
AIGenerate,
1920
AIGenerateBool,
2021
AIGenerateDouble,
@@ -419,6 +420,7 @@
419420
"geo_y_op",
420421
"GeoStDistanceOp",
421422
# AI ops
423+
"AIClassify",
422424
"AIGenerate",
423425
"AIGenerateBool",
424426
"AIGenerateDouble",

bigframes/operations/ai_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,18 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
123123
return dtypes.BOOL_DTYPE
124124

125125

126+
@dataclasses.dataclass(frozen=True)
127+
class AIClassify(base_ops.NaryOp):
128+
name: ClassVar[str] = "ai_classify"
129+
130+
prompt_context: Tuple[str | None, ...]
131+
categories: tuple[str, ...]
132+
connection_id: str
133+
134+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
135+
return dtypes.STRING_DTYPE
136+
137+
126138
@dataclasses.dataclass(frozen=True)
127139
class AIScore(base_ops.NaryOp):
128140
name: ClassVar[str] = "ai_score"

tests/system/small/bigquery/test_ai.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,27 @@ def test_ai_if_multi_model(session):
225225
assert result.dtype == dtypes.BOOL_DTYPE
226226

227227

228+
def test_ai_classify(session):
229+
s = bpd.Series(["cat", "orchid"], session=session)
230+
bpd.options.display.repr_mode = "deferred"
231+
232+
result = bbq.ai.classify(s, ["animal", "plant"])
233+
234+
assert _contains_no_nulls(result)
235+
assert result.dtype == dtypes.STRING_DTYPE
236+
237+
238+
def test_ai_classify_multi_model(session):
239+
df = session.from_glob_path(
240+
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
241+
)
242+
243+
result = bbq.ai.classify(df["image"], ["photo", "cartoon"])
244+
245+
assert _contains_no_nulls(result)
246+
assert result.dtype == dtypes.STRING_DTYPE
247+
248+
228249
def test_ai_score(session):
229250
s = bpd.Series(["Tiger", "Rabbit"], session=session)
230251
prompt = ("Rank the relative weights of ", s, " on the scale from 1 to 3")
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.CLASSIFY(
9+
input => (`bfcol_0`),
10+
categories => ['greeting', 'rejection'],
11+
connection_id => 'bigframes-dev.us.bigframes-default-connection'
12+
) AS `bfcol_1`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_1` AS `result`
17+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,20 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot):
216216
snapshot.assert_match(sql, "out.sql")
217217

218218

219+
def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot):
220+
col_name = "string_col"
221+
222+
op = ops.AIClassify(
223+
prompt_context=(None,),
224+
categories=("greeting", "rejection"),
225+
connection_id=CONNECTION_ID,
226+
)
227+
228+
sql = utils._apply_unary_ops(scalar_types_df, [op.as_expr(col_name)], ["result"])
229+
230+
snapshot.assert_match(sql, "out.sql")
231+
232+
219233
def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot):
220234
col_name = "string_col"
221235

third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,9 @@ def visit_AIGenerateDouble(self, op, **kwargs):
11191119
def visit_AIIf(self, op, **kwargs):
11201120
return sge.func("AI.IF", *self._compile_ai_args(**kwargs))
11211121

1122+
def visit_AIClassify(self, op, **kwargs):
1123+
return sge.func("AI.CLASSIFY", *self._compile_ai_args(**kwargs))
1124+
11221125
def visit_AIScore(self, op, **kwargs):
11231126
return sge.func("AI.SCORE", *self._compile_ai_args(**kwargs))
11241127

third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,21 @@ def dtype(self) -> dt.Struct:
107107
return dt.bool
108108

109109

110+
@public
111+
class AIClassify(Value):
112+
"""Generate True/False based on the prompt"""
113+
114+
input: Value
115+
categories: Value[dt.Array[dt.String]]
116+
connection_id: Value[dt.String]
117+
118+
shape = rlz.shape_like("input")
119+
120+
@attribute
121+
def dtype(self) -> dt.Struct:
122+
return dt.string
123+
124+
110125
@public
111126
class AIScore(Value):
112127
"""Generate doubles based on the prompt"""

0 commit comments

Comments
 (0)