|
28 | 28 | from bigframes.operations import ai_ops, output_schemas |
29 | 29 |
|
30 | 30 | PROMPT_TYPE = Union[ |
| 31 | + str, |
31 | 32 | series.Series, |
32 | 33 | pd.Series, |
33 | 34 | List[Union[str, series.Series, pd.Series]], |
@@ -73,7 +74,7 @@ def generate( |
73 | 74 | dtype: struct<is_herbivore: bool, number_of_legs: int64, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow] |
74 | 75 |
|
75 | 76 | Args: |
76 | | - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 77 | + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
77 | 78 | A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series |
78 | 79 | or pandas Series. |
79 | 80 | connection_id (str, optional): |
@@ -165,7 +166,7 @@ def generate_bool( |
165 | 166 | Name: result, dtype: boolean |
166 | 167 |
|
167 | 168 | Args: |
168 | | - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 169 | + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
169 | 170 | A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series |
170 | 171 | or pandas Series. |
171 | 172 | connection_id (str, optional): |
@@ -240,7 +241,7 @@ def generate_int( |
240 | 241 | Name: result, dtype: Int64 |
241 | 242 |
|
242 | 243 | Args: |
243 | | - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 244 | + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
244 | 245 | A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series |
245 | 246 | or pandas Series. |
246 | 247 | connection_id (str, optional): |
@@ -315,7 +316,7 @@ def generate_double( |
315 | 316 | Name: result, dtype: Float64 |
316 | 317 |
|
317 | 318 | Args: |
318 | | - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 319 | + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
319 | 320 | A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series |
320 | 321 | or pandas Series. |
321 | 322 | connection_id (str, optional): |
@@ -386,7 +387,7 @@ def if_( |
386 | 387 | dtype: string |
387 | 388 |
|
388 | 389 | Args: |
389 | | - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 390 | + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
390 | 391 | A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series |
391 | 392 | or pandas Series. |
392 | 393 | connection_id (str, optional): |
@@ -433,7 +434,7 @@ def classify( |
433 | 434 | [2 rows x 2 columns] |
434 | 435 |
|
435 | 436 | Args: |
436 | | - input (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 437 | + input (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
437 | 438 | A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series |
438 | 439 | or pandas Series. |
439 | 440 | categories (tuple[str, ...] | list[str]): |
@@ -482,7 +483,7 @@ def score( |
482 | 483 | dtype: Float64 |
483 | 484 |
|
484 | 485 | Args: |
485 | | - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 486 | + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
486 | 487 | A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series |
487 | 488 | or pandas Series. |
488 | 489 | connection_id (str, optional): |
@@ -514,9 +515,12 @@ def _separate_context_and_series( |
514 | 515 | Input: ("str1", series1, "str2", "str3", series2) |
515 | 516 | Output: ["str1", None, "str2", "str3", None], [series1, series2] |
516 | 517 | """ |
517 | | - if not isinstance(prompt, (list, tuple, series.Series)): |
| 518 | + if not isinstance(prompt, (str, list, tuple, series.Series)): |
518 | 519 | raise ValueError(f"Unsupported prompt type: {type(prompt)}") |
519 | 520 |
|
| 521 | + if isinstance(prompt, str): |
| 522 | + return [None], [series.Series([prompt])] |
| 523 | + |
520 | 524 | if isinstance(prompt, series.Series): |
521 | 525 | if prompt.dtype == dtypes.OBJ_REF_DTYPE: |
522 | 526 | # Multi-model support |
|
0 commit comments