Skip to content

Commit 102d363

Browse files
authored
feat: support passing list of values to bigframes.core.sql.simple_literal (#1641)
1 parent 53fc25b commit 102d363

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

bigframes/core/sql.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343

4444

4545
### Writing SQL Values (literals, column references, table references, etc.)
46-
def simple_literal(value: bytes | str | int | bool | float | datetime.datetime | None):
46+
def simple_literal(
47+
value: bytes | str | int | bool | float | datetime.datetime | list | None,
48+
):
4749
"""Return quoted input string."""
4850

4951
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#literals
@@ -80,6 +82,10 @@ def simple_literal(value: bytes | str | int | bool | float | datetime.datetime |
8082
elif isinstance(value, decimal.Decimal):
8183
# TODO: disambiguate BIGNUMERIC based on scale and/or precision
8284
return f"CAST('{str(value)}' AS NUMERIC)"
85+
elif isinstance(value, list):
86+
simple_literals = [simple_literal(i) for i in value]
87+
return f"[{', '.join(simple_literals)}]"
88+
8389
else:
8490
raise ValueError(f"Cannot produce literal for {value}")
8591

tests/unit/core/test_sql.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,63 @@ def test_simple_literal(value, expected_pattern):
7373
assert re.match(expected_pattern, got) is not None
7474

7575

76+
@pytest.mark.parametrize(
77+
("value", "expected"),
78+
(
79+
# Try to have some list of literals for each scalar data type:
80+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types
81+
([None, None], "[NULL, NULL]"),
82+
([True, False], "[True, False]"),
83+
(
84+
[b"\x01\x02\x03ABC", b"\x01\x02\x03ABC"],
85+
"[b'\\x01\\x02\\x03ABC', b'\\x01\\x02\\x03ABC']",
86+
),
87+
(
88+
[datetime.date(2025, 1, 1), datetime.date(2025, 1, 1)],
89+
"[DATE('2025-01-01'), DATE('2025-01-01')]",
90+
),
91+
(
92+
[datetime.datetime(2025, 1, 2, 3, 45, 6, 789123)],
93+
"[DATETIME('2025-01-02T03:45:06.789123')]",
94+
),
95+
(
96+
[shapely.Point(0, 1), shapely.Point(0, 2)],
97+
"[ST_GEOGFROMTEXT('POINT (0 1)'), ST_GEOGFROMTEXT('POINT (0 2)')]",
98+
),
99+
# TODO: INTERVAL type (e.g. from dateutil.relativedelta)
100+
# TODO: JSON type (TBD what Python object that would correspond to)
101+
([123, 456], "[123, 456]"),
102+
(
103+
[decimal.Decimal("123.75"), decimal.Decimal("456.78")],
104+
"[CAST('123.75' AS NUMERIC), CAST('456.78' AS NUMERIC)]",
105+
),
106+
# TODO: support BIGNUMERIC by looking at precision/scale of the DECIMAL
107+
([123.75, 456.78], "[123.75, 456.78]"),
108+
# TODO: support RANGE type
109+
(["abc", "def"], "['abc', 'def']"),
110+
# TODO: support STRUCT type (possibly another method?)
111+
(
112+
[datetime.time(12, 34, 56, 789123), datetime.time(11, 25, 56, 789123)],
113+
"[TIME(DATETIME('1970-01-01 12:34:56.789123')), TIME(DATETIME('1970-01-01 11:25:56.789123'))]",
114+
),
115+
(
116+
[
117+
datetime.datetime(
118+
2025, 1, 2, 3, 45, 6, 789123, tzinfo=datetime.timezone.utc
119+
),
120+
datetime.datetime(
121+
2025, 2, 1, 4, 45, 6, 789123, tzinfo=datetime.timezone.utc
122+
),
123+
],
124+
"[TIMESTAMP('2025-01-02T03:45:06.789123+00:00'), TIMESTAMP('2025-02-01T04:45:06.789123+00:00')]",
125+
),
126+
),
127+
)
128+
def test_simple_literal_w_list(value: list, expected: str):
129+
got = sql.simple_literal(value)
130+
assert got == expected
131+
132+
76133
def test_create_vector_search_sql_simple():
77134
result_query = sql.create_vector_search_sql(
78135
sql_string="SELECT embedding FROM my_embeddings_table WHERE id = 1",

0 commit comments

Comments
 (0)