2525
2626from bigframes import clients , dtypes , series , session
2727from bigframes .core import convert , log_adapter
28- from bigframes .operations import ai_ops
28+ from bigframes .operations import ai_ops , output_schemas
2929
3030PROMPT_TYPE = Union [
3131 series .Series ,
@@ -43,7 +43,7 @@ def generate(
4343 endpoint : str | None = None ,
4444 request_type : Literal ["dedicated" , "shared" , "unspecified" ] = "unspecified" ,
4545 model_params : Mapping [Any , Any ] | None = None ,
46- # TODO(b/446974666) Add output_schema parameter
46+ output_schema : Mapping [ str , str ] | None = None ,
4747) -> series .Series :
4848 """
4949 Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.
@@ -64,6 +64,14 @@ def generate(
6464 1 Ottawa\\ n
6565 Name: result, dtype: string
6666
67+ You get structured output when the `output_schema` parameter is set:
68+
69+ >>> animals = bpd.Series(["Rabbit", "Spider"])
70+ >>> bbq.ai.generate(animals, output_schema={"number_of_legs": "INT64", "is_herbivore": "BOOL"})
71+ 0 {'is_herbivore': True, 'number_of_legs': 4, 'f...
72+ 1 {'is_herbivore': False, 'number_of_legs': 8, '...
73+ dtype: struct<is_herbivore: bool, number_of_legs: int64, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
74+
6775 Args:
6876 prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
6977 A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
@@ -86,10 +94,14 @@ def generate(
8694 If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
8795 model_params (Mapping[Any, Any]):
8896 Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.
97+ output_schema (Mapping[str, str]):
98+ A mapping value that specifies the schema of the output, in the form {field_name: data_type}. Supported data types include
99+ `STRING`, `INT64`, `FLOAT64`, `BOOL`, `ARRAY`, and `STRUCT`.
89100
90101 Returns:
91102 bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
92103 * "result": a STRING value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
104+ If you specify an output schema then result is replaced by your custom schema.
93105 * "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
94106 The generated text is in the text element.
95107 * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
@@ -98,12 +110,22 @@ def generate(
98110 prompt_context , series_list = _separate_context_and_series (prompt )
99111 assert len (series_list ) > 0
100112
113+ if output_schema is None :
114+ output_schema_str = None
115+ else :
116+ output_schema_str = ", " .join (
117+ [f"{ name } { sql_type } " for name , sql_type in output_schema .items ()]
118+ )
119+ # Validate user input
120+ output_schemas .parse_sql_fields (output_schema_str )
121+
101122 operator = ai_ops .AIGenerate (
102123 prompt_context = tuple (prompt_context ),
103124 connection_id = _resolve_connection_id (series_list [0 ], connection_id ),
104125 endpoint = endpoint ,
105126 request_type = request_type ,
106127 model_params = json .dumps (model_params ) if model_params else None ,
128+ output_schema = output_schema_str ,
107129 )
108130
109131 return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
0 commit comments