1515import sys
1616
1717import pandas as pd
18- import pandas . testing
18+ import pyarrow as pa
1919import pytest
2020
21+ from bigframes import series
2122import bigframes .bigquery as bbq
2223import bigframes .pandas as bpd
2324
@@ -27,15 +28,17 @@ def test_ai_generate_bool(session):
2728 s2 = bpd .Series (["fruit" , "tree" ], session = session )
2829 prompt = (s1 , " is a " , s2 )
2930
30- result = bbq .ai .generate_bool (prompt , endpoint = "gemini-2.5-flash" ).struct .field (
31- "result"
32- )
31+ result = bbq .ai .generate_bool (prompt , endpoint = "gemini-2.5-flash" )
3332
34- pandas .testing .assert_series_equal (
35- result .to_pandas (),
36- pd .Series ([True , False ], name = "result" ),
37- check_dtype = False ,
38- check_index = False ,
33+ assert _contains_no_nulls (result )
34+ assert result .dtype == pd .ArrowDtype (
35+ pa .struct (
36+ (
37+ pa .field ("result" , pa .bool_ ()),
38+ pa .field ("full_response" , pa .string ()),
39+ pa .field ("status" , pa .string ()),
40+ )
41+ )
3942 )
4043
4144
@@ -52,11 +55,38 @@ def test_ai_generate_bool_with_model_params(session):
5255
5356 result = bbq .ai .generate_bool (
5457 prompt , endpoint = "gemini-2.5-flash" , model_params = model_params
55- ).struct .field ("result" )
58+ )
59+
60+ assert _contains_no_nulls (result )
61+ assert result .dtype == pd .ArrowDtype (
62+ pa .struct (
63+ (
64+ pa .field ("result" , pa .bool_ ()),
65+ pa .field ("full_response" , pa .string ()),
66+ pa .field ("status" , pa .string ()),
67+ )
68+ )
69+ )
70+
5671
57- pandas .testing .assert_series_equal (
58- result .to_pandas (),
59- pd .Series ([True , False ], name = "result" ),
60- check_dtype = False ,
61- check_index = False ,
72+ def test_ai_generate_bool_multi_model (session ):
73+ df = session .from_glob_path (
74+ "gs://bigframes-dev-testing/a_multimodel/images/*" , name = "image"
6275 )
76+
77+ result = bbq .ai .generate_bool ((df ["image" ], " contains an animal" ))
78+
79+ assert _contains_no_nulls (result )
80+ assert result .dtype == pd .ArrowDtype (
81+ pa .struct (
82+ (
83+ pa .field ("result" , pa .bool_ ()),
84+ pa .field ("full_response" , pa .string ()),
85+ pa .field ("status" , pa .string ()),
86+ )
87+ )
88+ )
89+
90+
91+ def _contains_no_nulls (s : series .Series ) -> bool :
92+ return len (s ) == s .count ()
0 commit comments