Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions python/ray/data/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,43 @@ def exp(self) -> "UDFExpr":
self
)

# trigonometric helpers
def sin(self) -> "UDFExpr":
"""Compute the sine of the expression (in radians)."""
return _create_pyarrow_compute_udf(pc.sin, return_dtype=DataType.float64())(
self
)

def cos(self) -> "UDFExpr":
"""Compute the cosine of the expression (in radians)."""
return _create_pyarrow_compute_udf(pc.cos, return_dtype=DataType.float64())(
self
)

def tan(self) -> "UDFExpr":
"""Compute the tangent of the expression (in radians)."""
return _create_pyarrow_compute_udf(pc.tan, return_dtype=DataType.float64())(
self
)

def asin(self) -> "UDFExpr":
"""Compute the arcsine (inverse sine) of the expression, returning radians."""
return _create_pyarrow_compute_udf(pc.asin, return_dtype=DataType.float64())(
self
)

def acos(self) -> "UDFExpr":
"""Compute the arccosine (inverse cosine) of the expression, returning radians."""
return _create_pyarrow_compute_udf(pc.acos, return_dtype=DataType.float64())(
self
)

def atan(self) -> "UDFExpr":
"""Compute the arctangent (inverse tangent) of the expression, returning radians."""
return _create_pyarrow_compute_udf(pc.atan, return_dtype=DataType.float64())(
self
)

@property
def list(self) -> "_ListNamespace":
"""Access list operations for this expression.
Expand Down
49 changes: 49 additions & 0 deletions python/ray/data/tests/test_with_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,55 @@ def test_with_column_logarithmic_operations(
assert rows_same(result_df, expected_df)


@pytest.mark.skipif(
get_pyarrow_version() < parse_version("20.0.0"),
reason="with_column requires PyArrow >= 20.0.0",
)
@pytest.mark.parametrize(
"expr_factory, expected_fn",
[
pytest.param(lambda: col("value").sin(), math.sin, id="sin"),
pytest.param(lambda: col("value").cos(), math.cos, id="cos"),
pytest.param(lambda: col("value").tan(), math.tan, id="tan"),
pytest.param(lambda: col("value").asin(), math.asin, id="asin"),
pytest.param(lambda: col("value").acos(), math.acos, id="acos"),
pytest.param(lambda: col("value").atan(), math.atan, id="atan"),
],
)
@pytest.mark.parametrize(
"input_values",
[
pytest.param(
[0.0, math.pi / 6, math.pi / 4, math.pi / 3, math.pi / 2],
id="trigonometric_angles",
),
pytest.param([0, 1, 2, 3], id="integer_angles"),
],
)
def test_with_column_trigonometric_operations(
ray_start_regular_shared,
expr_factory,
expected_fn,
input_values,
):
"""Test trigonometric expressions (sin, cos, tan, asin, acos, atan)."""
values = input_values
ds = ray.data.from_items([{"value": v} for v in values])
expr = expr_factory()

# For inverse trigonometric functions, we need to handle domain restrictions
# asin and acos require input in [-1, 1]
if expected_fn in (math.asin, math.acos):
# Use values in the valid domain for inverse trig functions
values = [-1.0, -0.5, 0.0, 0.5, 1.0]
ds = ray.data.from_items([{"value": v} for v in values])

Comment on lines +705 to +711
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: To improve the readability of this test, I think it might be clearer if we move this out into a separate test_with_column_inverse_trigonometric_operations test

expected_values = [expected_fn(v) for v in values]
result_df = ds.with_column("result", expr).to_pandas()
expected_df = pd.DataFrame({"value": values, "result": expected_values})
assert rows_same(result_df, expected_df)


@pytest.mark.skipif(
get_pyarrow_version() < parse_version("20.0.0"),
reason="with_column requires PyArrow >= 20.0.0",
Expand Down