Skip to content

Commit 5613e44

Browse files
authored
feat: implement cos, sin, and log operations for polars compiler (#2170)
* feat: implement cos, sin, and log operations for polars compiler * fix domain for log * update snapshot * revert sqrt change * revert sqrt change
1 parent 118c265 commit 5613e44

File tree

7 files changed

+104
-17
lines changed

7 files changed

+104
-17
lines changed

bigframes/core/compile/polars/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# polars shouldn't be needed at import time, as register is a no-op if polars
2525
# isn't installed.
2626
import bigframes.core.compile.polars.operations.generic_ops # noqa: F401
27+
import bigframes.core.compile.polars.operations.numeric_ops # noqa: F401
2728
import bigframes.core.compile.polars.operations.struct_ops # noqa: F401
2829

2930
try:
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
BigFrames -> Polars compilation for the operations in bigframes.operations.numeric_ops.
17+
18+
Please keep implementations in sequential order by op name.
19+
"""
20+
21+
from __future__ import annotations
22+
23+
from typing import TYPE_CHECKING
24+
25+
import bigframes.core.compile.polars.compiler as polars_compiler
26+
from bigframes.operations import numeric_ops
27+
28+
if TYPE_CHECKING:
29+
import polars as pl
30+
31+
32+
@polars_compiler.register_op(numeric_ops.CosOp)
33+
def cos_op_impl(
34+
compiler: polars_compiler.PolarsExpressionCompiler,
35+
op: numeric_ops.CosOp, # type: ignore
36+
input: pl.Expr,
37+
) -> pl.Expr:
38+
return input.cos()
39+
40+
41+
@polars_compiler.register_op(numeric_ops.LnOp)
42+
def ln_op_impl(
43+
compiler: polars_compiler.PolarsExpressionCompiler,
44+
op: numeric_ops.LnOp, # type: ignore
45+
input: pl.Expr,
46+
) -> pl.Expr:
47+
import polars as pl
48+
49+
return pl.when(input <= 0).then(float("nan")).otherwise(input.log())
50+
51+
52+
@polars_compiler.register_op(numeric_ops.Log10Op)
53+
def log10_op_impl(
54+
compiler: polars_compiler.PolarsExpressionCompiler,
55+
op: numeric_ops.Log10Op, # type: ignore
56+
input: pl.Expr,
57+
) -> pl.Expr:
58+
import polars as pl
59+
60+
return pl.when(input <= 0).then(float("nan")).otherwise(input.log(base=10))
61+
62+
63+
@polars_compiler.register_op(numeric_ops.Log1pOp)
64+
def log1p_op_impl(
65+
compiler: polars_compiler.PolarsExpressionCompiler,
66+
op: numeric_ops.Log1pOp, # type: ignore
67+
input: pl.Expr,
68+
) -> pl.Expr:
69+
import polars as pl
70+
71+
return pl.when(input <= -1).then(float("nan")).otherwise((input + 1).log())
72+
73+
74+
@polars_compiler.register_op(numeric_ops.SinOp)
75+
def sin_op_impl(
76+
compiler: polars_compiler.PolarsExpressionCompiler,
77+
op: numeric_ops.SinOp, # type: ignore
78+
input: pl.Expr,
79+
) -> pl.Expr:
80+
return input.sin()
81+
82+
83+
@polars_compiler.register_op(numeric_ops.SqrtOp)
84+
def sqrt_op_impl(
85+
compiler: polars_compiler.PolarsExpressionCompiler,
86+
op: numeric_ops.SqrtOp, # type: ignore
87+
input: pl.Expr,
88+
) -> pl.Expr:
89+
import polars as pl
90+
91+
return pl.when(input < 0).then(float("nan")).otherwise(input.sqrt())

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _(expr: TypedExpr) -> sge.Expression:
158158
return sge.Case(
159159
ifs=[
160160
sge.If(
161-
this=expr.expr < sge.convert(0),
161+
this=expr.expr <= sge.convert(0),
162162
true=constants._NAN,
163163
)
164164
],
@@ -171,7 +171,7 @@ def _(expr: TypedExpr) -> sge.Expression:
171171
return sge.Case(
172172
ifs=[
173173
sge.If(
174-
this=expr.expr < sge.convert(0),
174+
this=expr.expr <= sge.convert(0),
175175
true=constants._NAN,
176176
)
177177
],
@@ -184,7 +184,7 @@ def _(expr: TypedExpr) -> sge.Expression:
184184
return sge.Case(
185185
ifs=[
186186
sge.If(
187-
this=expr.expr < sge.convert(-1),
187+
this=expr.expr <= sge.convert(-1),
188188
true=constants._NAN,
189189
)
190190
],

tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_ln/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
CASE WHEN `bfcol_0` < 0 THEN CAST('NaN' AS FLOAT64) ELSE LN(`bfcol_0`) END AS `bfcol_1`
8+
CASE WHEN `bfcol_0` <= 0 THEN CAST('NaN' AS FLOAT64) ELSE LN(`bfcol_0`) END AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_log10/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
CASE WHEN `bfcol_0` < 0 THEN CAST('NaN' AS FLOAT64) ELSE LOG(10, `bfcol_0`) END AS `bfcol_1`
8+
CASE WHEN `bfcol_0` <= 0 THEN CAST('NaN' AS FLOAT64) ELSE LOG(10, `bfcol_0`) END AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_log1p/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
CASE WHEN `bfcol_0` < -1 THEN CAST('NaN' AS FLOAT64) ELSE LN(1 + `bfcol_0`) END AS `bfcol_1`
8+
CASE WHEN `bfcol_0` <= -1 THEN CAST('NaN' AS FLOAT64) ELSE LN(1 + `bfcol_0`) END AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/test_series_polars.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4622,20 +4622,15 @@ def test_apply_lambda(scalars_dfs, col, lambda_):
46224622
)
46234623

46244624

4625-
@pytest.mark.skip(
4626-
reason="NotImplementedError: Polars compiler hasn't implemented log()"
4627-
)
46284625
@pytest.mark.parametrize(
46294626
("ufunc",),
46304627
[
4631-
pytest.param(numpy.log),
4632-
pytest.param(numpy.sqrt),
4633-
pytest.param(numpy.sin),
4634-
],
4635-
ids=[
4636-
"log",
4637-
"sqrt",
4638-
"sin",
4628+
pytest.param(numpy.cos, id="cos"),
4629+
pytest.param(numpy.log, id="log"),
4630+
pytest.param(numpy.log10, id="log10"),
4631+
pytest.param(numpy.log1p, id="log1p"),
4632+
pytest.param(numpy.sqrt, id="sqrt"),
4633+
pytest.param(numpy.sin, id="sin"),
46394634
],
46404635
)
46414636
def test_apply_numpy_ufunc(scalars_dfs, ufunc):

0 commit comments

Comments
 (0)