Skip to content

Commit 1eed30a

Browse files
[Data] - Add to_pyarrow() to Expr (ray-project#57271)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? Ray Data Expr to Pyarrow Compute Expression converter. ## Related issue number <!-- For example: "Closes ray-project#1234" --> ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run pre-commit jobs to lint the changes in this PR. ([pre-commit setup](https://docs.ray.io/en/latest/ray-contribute/getting-involved.html#lint-and-formatting)) - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Goutam V. <goutam@anyscale.com> Signed-off-by: Goutam <goutam@anyscale.com> Signed-off-by: Future-Outlier <eric901201@gmail.com>
1 parent 76195e6 commit 1eed30a

File tree

2 files changed

+282
-0
lines changed

2 files changed

+282
-0
lines changed

python/ray/data/expressions.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from enum import Enum
77
from typing import Any, Callable, Dict, List, Union
88

9+
import pyarrow
10+
911
from ray.data.block import BatchColumn
1012
from ray.data.datatype import DataType
1113
from ray.util.annotations import DeveloperAPI, PublicAPI
@@ -59,6 +61,118 @@ class Operation(Enum):
5961
NOT_IN = "not_in"
6062

6163

64+
class _ExprVisitor(ABC):
65+
"""Base visitor with generic dispatch for Ray Data expressions."""
66+
67+
def visit(self, expr: "Expr") -> Any:
68+
if isinstance(expr, ColumnExpr):
69+
return self.visit_column(expr)
70+
elif isinstance(expr, LiteralExpr):
71+
return self.visit_literal(expr)
72+
elif isinstance(expr, BinaryExpr):
73+
return self.visit_binary(expr)
74+
elif isinstance(expr, UnaryExpr):
75+
return self.visit_unary(expr)
76+
elif isinstance(expr, AliasExpr):
77+
return self.visit_alias(expr)
78+
elif isinstance(expr, UDFExpr):
79+
return self.visit_udf(expr)
80+
elif isinstance(expr, DownloadExpr):
81+
return self.visit_download(expr)
82+
else:
83+
raise TypeError(f"Unsupported expression type for conversion: {type(expr)}")
84+
85+
@abstractmethod
86+
def visit_column(self, expr: "ColumnExpr") -> Any:
87+
pass
88+
89+
@abstractmethod
90+
def visit_literal(self, expr: "LiteralExpr") -> Any:
91+
pass
92+
93+
@abstractmethod
94+
def visit_binary(self, expr: "BinaryExpr") -> Any:
95+
pass
96+
97+
@abstractmethod
98+
def visit_unary(self, expr: "UnaryExpr") -> Any:
99+
pass
100+
101+
@abstractmethod
102+
def visit_alias(self, expr: "AliasExpr") -> Any:
103+
pass
104+
105+
@abstractmethod
106+
def visit_udf(self, expr: "UDFExpr") -> Any:
107+
pass
108+
109+
@abstractmethod
110+
def visit_download(self, expr: "DownloadExpr") -> Any:
111+
pass
112+
113+
114+
class _PyArrowExpressionVisitor(_ExprVisitor):
115+
"""Visitor that converts Ray Data expressions to PyArrow compute expressions."""
116+
117+
def visit_column(self, expr: "ColumnExpr") -> "pyarrow.compute.Expression":
118+
import pyarrow.compute as pc
119+
120+
return pc.field(expr.name)
121+
122+
def visit_literal(self, expr: "LiteralExpr") -> "pyarrow.compute.Expression":
123+
import pyarrow.compute as pc
124+
125+
return pc.scalar(expr.value)
126+
127+
def visit_binary(self, expr: "BinaryExpr") -> "pyarrow.compute.Expression":
128+
import pyarrow as pa
129+
import pyarrow.compute as pc
130+
131+
if expr.op in (Operation.IN, Operation.NOT_IN):
132+
left = self.visit(expr.left)
133+
if isinstance(expr.right, LiteralExpr):
134+
right_value = expr.right.value
135+
right = (
136+
pa.array(right_value)
137+
if isinstance(right_value, list)
138+
else pa.array([right_value])
139+
)
140+
else:
141+
raise ValueError(
142+
f"is_in/not_in operations require the right operand to be a "
143+
f"literal list, got {type(expr.right).__name__}."
144+
)
145+
result = pc.is_in(left, right)
146+
return pc.invert(result) if expr.op == Operation.NOT_IN else result
147+
148+
left = self.visit(expr.left)
149+
right = self.visit(expr.right)
150+
from ray.data._expression_evaluator import _ARROW_EXPR_OPS_MAP
151+
152+
if expr.op in _ARROW_EXPR_OPS_MAP:
153+
return _ARROW_EXPR_OPS_MAP[expr.op](left, right)
154+
raise ValueError(f"Unsupported binary operation for PyArrow: {expr.op}")
155+
156+
def visit_unary(self, expr: "UnaryExpr") -> "pyarrow.compute.Expression":
157+
operand = self.visit(expr.operand)
158+
from ray.data._expression_evaluator import _ARROW_EXPR_OPS_MAP
159+
160+
if expr.op in _ARROW_EXPR_OPS_MAP:
161+
return _ARROW_EXPR_OPS_MAP[expr.op](operand)
162+
raise ValueError(f"Unsupported unary operation for PyArrow: {expr.op}")
163+
164+
def visit_alias(self, expr: "AliasExpr") -> "pyarrow.compute.Expression":
165+
return self.visit(expr.expr)
166+
167+
def visit_udf(self, expr: "UDFExpr") -> "pyarrow.compute.Expression":
168+
raise TypeError("UDF expressions cannot be converted to PyArrow expressions")
169+
170+
def visit_download(self, expr: "DownloadExpr") -> "pyarrow.compute.Expression":
171+
raise TypeError(
172+
"Download expressions cannot be converted to PyArrow expressions"
173+
)
174+
175+
62176
@DeveloperAPI(stability="alpha")
63177
@dataclass(frozen=True)
64178
class Expr(ABC):
@@ -101,6 +215,18 @@ def structurally_equals(self, other: Any) -> bool:
101215
"""Compare two expression ASTs for structural equality."""
102216
raise NotImplementedError
103217

218+
def to_pyarrow(self) -> "pyarrow.compute.Expression":
219+
"""Convert this Ray Data expression to a PyArrow compute expression.
220+
221+
Returns:
222+
A PyArrow compute expression equivalent to this Ray Data expression.
223+
224+
Raises:
225+
ValueError: If the expression contains operations not supported by PyArrow.
226+
TypeError: If the expression type cannot be converted to PyArrow.
227+
"""
228+
return _PyArrowExpressionVisitor().visit(self)
229+
104230
def _bin(self, other: Any, op: Operation) -> "Expr":
105231
"""Create a binary expression with the given operation.
106232

python/ray/data/tests/test_expressions.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pyarrow as pa
2+
import pyarrow.compute as pc
13
import pytest
24

35
from ray.data.expressions import (
@@ -262,6 +264,160 @@ def test_complex_boolean_expressions(self):
262264
assert very_complex.op == Operation.AND
263265

264266

267+
class TestToPyArrow:
268+
"""Test conversion of Ray Data expressions to PyArrow compute expressions."""
269+
270+
@pytest.mark.parametrize(
271+
"ray_expr, equivalent_pyarrow_expr, description",
272+
[
273+
# Basic expressions
274+
(col("age"), lambda: pc.field("age"), "column reference"),
275+
(lit(42), lambda: pc.scalar(42), "integer literal"),
276+
(lit("hello"), lambda: pc.scalar("hello"), "string literal"),
277+
# Arithmetic operations
278+
(
279+
col("x") + 5,
280+
lambda: pc.add(pc.field("x"), pc.scalar(5)),
281+
"addition",
282+
),
283+
(
284+
col("x") * 2,
285+
lambda: pc.multiply(pc.field("x"), pc.scalar(2)),
286+
"multiplication",
287+
),
288+
# Comparison operations
289+
(
290+
col("age") > 18,
291+
lambda: pc.greater(pc.field("age"), pc.scalar(18)),
292+
"greater than",
293+
),
294+
(
295+
col("status") == "active",
296+
lambda: pc.equal(pc.field("status"), pc.scalar("active")),
297+
"equality",
298+
),
299+
# Boolean operations
300+
(
301+
(col("age") > 18) & (col("age") < 65),
302+
lambda: pc.and_kleene(
303+
pc.greater(pc.field("age"), pc.scalar(18)),
304+
pc.less(pc.field("age"), pc.scalar(65)),
305+
),
306+
"logical AND",
307+
),
308+
(
309+
~(col("active")),
310+
lambda: pc.invert(pc.field("active")),
311+
"logical NOT",
312+
),
313+
# Unary operations
314+
(
315+
col("value").is_null(),
316+
lambda: pc.is_null(pc.field("value")),
317+
"is_null check",
318+
),
319+
# In operations
320+
(
321+
col("status").is_in(["active", "pending"]),
322+
lambda: pc.is_in(pc.field("status"), pa.array(["active", "pending"])),
323+
"is_in with list",
324+
),
325+
# Complex nested expressions
326+
(
327+
(col("price") * col("quantity")) + col("tax"),
328+
lambda: pc.add(
329+
pc.multiply(pc.field("price"), pc.field("quantity")),
330+
pc.field("tax"),
331+
),
332+
"nested arithmetic",
333+
),
334+
# Alias expressions (should unwrap to inner expression)
335+
(
336+
(col("x") + 5).alias("result"),
337+
lambda: pc.add(pc.field("x"), pc.scalar(5)),
338+
"aliased expression",
339+
),
340+
],
341+
ids=[
342+
"col",
343+
"int_lit",
344+
"str_lit",
345+
"add",
346+
"mul",
347+
"gt",
348+
"eq",
349+
"and",
350+
"not",
351+
"is_null",
352+
"is_in",
353+
"nested",
354+
"alias",
355+
],
356+
)
357+
def test_to_pyarrow_equivalence(
358+
self, ray_expr, equivalent_pyarrow_expr, description
359+
):
360+
"""Test that Ray Data expressions convert to equivalent PyArrow expressions.
361+
362+
This test documents the expected PyArrow expression for each Ray Data expression
363+
and verifies correctness by comparing results on sample data.
364+
"""
365+
import pyarrow.dataset as ds
366+
367+
# Convert Ray expression to PyArrow
368+
converted = ray_expr.to_pyarrow()
369+
expected = equivalent_pyarrow_expr()
370+
371+
# Both should be PyArrow expressions
372+
assert isinstance(converted, pc.Expression)
373+
assert isinstance(expected, pc.Expression)
374+
375+
# Verify they produce the same results on sample data
376+
test_data = pa.table(
377+
{
378+
"age": [15, 25, 45, 70],
379+
"x": [1, 2, 3, 4],
380+
"price": [10.0, 20.0, 30.0, 40.0],
381+
"quantity": [2, 3, 1, 5],
382+
"tax": [1.0, 2.0, 3.0, 4.0],
383+
"status": ["active", "pending", "inactive", "active"],
384+
"value": [1, None, 3, None],
385+
"active": [True, False, True, False],
386+
}
387+
)
388+
389+
dataset = ds.dataset(test_data)
390+
391+
try:
392+
# For boolean expressions, compare filter results
393+
result_converted = dataset.scanner(filter=converted).to_table()
394+
result_expected = dataset.scanner(filter=expected).to_table()
395+
assert result_converted.equals(
396+
result_expected
397+
), f"Expressions produce different results for {description}"
398+
except (TypeError, pa.lib.ArrowInvalid, pa.lib.ArrowNotImplementedError):
399+
# For non-boolean expressions, just verify both are valid
400+
pass
401+
402+
def test_to_pyarrow_unsupported_expressions(self):
403+
"""Test that unsupported expression types raise appropriate errors."""
404+
from ray.data.datatype import DataType
405+
from ray.data.expressions import UDFExpr
406+
407+
def dummy_fn(x):
408+
return x
409+
410+
udf_expr = UDFExpr(
411+
fn=dummy_fn,
412+
args=[col("x")],
413+
kwargs={},
414+
data_type=DataType(int),
415+
)
416+
417+
with pytest.raises(TypeError, match="UDF expressions cannot be converted"):
418+
udf_expr.to_pyarrow()
419+
420+
265421
if __name__ == "__main__":
266422
import sys
267423

0 commit comments

Comments
 (0)