Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions python/ray/data/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from enum import Enum
from typing import Any, Callable, Dict, List, Union

import pyarrow

from ray.data.block import BatchColumn
from ray.data.datatype import DataType
from ray.util.annotations import DeveloperAPI, PublicAPI
Expand Down Expand Up @@ -59,6 +61,131 @@ class Operation(Enum):
NOT_IN = "not_in"


class _PyArrowExpressionVisitor:
"""Visitor that converts Ray Data expressions to PyArrow compute expressions.

This follows the visitor pattern similar to ast.NodeVisitor, where a generic
visit() method dispatches to type-specific visit methods.
"""

def visit(self, expr: "Expr") -> "pyarrow.compute.Expression":
"""Visit an expression node and convert it to PyArrow.

Args:
expr: The expression to convert

Returns:
A PyArrow compute expression

Raises:
ValueError: If the operation is not supported by PyArrow
TypeError: If the expression type cannot be converted to PyArrow
"""
# Dispatch to the appropriate visit method based on expression type
if isinstance(expr, ColumnExpr):
return self.visit_column(expr)
elif isinstance(expr, LiteralExpr):
return self.visit_literal(expr)
elif isinstance(expr, BinaryExpr):
return self.visit_binary(expr)
elif isinstance(expr, UnaryExpr):
return self.visit_unary(expr)
elif isinstance(expr, AliasExpr):
return self.visit_alias(expr)
elif isinstance(expr, UDFExpr):
return self.visit_udf(expr)
elif isinstance(expr, DownloadExpr):
return self.visit_download(expr)
else:
raise TypeError(
f"Unsupported expression type for PyArrow conversion: {type(expr)}"
)

def visit_column(self, expr: "ColumnExpr") -> "pyarrow.compute.Expression":
"""Convert a ColumnExpr to PyArrow field reference."""
import pyarrow.compute as pc

return pc.field(expr.name)

def visit_literal(self, expr: "LiteralExpr") -> "pyarrow.compute.Expression":
"""Convert a LiteralExpr to PyArrow scalar."""
import pyarrow.compute as pc

return pc.scalar(expr.value)

def visit_binary(self, expr: "BinaryExpr") -> "pyarrow.compute.Expression":
"""Convert a BinaryExpr to PyArrow binary operation."""
import pyarrow as pa
import pyarrow.compute as pc

# Special handling for IN and NOT_IN operations
# PyArrow's is_in expects the value_set to be an array, not an expression
if expr.op in (Operation.IN, Operation.NOT_IN):
left = self.visit(expr.left)

# Check the type of the right operand BEFORE visiting
if isinstance(expr.right, LiteralExpr):
# For literal lists, convert directly to pa.array
right_value = expr.right.value
if isinstance(right_value, list):
right = pa.array(right_value)
else:
# Single value, wrap in array
right = pa.array([right_value])
else:
# PyArrow's is_in doesn't support expressions as the value_set
# It requires an actual array of values
raise ValueError(
f"is_in/not_in operations require the right operand to be a "
f"literal list, got {type(expr.right).__name__}. "
f"Column-to-column is_in is not supported in PyArrow expressions."
)

# Now apply the operation
result = pc.is_in(left, right)
if expr.op == Operation.NOT_IN:
result = pc.invert(result)
return result

# For all other operations, recursively visit both operands
left = self.visit(expr.left)
right = self.visit(expr.right)

# Reuse the Arrow operations map from the evaluator for other operations
from ray.data._expression_evaluator import _ARROW_EXPR_OPS_MAP

if expr.op in _ARROW_EXPR_OPS_MAP:
return _ARROW_EXPR_OPS_MAP[expr.op](left, right)
else:
raise ValueError(f"Unsupported binary operation for PyArrow: {expr.op}")

def visit_unary(self, expr: "UnaryExpr") -> "pyarrow.compute.Expression":
"""Convert a UnaryExpr to PyArrow unary operation."""
# Recursively visit operand
operand = self.visit(expr.operand)

from ray.data._expression_evaluator import _ARROW_EXPR_OPS_MAP

if expr.op in _ARROW_EXPR_OPS_MAP:
return _ARROW_EXPR_OPS_MAP[expr.op](operand)
else:
raise ValueError(f"Unsupported unary operation for PyArrow: {expr.op}")

def visit_alias(self, expr: "AliasExpr") -> "pyarrow.compute.Expression":
"""Convert an AliasExpr by converting its inner expression."""
return self.visit(expr.expr)

def visit_udf(self, expr: "UDFExpr") -> "pyarrow.compute.Expression":
"""UDF expressions cannot be converted to PyArrow."""
raise TypeError("UDF expressions cannot be converted to PyArrow expressions")

def visit_download(self, expr: "DownloadExpr") -> "pyarrow.compute.Expression":
"""Download expressions cannot be converted to PyArrow."""
raise TypeError(
"Download expressions cannot be converted to PyArrow expressions"
)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True)
class Expr(ABC):
Expand Down Expand Up @@ -101,6 +228,18 @@ def structurally_equals(self, other: Any) -> bool:
"""Compare two expression ASTs for structural equality."""
raise NotImplementedError

