Skip to content

Commit b3dbbcc

Browse files
authored
refactor: enable "astype" engine tests for the sqlglot compiler (#2107)
1 parent ca1e44c commit b3dbbcc

File tree

10 files changed

+417
-20
lines changed

10 files changed

+417
-20
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,54 @@
1616

1717
import sqlglot.expressions as sge
1818

19+
from bigframes import dtypes
1920
from bigframes import operations as ops
2021
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2122
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
23+
from bigframes.core.compile.sqlglot.sqlglot_types import SQLGlotType
2224

2325
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
2426

2527

2628
@register_unary_op(ops.AsTypeOp, pass_op=True)
2729
def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
28-
# TODO: Support more types for casting, such as JSON, etc.
29-
return sge.Cast(this=expr.expr, to=op.to_type)
30+
from_type = expr.dtype
31+
to_type = op.to_type
32+
sg_to_type = SQLGlotType.from_bigframes_dtype(to_type)
33+
sg_expr = expr.expr
34+
35+
if to_type == dtypes.JSON_DTYPE:
36+
return _cast_to_json(expr, op)
37+
38+
if from_type == dtypes.JSON_DTYPE:
39+
return _cast_from_json(expr, op)
40+
41+
if to_type == dtypes.INT_DTYPE:
42+
result = _cast_to_int(expr, op)
43+
if result is not None:
44+
return result
45+
46+
if to_type == dtypes.FLOAT_DTYPE and from_type == dtypes.BOOL_DTYPE:
47+
sg_expr = _cast(sg_expr, "INT64", op.safe)
48+
return _cast(sg_expr, sg_to_type, op.safe)
49+
50+
if to_type == dtypes.BOOL_DTYPE:
51+
if from_type == dtypes.BOOL_DTYPE:
52+
return sg_expr
53+
else:
54+
return sge.NEQ(this=sg_expr, expression=sge.convert(0))
55+
56+
if to_type == dtypes.STRING_DTYPE:
57+
sg_expr = _cast(sg_expr, sg_to_type, op.safe)
58+
if from_type == dtypes.BOOL_DTYPE:
59+
sg_expr = sge.func("INITCAP", sg_expr)
60+
return sg_expr
61+
62+
if dtypes.is_time_like(to_type) and from_type == dtypes.INT_DTYPE:
63+
sg_expr = sge.func("TIMESTAMP_MICROS", sg_expr)
64+
return _cast(sg_expr, sg_to_type, op.safe)
65+
66+
return _cast(sg_expr, sg_to_type, op.safe)
3067

3168

3269
@register_unary_op(ops.hash_op)
@@ -53,3 +90,64 @@ def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression:
5390
@register_unary_op(ops.notnull_op)
5491
def _(expr: TypedExpr) -> sge.Expression:
5592
return sge.Not(this=sge.Is(this=expr.expr, expression=sge.Null()))
93+
94+
95+
# Helper functions
96+
def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
97+
from_type = expr.dtype
98+
sg_expr = expr.expr
99+
100+
if from_type == dtypes.STRING_DTYPE:
101+
func_name = "PARSE_JSON_IN_SAFE" if op.safe else "PARSE_JSON"
102+
return sge.func(func_name, sg_expr)
103+
if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE):
104+
sg_expr = sge.Cast(this=sg_expr, to="STRING")
105+
return sge.func("PARSE_JSON", sg_expr)
106+
raise TypeError(f"Cannot cast from {from_type} to {dtypes.JSON_DTYPE}")
107+
108+
109+
def _cast_from_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
110+
to_type = op.to_type
111+
sg_expr = expr.expr
112+
func_name = ""
113+
if to_type == dtypes.INT_DTYPE:
114+
func_name = "INT64"
115+
elif to_type == dtypes.FLOAT_DTYPE:
116+
func_name = "FLOAT64"
117+
elif to_type == dtypes.BOOL_DTYPE:
118+
func_name = "BOOL"
119+
elif to_type == dtypes.STRING_DTYPE:
120+
func_name = "STRING"
121+
if func_name:
122+
func_name = "SAFE." + func_name if op.safe else func_name
123+
return sge.func(func_name, sg_expr)
124+
raise TypeError(f"Cannot cast from {dtypes.JSON_DTYPE} to {to_type}")
125+
126+
127+
def _cast_to_int(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression | None:
128+
from_type = expr.dtype
129+
sg_expr = expr.expr
130+
# Cannot cast DATETIME to INT directly so need to convert to TIMESTAMP first.
131+
if from_type == dtypes.DATETIME_DTYPE:
132+
sg_expr = _cast(sg_expr, "TIMESTAMP", op.safe)
133+
return sge.func("UNIX_MICROS", sg_expr)
134+
if from_type == dtypes.TIMESTAMP_DTYPE:
135+
return sge.func("UNIX_MICROS", sg_expr)
136+
if from_type == dtypes.TIME_DTYPE:
137+
return sge.func(
138+
"TIME_DIFF",
139+
_cast(sg_expr, "TIME", op.safe),
140+
sge.convert("00:00:00"),
141+
"MICROSECOND",
142+
)
143+
if from_type == dtypes.NUMERIC_DTYPE or from_type == dtypes.FLOAT_DTYPE:
144+
sg_expr = sge.func("TRUNC", sg_expr)
145+
return _cast(sg_expr, "INT64", op.safe)
146+
return None
147+
148+
149+
def _cast(expr: sge.Expression, to: str, safe: bool):
150+
if safe:
151+
return sge.TryCast(this=expr, to=to)
152+
else:
153+
return sge.Cast(this=expr, to=to)

tests/system/small/engines/test_generic_ops.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def apply_op(
5252
return new_arr
5353

5454

55-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
55+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
5656
def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine):
5757
arr = apply_op(
5858
scalars_array_value,
@@ -63,7 +63,7 @@ def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine)
6363
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
6464

6565

66-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
66+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
6767
def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue, engine):
6868
vals = ["1", "100", "-3"]
6969
arr, _ = scalars_array_value.compute_values(
@@ -78,7 +78,7 @@ def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue,
7878
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
7979

8080

81-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
81+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
8282
def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engine):
8383
arr = apply_op(
8484
scalars_array_value,
@@ -89,7 +89,7 @@ def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engin
8989
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
9090

9191

92-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
92+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
9393
def test_engines_astype_string_float(
9494
scalars_array_value: array_value.ArrayValue, engine
9595
):
@@ -106,7 +106,7 @@ def test_engines_astype_string_float(
106106
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
107107

108108

109-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
109+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
110110
def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine):
111111
arr = apply_op(
112112
scalars_array_value, ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE)
@@ -115,7 +115,7 @@ def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine
115115
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
116116

117117

118-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
118+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
119119
def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engine):
120120
# floats work slightly different with trailing zeroes rn
121121
arr = apply_op(
@@ -127,7 +127,7 @@ def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engi
127127
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
128128

129129

130-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
130+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
131131
def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, engine):
132132
arr = apply_op(
133133
scalars_array_value,
@@ -138,7 +138,7 @@ def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, eng
138138
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
139139

140140

141-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
141+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
142142
def test_engines_astype_string_numeric(
143143
scalars_array_value: array_value.ArrayValue, engine
144144
):
@@ -155,7 +155,7 @@ def test_engines_astype_string_numeric(
155155
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
156156

157157

158-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
158+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
159159
def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine):
160160
arr = apply_op(
161161
scalars_array_value,
@@ -166,7 +166,7 @@ def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine
166166
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
167167

168168

169-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
169+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
170170
def test_engines_astype_string_date(
171171
scalars_array_value: array_value.ArrayValue, engine
172172
):
@@ -183,7 +183,7 @@ def test_engines_astype_string_date(
183183
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
184184

185185

186-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
186+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
187187
def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, engine):
188188
arr = apply_op(
189189
scalars_array_value,
@@ -194,7 +194,7 @@ def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, en
194194
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
195195

196196

197-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
197+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
198198
def test_engines_astype_string_datetime(
199199
scalars_array_value: array_value.ArrayValue, engine
200200
):
@@ -211,7 +211,7 @@ def test_engines_astype_string_datetime(
211211
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
212212

213213

214-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
214+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
215215
def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, engine):
216216
arr = apply_op(
217217
scalars_array_value,
@@ -222,7 +222,7 @@ def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, e
222222
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
223223

224224

225-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
225+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
226226
def test_engines_astype_string_timestamp(
227227
scalars_array_value: array_value.ArrayValue, engine
228228
):
@@ -243,7 +243,7 @@ def test_engines_astype_string_timestamp(
243243
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
244244

245245

246-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
246+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
247247
def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine):
248248
arr = apply_op(
249249
scalars_array_value,
@@ -254,7 +254,7 @@ def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine
254254
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
255255

256256

257-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
257+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
258258
def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, engine):
259259
exprs = [
260260
ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE).as_expr(
@@ -275,7 +275,7 @@ def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, e
275275
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
276276

277277

278-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
278+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
279279
def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, engine):
280280
exprs = [
281281
ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr(
@@ -298,7 +298,7 @@ def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, eng
298298
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
299299

300300

301-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
301+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
302302
def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, engine):
303303
arr = apply_op(
304304
scalars_array_value,
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`float64_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
`bfcol_0` AS `bfcol_2`,
10+
`bfcol_1` <> 0 AS `bfcol_3`,
11+
`bfcol_1` <> 0 AS `bfcol_4`
12+
FROM `bfcte_0`
13+
)
14+
SELECT
15+
`bfcol_2` AS `bool_col`,
16+
`bfcol_3` AS `float64_col`,
17+
`bfcol_4` AS `float64_w_safe`
18+
FROM `bfcte_1`
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CAST(CAST(`bfcol_0` AS INT64) AS FLOAT64) AS `bfcol_1`,
9+
CAST('1.34235e4' AS FLOAT64) AS `bfcol_2`,
10+
SAFE_CAST(SAFE_CAST(`bfcol_0` AS INT64) AS FLOAT64) AS `bfcol_3`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_1` AS `bool_col`,
15+
`bfcol_2` AS `str_const`,
16+
`bfcol_3` AS `bool_w_safe`
17+
FROM `bfcte_1`
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`json_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`json_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
INT64(`bfcol_0`) AS `bfcol_1`,
9+
FLOAT64(`bfcol_0`) AS `bfcol_2`,
10+
BOOL(`bfcol_0`) AS `bfcol_3`,
11+
STRING(`bfcol_0`) AS `bfcol_4`,
12+
SAFE.INT64(`bfcol_0`) AS `bfcol_5`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_1` AS `int64_col`,
17+
`bfcol_2` AS `float64_col`,
18+
`bfcol_3` AS `bool_col`,
19+
`bfcol_4` AS `string_col`,
20+
`bfcol_5` AS `int64_w_safe`
21+
FROM `bfcte_1`
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`datetime_col` AS `bfcol_0`,
4+
`numeric_col` AS `bfcol_1`,
5+
`float64_col` AS `bfcol_2`,
6+
`time_col` AS `bfcol_3`,
7+
`timestamp_col` AS `bfcol_4`
8+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
9+
), `bfcte_1` AS (
10+
SELECT
11+
*,
12+
UNIX_MICROS(CAST(`bfcol_0` AS TIMESTAMP)) AS `bfcol_5`,
13+
UNIX_MICROS(SAFE_CAST(`bfcol_0` AS TIMESTAMP)) AS `bfcol_6`,
14+
TIME_DIFF(CAST(`bfcol_3` AS TIME), '00:00:00', MICROSECOND) AS `bfcol_7`,
15+
TIME_DIFF(SAFE_CAST(`bfcol_3` AS TIME), '00:00:00', MICROSECOND) AS `bfcol_8`,
16+
UNIX_MICROS(`bfcol_4`) AS `bfcol_9`,
17+
CAST(TRUNC(`bfcol_1`) AS INT64) AS `bfcol_10`,
18+
CAST(TRUNC(`bfcol_2`) AS INT64) AS `bfcol_11`,
19+
SAFE_CAST(TRUNC(`bfcol_2`) AS INT64) AS `bfcol_12`,
20+
CAST('100' AS INT64) AS `bfcol_13`
21+
FROM `bfcte_0`
22+
)
23+
SELECT
24+
`bfcol_5` AS `datetime_col`,
25+
`bfcol_6` AS `datetime_w_safe`,
26+
`bfcol_7` AS `time_col`,
27+
`bfcol_8` AS `time_w_safe`,
28+
`bfcol_9` AS `timestamp_col`,
29+
`bfcol_10` AS `numeric_col`,
30+
`bfcol_11` AS `float64_col`,
31+
`bfcol_12` AS `float64_w_safe`,
32+
`bfcol_13` AS `str_const`
33+
FROM `bfcte_1`

0 commit comments

Comments
 (0)