@@ -348,20 +348,20 @@ def if_(
348348 provides optimization such that not all rows are evaluated with the LLM.
349349
350350 **Examples:**
351- >>> import bigframes.pandas as bpd
352- >>> import bigframes.bigquery as bbq
353- >>> bpd.options.display.progress_bar = None
354- >>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
355- >>> bbq.ai.if_((us_state, " has a city called Springfield"))
356- 0 True
357- 1 True
358- 2 False
359- dtype: boolean
360-
361- >>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
362- 0 Massachusetts
363- 1 Illinois
364- dtype: string
351+ >>> import bigframes.pandas as bpd
352+ >>> import bigframes.bigquery as bbq
353+ >>> bpd.options.display.progress_bar = None
354+ >>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
355+ >>> bbq.ai.if_((us_state, " has a city called Springfield"))
356+ 0 True
357+ 1 True
358+ 2 False
359+ dtype: boolean
360+
361+ >>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
362+ 0 Massachusetts
363+ 1 Illinois
364+ dtype: string
365365
366366 Args:
367367 prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
@@ -386,6 +386,56 @@ def if_(
386386 return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
387387
388388
389+ @log_adapter .method_logger (custom_base_name = "bigquery_ai" )
390+ def classify (
391+ input : PROMPT_TYPE ,
392+ categories : tuple [str , ...] | list [str ],
393+ * ,
394+ connection_id : str | None = None ,
395+ ) -> series .Series :
396+ """
397+ Classifies a given input into one of the specified categories. It will always return one of the provided categories best fit the prompt input.
398+
399+ **Examples:**
400+
401+ >>> import bigframes.pandas as bpd
402+ >>> import bigframes.bigquery as bbq
403+ >>> bpd.options.display.progress_bar = None
404+ >>> df = bpd.DataFrame({'creature': ['Cat', 'Salmon']})
405+ >>> df['type'] = bbq.ai.classify(df['creature'], ['Mammal', 'Fish'])
406+ >>> df
407+ creature type
408+ 0 Cat Mammal
409+ 1 Salmon Fish
410+ <BLANKLINE>
411+ [2 rows x 2 columns]
412+
413+ Args:
414+ input (Series | List[str|Series] | Tuple[str|Series, ...]):
415+ A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series
416+ or pandas Series.
417+ categories (tuple[str, ...] | list[str]):
418+ Categories to classify the input into.
419+ connection_id (str, optional):
420+ Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
421+ If not provided, the connection from the current session will be used.
422+
423+ Returns:
424+ bigframes.series.Series: A new series of strings.
425+ """
426+
427+ prompt_context , series_list = _separate_context_and_series (input )
428+ assert len (series_list ) > 0
429+
430+ operator = ai_ops .AIClassify (
431+ prompt_context = tuple (prompt_context ),
432+ categories = tuple (categories ),
433+ connection_id = _resolve_connection_id (series_list [0 ], connection_id ),
434+ )
435+
436+ return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
437+
438+
389439@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
390440def score (
391441 prompt : PROMPT_TYPE ,
@@ -398,15 +448,16 @@ def score(
398448 rubric with examples in the prompt.
399449
400450 **Examples:**
401- >>> import bigframes.pandas as bpd
402- >>> import bigframes.bigquery as bbq
403- >>> bpd.options.display.progress_bar = None
404- >>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
405- >>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
406- 0 2.0
407- 1 1.0
408- 2 3.0
409- dtype: Float64
451+
452+ >>> import bigframes.pandas as bpd
453+ >>> import bigframes.bigquery as bbq
454+ >>> bpd.options.display.progress_bar = None
455+ >>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
456+ >>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
457+ 0 2.0
458+ 1 1.0
459+ 2 3.0
460+ dtype: Float64
410461
411462 Args:
412463 prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
0 commit comments