Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/ray/data/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,20 @@ py_test(
],
)

py_test(
name = "test_filter",
size = "medium",
srcs = ["tests/test_filter.py"],
tags = [
"exclusive",
"team:data",
],
deps = [
":conftest",
"//:ray_lib",
],
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Missing Dependencies in New Test Target

The new test_filter py_test target in BUILD.bazel is missing its deps section. This prevents test_filter.py from resolving imports for ray and ray.data modules, which are typically provided by ":conftest" and "//:ray_lib" dependencies, as seen in other test targets.

Fix in Cursor Fix in Web


py_test(
name = "test_numpy",
size = "medium",
Expand Down
14 changes: 14 additions & 0 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
U,
)
from ray.data.context import DEFAULT_TARGET_MAX_BLOCK_SIZE, DataContext
from ray.data.expressions import Expr

try:
import pyarrow
Expand Down Expand Up @@ -463,6 +464,19 @@ def iter_rows(
for i in range(self.num_rows()):
yield self._get_row(i)

def filter(self, predicate_expr: "Expr") -> "pyarrow.Table":
"""Filter rows based on a predicate expression."""
if self._table.num_rows == 0:
return self._table

from ray.data._expression_evaluator import eval_expr

# Evaluate the expression to get a boolean mask
mask = eval_expr(predicate_expr, self._table)

# Use PyArrow's built-in filter method
return self._table.filter(mask)


class ArrowBlockColumnAccessor(BlockColumnAccessor):
def __init__(self, col: Union["pyarrow.Array", "pyarrow.ChunkedArray"]):
Expand Down
20 changes: 10 additions & 10 deletions python/ray/data/_internal/logical/operators/map_operator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import inspect
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional
from typing import Any, Callable, Dict, Iterable, List, Optional

from ray.data._internal.compute import ComputeStrategy, TaskPoolStrategy
from ray.data._internal.logical.interfaces import LogicalOperator
Expand All @@ -10,10 +10,6 @@
from ray.data.expressions import Expr
from ray.data.preprocessor import Preprocessor

if TYPE_CHECKING:
import pyarrow as pa


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -235,20 +231,24 @@ class Filter(AbstractUDFMap):
def __init__(
self,
input_op: LogicalOperator,
predicate_expr: Optional[Expr] = None,
fn: Optional[UserDefinedFunction] = None,
fn_args: Optional[Iterable[Any]] = None,
fn_kwargs: Optional[Dict[str, Any]] = None,
fn_constructor_args: Optional[Iterable[Any]] = None,
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
filter_expr: Optional["pa.dataset.Expression"] = None,
compute: Optional[ComputeStrategy] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
# Ensure exactly one of fn or filter_expr is provided
if not ((fn is None) ^ (filter_expr is None)):
raise ValueError("Exactly one of 'fn' or 'filter_expr' must be provided")
self._filter_expr = filter_expr
# Ensure exactly one of fn, or predicate_expr is provided
provided_params = sum([fn is not None, predicate_expr is not None])
if provided_params != 1:
raise ValueError(
f"Exactly one of 'fn', or 'predicate_expr' must be provided (received fn={fn}, predicate_expr={predicate_expr})"
)

self._predicate_expr = predicate_expr

super().__init__(
"Filter",
Expand Down
15 changes: 15 additions & 0 deletions python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
U,
)
from ray.data.context import DataContext
from ray.data.expressions import Expr

if TYPE_CHECKING:
import pandas
Expand Down Expand Up @@ -619,3 +620,17 @@ def iter_rows(
yield row.as_pydict()
else:
yield row

def filter(self, predicate_expr: "Expr") -> "pandas.DataFrame":
"""Filter rows based on a predicate expression."""
if self._table.empty:
return self._table

# TODO: Move _expression_evaluator to _internal
from ray.data._expression_evaluator import eval_expr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to _internal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make this a TODO just to keep the change cleaner

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's do in a follow-up. But let's do it right away


# Evaluate the expression to get a boolean mask
mask = eval_expr(predicate_expr, self._table)

# Use pandas boolean indexing
return self._table[mask]
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pyarrow.compute as pc
import pyarrow.dataset as ds

from ray.data.expressions import ColumnExpr, Expr

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -36,8 +38,29 @@ def get_filters(expression: str) -> ds.Expression:
logger.exception(f"Error processing expression: {e}")
raise

@staticmethod
def parse_native_expression(expression: str) -> "Expr":
"""Parse and evaluate the expression to generate a Ray Data expression.

Args:
expression: A string representing the filter expression to parse.

Returns:
A Ray Data Expr object for filtering data.

"""
try:
tree = ast.parse(expression, mode="eval")
return _ConvertToNativeExpressionVisitor().visit(tree.body)
except SyntaxError as e:
raise ValueError(f"Invalid syntax in the expression: {expression}") from e
except Exception as e:
logger.exception(f"Error processing expression: {e}")
raise


class _ConvertToArrowExpressionVisitor(ast.NodeVisitor):
# TODO: Deprecate this visitor after we remove string support in filter API.
def visit_Compare(self, node: ast.Compare) -> ds.Expression:
"""Handle comparison operations (e.g., a == b, a < b, a in b).

Expand Down Expand Up @@ -234,3 +257,151 @@ def visit_Call(self, node: ast.Call) -> ds.Expression:
return function_map[func_name](*args)
else:
raise ValueError(f"Unsupported function: {func_name}")


class _ConvertToNativeExpressionVisitor(ast.NodeVisitor):
"""AST visitor that converts string expressions to Ray Data expressions."""

def visit_Compare(self, node: ast.Compare) -> "Expr":
"""Handle comparison operations (e.g., a == b, a < b, a in b)."""
from ray.data.expressions import BinaryExpr, Operation

if len(node.ops) != 1 or len(node.comparators) != 1:
raise ValueError("Only simple binary comparisons are supported")

left = self.visit(node.left)
right = self.visit(node.comparators[0])
op = node.ops[0]

# Map AST comparison operators to Ray Data operations
op_map = {
ast.Eq: Operation.EQ,
ast.NotEq: Operation.NE,
ast.Lt: Operation.LT,
ast.LtE: Operation.LE,
ast.Gt: Operation.GT,
ast.GtE: Operation.GE,
ast.In: Operation.IN,
ast.NotIn: Operation.NOT_IN,
}

if type(op) not in op_map:
raise ValueError(f"Unsupported comparison operator: {type(op).__name__}")

return BinaryExpr(op_map[type(op)], left, right)

def visit_BoolOp(self, node: ast.BoolOp) -> "Expr":
"""Handle logical operations (e.g., a and b, a or b)."""
from ray.data.expressions import BinaryExpr, Operation

conditions = [self.visit(value) for value in node.values]
combined_expr = conditions[0]

for condition in conditions[1:]:
if isinstance(node.op, ast.And):
combined_expr = BinaryExpr(Operation.AND, combined_expr, condition)
elif isinstance(node.op, ast.Or):
combined_expr = BinaryExpr(Operation.OR, combined_expr, condition)
else:
raise ValueError(
f"Unsupported logical operator: {type(node.op).__name__}"
)

return combined_expr

def visit_UnaryOp(self, node: ast.UnaryOp) -> "Expr":
"""Handle unary operations (e.g., not a, -5)."""
from ray.data.expressions import Operation, UnaryExpr, lit

if isinstance(node.op, ast.Not):
operand = self.visit(node.operand)
return UnaryExpr(Operation.NOT, operand)
elif isinstance(node.op, ast.USub):
# For negative numbers, treat as literal
if isinstance(node.operand, ast.Constant):
return lit(-node.operand.value)
else:
raise ValueError("Unary minus only supported for constant values")
else:
raise ValueError(f"Unsupported unary operator: {type(node.op).__name__}")

def visit_Name(self, node: ast.Name) -> "Expr":
"""Handle variable names (column references)."""
from ray.data.expressions import col

return col(node.id)

def visit_Constant(self, node: ast.Constant) -> "Expr":
"""Handle constant values (numbers, strings, booleans)."""
from ray.data.expressions import lit

return lit(node.value)

def visit_List(self, node: ast.List) -> "Expr":
"""Handle list literals."""
from ray.data.expressions import LiteralExpr, lit

# Visit all elements first
visited_elements = [self.visit(elt) for elt in node.elts]

# Try to extract constant values for literal list
elements = []
for elem in visited_elements:
if isinstance(elem, LiteralExpr):
elements.append(elem.value)
else:
# For compatibility with Arrow visitor, we need to support non-literals
# but Ray Data expressions may have limitations here
raise ValueError(
"List contains non-constant expressions. Ray Data expressions "
"currently only support lists of constant values."
)

return lit(elements)

def visit_Attribute(self, node: ast.Attribute) -> "Expr":
"""Handle attribute access (e.g., for nested column names)."""
from ray.data.expressions import col

# For nested column names like "user.age", combine them with dots
if isinstance(node.value, ast.Name):
return col(f"{node.value.id}.{node.attr}")
elif isinstance(node.value, ast.Attribute):
# Recursively handle nested attributes
left_expr = self.visit(node.value)
if isinstance(left_expr, ColumnExpr):
return col(f"{left_expr._name}.{node.attr}")

raise ValueError(f"Unsupported attribute access: {node.attr}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not gonna be enough for us to debug it, right?

Add the log of the whole node, plus expr we parsed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll use ast.dump on the node


def visit_Call(self, node: ast.Call) -> "Expr":
"""Handle function calls for operations like is_null, is_not_null, is_nan."""
from ray.data.expressions import BinaryExpr, Operation, UnaryExpr

func_name = node.func.id if isinstance(node.func, ast.Name) else str(node.func)

if func_name == "is_null":
if len(node.args) != 1:
raise ValueError("is_null() expects exactly one argument")
operand = self.visit(node.args[0])
return UnaryExpr(Operation.IS_NULL, operand)
# Adding this conditional to keep it consistent with the existing implementation.
elif func_name == "is_valid" or func_name == "is_not_null":
if len(node.args) != 1:
raise ValueError(f"{func_name}() expects exactly one argument")
operand = self.visit(node.args[0])
return UnaryExpr(Operation.IS_NOT_NULL, operand)
elif func_name == "is_nan":
if len(node.args) != 1:
raise ValueError("is_nan() expects exactly one argument")
operand = self.visit(node.args[0])
# Use x != x pattern for NaN detection (NaN != NaN is True)
return BinaryExpr(Operation.NE, operand, operand)
elif func_name == "is_in":
if len(node.args) != 2:
raise ValueError("is_in() expects exactly two arguments")
left = self.visit(node.args[0])
right = self.visit(node.args[1])
return BinaryExpr(Operation.IN, left, right)
else:
raise ValueError(f"Unsupported function: {func_name}")
25 changes: 11 additions & 14 deletions python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from ray.data._internal.output_buffer import OutputBlockSizeOption
from ray.data._internal.util import _truncated_repr
from ray.data.block import (
BatchFormat,
Block,
BlockAccessor,
CallableClass,
Expand Down Expand Up @@ -218,26 +217,24 @@ def plan_filter_op(
target_max_block_size=data_context.target_max_block_size,
)

expression = op._filter_expr
predicate_expr = op._predicate_expr
compute = get_compute(op._compute)
if expression is not None:
if predicate_expr is not None:

def filter_batch_fn(block: "pa.Table") -> "pa.Table":
try:
return block.filter(expression)
except Exception as e:
_try_wrap_udf_exception(e)
def filter_block_fn(
blocks: Iterable[Block], ctx: TaskContext
) -> Iterable[Block]:
for block in blocks:
block_accessor = BlockAccessor.for_block(block)
filtered_block = block_accessor.filter(predicate_expr)
yield filtered_block

init_fn = None
transform_fn = BatchMapTransformFn(
_generate_transform_fn_for_map_batches(filter_batch_fn),
batch_size=None,
batch_format=BatchFormat.ARROW,
zero_copy_batch=True,
transform_fn = BlockMapTransformFn(
filter_block_fn,
is_udf=True,
output_block_size_option=output_block_size_option,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Import Error and Inconsistent Exception Handling

The BlockMapTransformFn class is used without being imported, which causes a NameError. Additionally, the predicate expression filter path lacks _try_wrap_udf_exception, leading to inconsistent error reporting for user code exceptions compared to the string expression path.

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Ofc it's imported...


else:
udf_is_callable_class = isinstance(op._fn, CallableClass)
filter_fn, init_fn = _get_udf(
Expand Down
Loading