|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""This module integrates BigQuery built-in AI functions for use with Series/DataFrame objects, |
| 16 | +such as AI.GENERATE_BOOL: |
| 17 | +https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-generate-bool""" |
| 18 | + |
| 19 | +from __future__ import annotations |
| 20 | + |
| 21 | +import json |
| 22 | +from typing import Any, List, Literal, Mapping, Tuple |
| 23 | + |
| 24 | +from bigframes import clients, dtypes, series |
| 25 | +from bigframes.core import log_adapter |
| 26 | +from bigframes.operations import ai_ops |
| 27 | + |
| 28 | + |
| 29 | +@log_adapter.method_logger(custom_base_name="bigquery_ai") |
| 30 | +def generate_bool( |
| 31 | + prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...], |
| 32 | + *, |
| 33 | + connection_id: str | None = None, |
| 34 | + endpoint: str | None = None, |
| 35 | + request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified", |
| 36 | + model_params: Mapping[Any, Any] | None = None, |
| 37 | +) -> series.Series: |
| 38 | + """ |
| 39 | + Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data. |
| 40 | +
|
| 41 | + **Examples:** |
| 42 | +
|
| 43 | + >>> import bigframes.pandas as bpd |
| 44 | + >>> import bigframes.bigquery as bbq |
| 45 | + >>> bpd.options.display.progress_bar = None |
| 46 | + >>> df = bpd.DataFrame({ |
| 47 | + ... "col_1": ["apple", "bear", "pear"], |
| 48 | + ... "col_2": ["fruit", "animal", "animal"] |
| 49 | + ... }) |
| 50 | + >>> bbq.ai_generate_bool((df["col_1"], " is a ", df["col_2"])) |
| 51 | + 0 {'result': True, 'full_response': '{"candidate... |
| 52 | + 1 {'result': True, 'full_response': '{"candidate... |
| 53 | + 2 {'result': False, 'full_response': '{"candidat... |
| 54 | + dtype: struct<result: bool, full_response: string, status: string>[pyarrow] |
| 55 | +
|
| 56 | + >>> bbq.ai_generate_bool((df["col_1"], " is a ", df["col_2"])).struct.field("result") |
| 57 | + 0 True |
| 58 | + 1 True |
| 59 | + 2 False |
| 60 | + Name: result, dtype: boolean |
| 61 | +
|
| 62 | + >>> model_params = { |
| 63 | + ... "generation_config": { |
| 64 | + ... "thinking_config": { |
| 65 | + ... "thinking_budget": 0 |
| 66 | + ... } |
| 67 | + ... } |
| 68 | + ... } |
| 69 | + >>> bbq.ai_generate_bool( |
| 70 | + ... (df["col_1"], " is a ", df["col_2"]), |
| 71 | + ... endpoint="gemini-2.5-pro", |
| 72 | + ... model_params=model_params, |
| 73 | + ... ).struct.field("result") |
| 74 | + 0 True |
| 75 | + 1 True |
| 76 | + 2 False |
| 77 | + Name: result, dtype: boolean |
| 78 | +
|
| 79 | + Args: |
| 80 | + prompt (series.Series | List[str|series.Series] | Tuple[str|series.Series, ...]): |
| 81 | + A mixture of Series and string literals that specifies the prompt to send to the model. |
| 82 | + connection_id (str, optional): |
| 83 | + Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. |
| 84 | + If not provided, the connection from the current session will be used. |
| 85 | + endpoint (str, optional): |
| 86 | + Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any |
| 87 | + generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and |
| 88 | + uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects a recent stable |
| 89 | + version of Gemini to use. |
| 90 | + request_type (Literal["dedicated", "shared", "unspecified"]): |
| 91 | + Specifies the type of inference request to send to the Gemini model. The request type determines what quota the request uses. |
| 92 | + * "dedicated": function only uses Provisioned Throughput quota. The function returns the error Provisioned throughput is not |
| 93 | + purchased or is not active if Provisioned Throughput quota isn't available. |
| 94 | + * "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota. |
| 95 | + * "unspecified": If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota. |
| 96 | + If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first. |
| 97 | + If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota. |
| 98 | + model_params (Mapping[Any, Any]): |
| 99 | + Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format. |
| 100 | +
|
| 101 | + Returns: |
| 102 | + bigframes.series.Series: A new struct Series with the result data. The struct contains these fields: |
| 103 | + * "result": a BOOL value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI. |
| 104 | + * "full_response": a STRING value containing the JSON response from the projects.locations.endpoints.generateContent call to the model. |
| 105 | + The generated text is in the text element. |
| 106 | + * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. |
| 107 | + """ |
| 108 | + |
| 109 | + prompt_context, series_list = _separate_context_and_series(prompt) |
| 110 | + assert len(series_list) > 0 |
| 111 | + |
| 112 | + operator = ai_ops.AIGenerateBool( |
| 113 | + prompt_context=tuple(prompt_context), |
| 114 | + connection_id=_resolve_connection_id(series_list[0], connection_id), |
| 115 | + endpoint=endpoint, |
| 116 | + request_type=request_type, |
| 117 | + model_params=json.dumps(model_params) if model_params else None, |
| 118 | + ) |
| 119 | + |
| 120 | + return series_list[0]._apply_nary_op(operator, series_list[1:]) |
| 121 | + |
| 122 | + |
| 123 | +def _separate_context_and_series( |
| 124 | + prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...], |
| 125 | +) -> Tuple[List[str | None], List[series.Series]]: |
| 126 | + """ |
| 127 | + Returns the two values. The first value is the prompt with all series replaced by None. The second value is all the series |
| 128 | + in the prompt. The original item order is kept. |
| 129 | + For example: |
| 130 | + Input: ("str1", series1, "str2", "str3", series2) |
| 131 | + Output: ["str1", None, "str2", "str3", None], [series1, series2] |
| 132 | + """ |
| 133 | + if not isinstance(prompt, (list, tuple, series.Series)): |
| 134 | + raise ValueError(f"Unsupported prompt type: {type(prompt)}") |
| 135 | + |
| 136 | + if isinstance(prompt, series.Series): |
| 137 | + if prompt.dtype == dtypes.OBJ_REF_DTYPE: |
| 138 | + # Multi-model support |
| 139 | + return [None], [prompt.blob.read_url()] |
| 140 | + return [None], [prompt] |
| 141 | + |
| 142 | + prompt_context: List[str | None] = [] |
| 143 | + series_list: List[series.Series] = [] |
| 144 | + |
| 145 | + for item in prompt: |
| 146 | + if isinstance(item, str): |
| 147 | + prompt_context.append(item) |
| 148 | + |
| 149 | + elif isinstance(item, series.Series): |
| 150 | + prompt_context.append(None) |
| 151 | + |
| 152 | + if item.dtype == dtypes.OBJ_REF_DTYPE: |
| 153 | + # Multi-model support |
| 154 | + item = item.blob.read_url() |
| 155 | + series_list.append(item) |
| 156 | + |
| 157 | + else: |
| 158 | + raise TypeError(f"Unsupported type in prompt: {type(item)}") |
| 159 | + |
| 160 | + if not series_list: |
| 161 | + raise ValueError("Please provide at least one Series in the prompt") |
| 162 | + |
| 163 | + return prompt_context, series_list |
| 164 | + |
| 165 | + |
| 166 | +def _resolve_connection_id(series: series.Series, connection_id: str | None): |
| 167 | + return clients.get_canonical_bq_connection_id( |
| 168 | + connection_id or series._session._bq_connection, |
| 169 | + series._session._project, |
| 170 | + series._session._location, |
| 171 | + ) |
0 commit comments