def to_pyarrow(self) -> "pyarrow.compute.Expression":
"""Convert this Ray Data expression to a PyArrow compute expression.

Returns:
A PyArrow compute expression equivalent to this Ray Data expression.

Raises:
ValueError: If the expression contains operations not supported by PyArrow.
TypeError: If the expression type cannot be converted to PyArrow.
"""
return _PyArrowExpressionVisitor().visit(self)

def _bin(self, other: Any, op: Operation) -> "Expr":
"""Create a binary expression with the given operation.

Expand Down
156 changes: 156 additions & 0 deletions python/ray/data/tests/test_expressions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pyarrow as pa
import pyarrow.compute as pc
import pytest

from ray.data.expressions import (
Expand Down Expand Up @@ -262,6 +264,160 @@ def test_complex_boolean_expressions(self):
assert very_complex.op == Operation.AND


class TestToPyArrow:
"""Test conversion of Ray Data expressions to PyArrow compute expressions."""

@pytest.mark.parametrize(
"ray_expr, equivalent_pyarrow_expr, description",
[
# Basic expressions
(col("age"), lambda: pc.field("age"), "column reference"),
(lit(42), lambda: pc.scalar(42), "integer literal"),
(lit("hello"), lambda: pc.scalar("hello"), "string literal"),
# Arithmetic operations
(
col("x") + 5,
lambda: pc.add(pc.field("x"), pc.scalar(5)),
"addition",
),
(
col("x") * 2,
lambda: pc.multiply(pc.field("x"), pc.scalar(2)),
"multiplication",
),
# Comparison operations
(
col("age") > 18,
lambda: pc.greater(pc.field("age"), pc.scalar(18)),
"greater than",
),
(
col("status") == "active",
lambda: pc.equal(pc.field("status"), pc.scalar("active")),
"equality",
),
# Boolean operations
(
(col("age") > 18) & (col("age") < 65),
lambda: pc.and_kleene(
pc.greater(pc.field("age"), pc.scalar(18)),
pc.less(pc.field("age"), pc.scalar(65)),
),
"logical AND",
),
(
~(col("active")),
lambda: pc.invert(pc.field("active")),
"logical NOT",
),
# Unary operations
(
col("value").is_null(),
lambda: pc.is_null(pc.field("value")),
"is_null check",
),
# In operations
(
col("status").is_in(["active", "pending"]),
lambda: pc.is_in(pc.field("status"), pa.array(["active", "pending"])),
"is_in with list",
),
# Complex nested expressions
(
(col("price") * col("quantity")) + col("tax"),
lambda: pc.add(
pc.multiply(pc.field("price"), pc.field("quantity")),
pc.field("tax"),
),
"nested arithmetic",
),
# Alias expressions (should unwrap to inner expression)
(
(col("x") + 5).alias("result"),
lambda: pc.add(pc.field("x"), pc.scalar(5)),
"aliased expression",
),
],
ids=[
"col",
"int_lit",
"str_lit",
"add",
"mul",
"gt",
"eq",
"and",
"not",
"is_null",
"is_in",
"nested",
"alias",
],
)
def test_to_pyarrow_equivalence(
self, ray_expr, equivalent_pyarrow_expr, description
):
"""Test that Ray Data expressions convert to equivalent PyArrow expressions.

This test documents the expected PyArrow expression for each Ray Data expression
and verifies correctness by comparing results on sample data.
"""
import pyarrow.dataset as ds

# Convert Ray expression to PyArrow
converted = ray_expr.to_pyarrow()
expected = equivalent_pyarrow_expr()

# Both should be PyArrow expressions
assert isinstance(converted, pc.Expression)
assert isinstance(expected, pc.Expression)

# Verify they produce the same results on sample data
test_data = pa.table(
{
"age": [15, 25, 45, 70],
"x": [1, 2, 3, 4],
"price": [10.0, 20.0, 30.0, 40.0],
"quantity": [2, 3, 1, 5],
"tax": [1.0, 2.0, 3.0, 4.0],
"status": ["active", "pending", "inactive", "active"],
"value": [1, None, 3, None],
"active": [True, False, True, False],
}
)

dataset = ds.dataset(test_data)

try:
# For boolean expressions, compare filter results
result_converted = dataset.scanner(filter=converted).to_table()
result_expected = dataset.scanner(filter=expected).to_table()
assert result_converted.equals(
result_expected
), f"Expressions produce different results for {description}"
except (TypeError, pa.lib.ArrowInvalid, pa.lib.ArrowNotImplementedError):
# For non-boolean expressions, just verify both are valid
pass

def test_to_pyarrow_unsupported_expressions(self):
"""Test that unsupported expression types raise appropriate errors."""
from ray.data.datatype import DataType
from ray.data.expressions import UDFExpr

def dummy_fn(x):
return x

udf_expr = UDFExpr(
fn=dummy_fn,
args=[col("x")],
kwargs={},
data_type=DataType(int),
)

with pytest.raises(TypeError, match="UDF expressions cannot be converted"):
udf_expr.to_pyarrow()


if __name__ == "__main__":
import sys

Expand Down