Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,100 @@ def generate_double(
return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def if_(
prompt: PROMPT_TYPE,
*,
connection_id: str | None = None,
) -> series.Series:
"""
Evaluates the prompt to True or False. Compared to `ai.generate_bool()`, this function
provides optimization such that not all rows are evaluated with the LLM.

**Examples:**
>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> bpd.options.display.progress_bar = None
>>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
>>> bbq.ai.if_((us_state, " has a city called Springfield"))
0 True
1 True
2 False
dtype: boolean

>>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
0 Massachusetts
1 Illinois
dtype: string

Args:
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
or pandas Series.
connection_id (str, optional):
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
If not provided, the connection from the current session will be used.

Returns:
bigframes.series.Series: A new series of bools.
"""

prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0

operator = ai_ops.AIIf(
prompt_context=tuple(prompt_context),
connection_id=_resolve_connection_id(series_list[0], connection_id),
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def score(
prompt: PROMPT_TYPE,
*,
connection_id: str | None = None,
) -> series.Series:
"""
Computes a score based on rubrics described in natural language. It will return a double value.
There is no fixed range for the score returned. To get high quality results, provide a scoring
rubric with examples in the prompt.

**Examples:**
>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> bpd.options.display.progress_bar = None
>>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
>>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3"))
0 2.0
1 1.0
2 3.0
dtype: Float64

Args:
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
or pandas Series.
connection_id (str, optional):
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
If not provided, the connection from the current session will be used.

Returns:
bigframes.series.Series: A new series of double (float) values.
"""

prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0

operator = ai_ops.AIScore(
prompt_context=tuple(prompt_context),
connection_id=_resolve_connection_id(series_list[0], connection_id),
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


def _separate_context_and_series(
prompt: PROMPT_TYPE,
) -> Tuple[List[str | None], List[series.Series]]:
Expand Down
18 changes: 18 additions & 0 deletions bigframes/core/compile/ibis_compiler/scalar_op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2030,6 +2030,24 @@ def ai_generate_double(
).to_expr()


@scalar_op_compiler.register_nary_op(ops.AIIf, pass_op=True)
def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:

return ai_ops.AIIf(
_construct_prompt(values, op.prompt_context), # type: ignore
op.connection_id, # type: ignore
).to_expr()


@scalar_op_compiler.register_nary_op(ops.AIScore, pass_op=True)
def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructValue:

return ai_ops.AIScore(
_construct_prompt(values, op.prompt_context), # type: ignore
op.connection_id, # type: ignore
).to_expr()


def _construct_prompt(
col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None]
) -> ibis_types.StructValue:
Expand Down
25 changes: 21 additions & 4 deletions bigframes/core/compile/sqlglot/expressions/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression:
return sge.func("AI.GENERATE_DOUBLE", *args)


@register_nary_op(ops.AIIf, pass_op=True)
def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression:
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)

return sge.func("AI.IF", *args)


@register_nary_op(ops.AIScore, pass_op=True)
def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)

return sge.func("AI.SCORE", *args)


def _construct_prompt(
exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...]
) -> sge.Kwarg:
Expand Down Expand Up @@ -83,10 +97,13 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
if endpoit is not None:
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit)))

request_type = typing.cast(str, op_args["request_type"]).upper()
args.append(
sge.Kwarg(this="request_type", expression=sge.Literal.string(request_type))
)
request_type = typing.cast(str, op_args.get("request_type", None))
if request_type is not None:
args.append(
sge.Kwarg(
this="request_type", expression=sge.Literal.string(request_type.upper())
)
)

model_params = typing.cast(str, op_args.get("model_params", None))
if model_params is not None:
Expand Down
4 changes: 4 additions & 0 deletions bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
AIGenerateBool,
AIGenerateDouble,
AIGenerateInt,
AIIf,
AIScore,
)
from bigframes.operations.array_ops import (
ArrayIndexOp,
Expand Down Expand Up @@ -421,6 +423,8 @@
"AIGenerateBool",
"AIGenerateDouble",
"AIGenerateInt",
"AIIf",
"AIScore",
# Numpy ops mapping
"NUMPY_TO_BINOP",
"NUMPY_TO_OP",
Expand Down
22 changes: 22 additions & 0 deletions bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,25 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
)
)
)


