@@ -73,6 +73,63 @@ def test_simple_literal(value, expected_pattern):
73
73
assert re .match (expected_pattern , got ) is not None
74
74
75
75
76
+ @pytest .mark .parametrize (
77
+ ("value" , "expected" ),
78
+ (
79
+ # Try to have some list of literals for each scalar data type:
80
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types
81
+ ([None , None ], "[NULL, NULL]" ),
82
+ ([True , False ], "[True, False]" ),
83
+ (
84
+ [b"\x01 \x02 \x03 ABC" , b"\x01 \x02 \x03 ABC" ],
85
+ "[b'\\ x01\\ x02\\ x03ABC', b'\\ x01\\ x02\\ x03ABC']" ,
86
+ ),
87
+ (
88
+ [datetime .date (2025 , 1 , 1 ), datetime .date (2025 , 1 , 1 )],
89
+ "[DATE('2025-01-01'), DATE('2025-01-01')]" ,
90
+ ),
91
+ (
92
+ [datetime .datetime (2025 , 1 , 2 , 3 , 45 , 6 , 789123 )],
93
+ "[DATETIME('2025-01-02T03:45:06.789123')]" ,
94
+ ),
95
+ (
96
+ [shapely .Point (0 , 1 ), shapely .Point (0 , 2 )],
97
+ "[ST_GEOGFROMTEXT('POINT (0 1)'), ST_GEOGFROMTEXT('POINT (0 2)')]" ,
98
+ ),
99
+ # TODO: INTERVAL type (e.g. from dateutil.relativedelta)
100
+ # TODO: JSON type (TBD what Python object that would correspond to)
101
+ ([123 , 456 ], "[123, 456]" ),
102
+ (
103
+ [decimal .Decimal ("123.75" ), decimal .Decimal ("456.78" )],
104
+ "[CAST('123.75' AS NUMERIC), CAST('456.78' AS NUMERIC)]" ,
105
+ ),
106
+ # TODO: support BIGNUMERIC by looking at precision/scale of the DECIMAL
107
+ ([123.75 , 456.78 ], "[123.75, 456.78]" ),
108
+ # TODO: support RANGE type
109
+ (["abc" , "def" ], "['abc', 'def']" ),
110
+ # TODO: support STRUCT type (possibly another method?)
111
+ (
112
+ [datetime .time (12 , 34 , 56 , 789123 ), datetime .time (11 , 25 , 56 , 789123 )],
113
+ "[TIME(DATETIME('1970-01-01 12:34:56.789123')), TIME(DATETIME('1970-01-01 11:25:56.789123'))]" ,
114
+ ),
115
+ (
116
+ [
117
+ datetime .datetime (
118
+ 2025 , 1 , 2 , 3 , 45 , 6 , 789123 , tzinfo = datetime .timezone .utc
119
+ ),
120
+ datetime .datetime (
121
+ 2025 , 2 , 1 , 4 , 45 , 6 , 789123 , tzinfo = datetime .timezone .utc
122
+ ),
123
+ ],
124
+ "[TIMESTAMP('2025-01-02T03:45:06.789123+00:00'), TIMESTAMP('2025-02-01T04:45:06.789123+00:00')]" ,
125
+ ),
126
+ ),
127
+ )
128
+ def test_simple_literal_w_list (value : list , expected : str ):
129
+ got = sql .simple_literal (value )
130
+ assert got == expected
131
+
132
+
76
133
def test_create_vector_search_sql_simple ():
77
134
result_query = sql .create_vector_search_sql (
78
135
sql_string = "SELECT embedding FROM my_embeddings_table WHERE id = 1" ,
0 commit comments