@dataclasses.dataclass(frozen=True)
class AIIf(base_ops.NaryOp):
name: ClassVar[str] = "ai_if"

prompt_context: Tuple[str | None, ...]
connection_id: str

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return dtypes.BOOL_DTYPE


@dataclasses.dataclass(frozen=True)
class AIScore(base_ops.NaryOp):
name: ClassVar[str] = "ai_score"

prompt_context: Tuple[str | None, ...]
connection_id: str

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return dtypes.FLOAT_DTYPE
44 changes: 44 additions & 0 deletions tests/system/small/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,5 +203,49 @@ def test_ai_generate_double_multi_model(session):
)


def test_ai_if(session):
s1 = bpd.Series(["apple", "bear"], session=session)
s2 = bpd.Series(["fruit", "tree"], session=session)
prompt = (s1, " is a ", s2)

result = bbq.ai.if_(prompt)

assert _contains_no_nulls(result)
assert result.dtype == dtypes.BOOL_DTYPE


def test_ai_if_multi_model(session):
df = session.from_glob_path(
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
)

result = bbq.ai.if_((df["image"], " contains an animal"))

assert _contains_no_nulls(result)
assert result.dtype == dtypes.BOOL_DTYPE


def test_ai_score(session):
s = bpd.Series(["Tiger", "Rabbit"], session=session)
prompt = ("Rank the relative weights of ", s, " on the scale from 1 to 3")

result = bbq.ai.score(prompt)

assert _contains_no_nulls(result)
assert result.dtype == dtypes.FLOAT_DTYPE


def test_ai_score_multi_model(session):
df = session.from_glob_path(
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
)
prompt = ("Rank the liveliness of ", df["image"], "on the scale from 1 to 3")

result = bbq.ai.score(prompt)

assert _contains_no_nulls(result)
assert result.dtype == dtypes.FLOAT_DTYPE


def _contains_no_nulls(s: series.Series) -> bool:
return len(s) == s.count()
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
AI.IF(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'bigframes-dev.us.bigframes-default-connection'
) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `result`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
AI.SCORE(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'bigframes-dev.us.bigframes-default-connection'
) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `result`
FROM `bfcte_1`
30 changes: 30 additions & 0 deletions tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,33 @@ def test_ai_generate_double_with_model_param(
)

snapshot.assert_match(sql, "out.sql")


def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

op = ops.AIIf(
prompt_context=(None, " is the same as ", None),
connection_id=CONNECTION_ID,
)

sql = utils._apply_unary_ops(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

snapshot.assert_match(sql, "out.sql")


def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

op = ops.AIScore(
prompt_context=(None, " is the same as ", None),
connection_id=CONNECTION_ID,
)

sql = utils._apply_unary_ops(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

snapshot.assert_match(sql, "out.sql")
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,12 @@ def visit_AIGenerateInt(self, op, **kwargs):
def visit_AIGenerateDouble(self, op, **kwargs):
return sge.func("AI.GENERATE_DOUBLE", *self._compile_ai_args(**kwargs))

def visit_AIIf(self, op, **kwargs):
return sge.func("AI.IF", *self._compile_ai_args(**kwargs))

def visit_AIScore(self, op, **kwargs):
return sge.func("AI.SCORE", *self._compile_ai_args(**kwargs))

def _compile_ai_args(self, **kwargs):
args = []

Expand Down
28 changes: 28 additions & 0 deletions third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,31 @@ def dtype(self) -> dt.Struct:
("status", dt.string),
)
)


@public
class AIIf(Value):
"""Generate True/False based on the prompt"""

prompt: Value
connection_id: Value[dt.String]

shape = rlz.shape_like("prompt")

@attribute
def dtype(self) -> dt.Struct:
return dt.bool


@public
class AIScore(Value):
"""Generate doubles based on the prompt"""

prompt: Value
connection_id: Value[dt.String]

shape = rlz.shape_like("prompt")

@attribute
def dtype(self) -> dt.Struct:
return dt.float64