-
Notifications
You must be signed in to change notification settings - Fork 7.3k
[Data] [2/n] - Add predicate expression support for dataset.filter #56716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
8db6b11
c664348
27be0de
096051c
2d84e77
d250215
7ff4f9e
7758d67
735faa6
b85fb1a
9884ff8
1f6872c
86fb533
633d5e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,7 @@ | |
| U, | ||
| ) | ||
| from ray.data.context import DataContext | ||
| from ray.data.expressions import Expr | ||
|
|
||
| if TYPE_CHECKING: | ||
| import pandas | ||
|
|
@@ -619,3 +620,17 @@ def iter_rows( | |
| yield row.as_pydict() | ||
| else: | ||
| yield row | ||
|
|
||
| def filter(self, predicate_expr: "Expr") -> "pandas.DataFrame": | ||
| """Filter rows based on a predicate expression.""" | ||
| if self._table.empty: | ||
| return self._table | ||
|
|
||
| # TODO: Move _expression_evaluator to _internal | ||
| from ray.data._expression_evaluator import eval_expr | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move this to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll make this a TODO just to keep the change cleaner
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, let's do in a follow-up. But let's do it right away |
||
|
|
||
| # Evaluate the expression to get a boolean mask | ||
| mask = eval_expr(predicate_expr, self._table) | ||
|
|
||
| # Use pandas boolean indexing | ||
| return self._table[mask] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,8 @@ | |
| import pyarrow.compute as pc | ||
| import pyarrow.dataset as ds | ||
|
|
||
| from ray.data.expressions import Expr | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
|
|
@@ -36,8 +38,29 @@ def get_filters(expression: str) -> ds.Expression: | |
| logger.exception(f"Error processing expression: {e}") | ||
| raise | ||
|
|
||
| @staticmethod | ||
| def get_ray_data_expression(expression: str) -> "Expr": | ||
goutamvenkat-anyscale marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Parse and evaluate the expression to generate a Ray Data expression. | ||
|
|
||
| Args: | ||
| expression: A string representing the filter expression to parse. | ||
|
|
||
| Returns: | ||
| A Ray Data Expr object for filtering data. | ||
|
|
||
| """ | ||
| try: | ||
| tree = ast.parse(expression, mode="eval") | ||
| return _ConvertToRayDataExpressionVisitor().visit(tree.body) | ||
| except SyntaxError as e: | ||
| raise ValueError(f"Invalid syntax in the expression: {expression}") from e | ||
| except Exception as e: | ||
| logger.exception(f"Error processing expression: {e}") | ||
| raise | ||
|
|
||
|
|
||
| class _ConvertToArrowExpressionVisitor(ast.NodeVisitor): | ||
| # TODO: Deprecate this visitor after we remove string support in filter API. | ||
| def visit_Compare(self, node: ast.Compare) -> ds.Expression: | ||
| """Handle comparison operations (e.g., a == b, a < b, a in b). | ||
|
|
||
|
|
@@ -234,3 +257,150 @@ def visit_Call(self, node: ast.Call) -> ds.Expression: | |
| return function_map[func_name](*args) | ||
| else: | ||
| raise ValueError(f"Unsupported function: {func_name}") | ||
|
|
||
|
|
||
| class _ConvertToRayDataExpressionVisitor(ast.NodeVisitor): | ||
goutamvenkat-anyscale marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """AST visitor that converts string expressions to Ray Data expressions.""" | ||
|
|
||
| def visit_Compare(self, node: ast.Compare) -> "Expr": | ||
| """Handle comparison operations (e.g., a == b, a < b, a in b).""" | ||
| from ray.data.expressions import BinaryExpr, Operation | ||
|
|
||
| if len(node.ops) != 1 or len(node.comparators) != 1: | ||
| raise ValueError("Only simple binary comparisons are supported") | ||
|
|
||
| left = self.visit(node.left) | ||
| right = self.visit(node.comparators[0]) | ||
| op = node.ops[0] | ||
|
|
||
| # Map AST comparison operators to Ray Data operations | ||
| op_map = { | ||
| ast.Eq: Operation.EQ, | ||
| ast.NotEq: Operation.NE, | ||
| ast.Lt: Operation.LT, | ||
| ast.LtE: Operation.LE, | ||
| ast.Gt: Operation.GT, | ||
| ast.GtE: Operation.GE, | ||
| ast.In: Operation.IN, | ||
| ast.NotIn: Operation.NOT_IN, | ||
| } | ||
|
|
||
| if type(op) not in op_map: | ||
| raise ValueError(f"Unsupported comparison operator: {type(op).__name__}") | ||
|
|
||
| return BinaryExpr(op_map[type(op)], left, right) | ||
|
|
||
| def visit_BoolOp(self, node: ast.BoolOp) -> "Expr": | ||
| """Handle logical operations (e.g., a and b, a or b).""" | ||
| from ray.data.expressions import BinaryExpr, Operation | ||
|
|
||
| conditions = [self.visit(value) for value in node.values] | ||
| combined_expr = conditions[0] | ||
|
|
||
| for condition in conditions[1:]: | ||
| if isinstance(node.op, ast.And): | ||
| combined_expr = BinaryExpr(Operation.AND, combined_expr, condition) | ||
| elif isinstance(node.op, ast.Or): | ||
| combined_expr = BinaryExpr(Operation.OR, combined_expr, condition) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported logical operator: {type(node.op).__name__}" | ||
| ) | ||
|
|
||
| return combined_expr | ||
|
|
||
| def visit_UnaryOp(self, node: ast.UnaryOp) -> "Expr": | ||
| """Handle unary operations (e.g., not a, -5).""" | ||
| from ray.data.expressions import Operation, UnaryExpr, lit | ||
|
|
||
| if isinstance(node.op, ast.Not): | ||
| operand = self.visit(node.operand) | ||
| return UnaryExpr(Operation.NOT, operand) | ||
| elif isinstance(node.op, ast.USub): | ||
| # For negative numbers, treat as literal | ||
| if isinstance(node.operand, ast.Constant): | ||
| return lit(-node.operand.value) | ||
| else: | ||
| raise ValueError("Unary minus only supported for constant values") | ||
goutamvenkat-anyscale marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| else: | ||
| raise ValueError(f"Unsupported unary operator: {type(node.op).__name__}") | ||
|
|
||
| def visit_Name(self, node: ast.Name) -> "Expr": | ||
| """Handle variable names (column references).""" | ||
| from ray.data.expressions import col | ||
|
|
||
| return col(node.id) | ||
|
|
||
| def visit_Constant(self, node: ast.Constant) -> "Expr": | ||
| """Handle constant values (numbers, strings, booleans).""" | ||
| from ray.data.expressions import lit | ||
|
|
||
| return lit(node.value) | ||
|
|
||
| def visit_List(self, node: ast.List) -> "Expr": | ||
| """Handle list literals.""" | ||
| from ray.data.expressions import LiteralExpr, lit | ||
|
|
||
| # Visit all elements first | ||
| visited_elements = [self.visit(elt) for elt in node.elts] | ||
|
|
||
| # Try to extract constant values for literal list | ||
| elements = [] | ||
| for elem in visited_elements: | ||
| if isinstance(elem, LiteralExpr): | ||
| elements.append(elem.value) | ||
| else: | ||
| # For compatibility with Arrow visitor, we need to support non-literals | ||
| # but Ray Data expressions may have limitations here | ||
| raise ValueError( | ||
| "List contains non-constant expressions. Ray Data expressions " | ||
| "currently only support lists of constant values for 'in' operations." | ||
| ) | ||
|
Comment on lines
347
to
355
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, you don't know if this list is gonna be used in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point |
||
|
|
||
| return lit(elements) | ||
|
|
||
| def visit_Attribute(self, node: ast.Attribute) -> "Expr": | ||
| """Handle attribute access (e.g., for nested column names).""" | ||
| from ray.data.expressions import col | ||
|
|
||
| # For nested column names like "user.age", combine them with dots | ||
| if isinstance(node.value, ast.Name): | ||
| return col(f"{node.value.id}.{node.attr}") | ||
| elif isinstance(node.value, ast.Attribute): | ||
| # Recursively handle nested attributes | ||
| left_expr = self.visit(node.value) | ||
| if hasattr(left_expr, "_name"): # ColumnExpr | ||
goutamvenkat-anyscale marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return col(f"{left_expr._name}.{node.attr}") | ||
|
|
||
| raise ValueError(f"Unsupported attribute access: {node.attr}") | ||
|
||
|
|
||
| def visit_Call(self, node: ast.Call) -> "Expr": | ||
| """Handle function calls for operations like is_null, is_not_null, is_nan.""" | ||
| from ray.data.expressions import BinaryExpr, Operation, UnaryExpr | ||
|
|
||
| func_name = node.func.id if isinstance(node.func, ast.Name) else str(node.func) | ||
|
|
||
| if func_name == "is_null": | ||
| if len(node.args) != 1: | ||
| raise ValueError("is_null() expects exactly one argument") | ||
| operand = self.visit(node.args[0]) | ||
| return UnaryExpr(Operation.IS_NULL, operand) | ||
| elif func_name == "is_valid" or func_name == "is_not_null": | ||
goutamvenkat-anyscale marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if len(node.args) != 1: | ||
| raise ValueError(f"{func_name}() expects exactly one argument") | ||
| operand = self.visit(node.args[0]) | ||
| return UnaryExpr(Operation.IS_NOT_NULL, operand) | ||
| elif func_name == "is_nan": | ||
| if len(node.args) != 1: | ||
| raise ValueError("is_nan() expects exactly one argument") | ||
| operand = self.visit(node.args[0]) | ||
| # Use x != x pattern for NaN detection (NaN != NaN is True) | ||
| return BinaryExpr(Operation.NE, operand, operand) | ||
| elif func_name == "is_in": | ||
| if len(node.args) != 2: | ||
| raise ValueError("is_in() expects exactly two arguments") | ||
| left = self.visit(node.args[0]) | ||
| right = self.visit(node.args[1]) | ||
| return BinaryExpr(Operation.IN, left, right) | ||
| else: | ||
| raise ValueError(f"Unsupported function: {func_name}") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,7 +51,6 @@ | |
| from ray.data._internal.output_buffer import OutputBlockSizeOption | ||
| from ray.data._internal.util import _truncated_repr | ||
| from ray.data.block import ( | ||
| BatchFormat, | ||
| Block, | ||
| BlockAccessor, | ||
| CallableClass, | ||
|
|
@@ -218,26 +217,24 @@ def plan_filter_op( | |
| target_max_block_size=data_context.target_max_block_size, | ||
| ) | ||
|
|
||
| expression = op._filter_expr | ||
| predicate_expr = op._predicate_expr | ||
| compute = get_compute(op._compute) | ||
| if expression is not None: | ||
| if predicate_expr is not None: | ||
|
|
||
| def filter_batch_fn(block: "pa.Table") -> "pa.Table": | ||
| try: | ||
| return block.filter(expression) | ||
| except Exception as e: | ||
| _try_wrap_udf_exception(e) | ||
| def filter_block_fn( | ||
| blocks: Iterable[Block], ctx: TaskContext | ||
| ) -> Iterable[Block]: | ||
| for block in blocks: | ||
| block_accessor = BlockAccessor.for_block(block) | ||
| filtered_block = block_accessor.filter(predicate_expr) | ||
| yield filtered_block | ||
|
|
||
| init_fn = None | ||
| transform_fn = BatchMapTransformFn( | ||
| _generate_transform_fn_for_map_batches(filter_batch_fn), | ||
| batch_size=None, | ||
| batch_format=BatchFormat.ARROW, | ||
| zero_copy_batch=True, | ||
| transform_fn = BlockMapTransformFn( | ||
| filter_block_fn, | ||
| is_udf=True, | ||
| output_block_size_option=output_block_size_option, | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Import Error and Inconsistent Exception HandlingThe
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... Ofc it's imported... |
||
|
|
||
| else: | ||
| udf_is_callable_class = isinstance(op._fn, CallableClass) | ||
| filter_fn, init_fn = _get_udf( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Missing Dependencies in New Test Target
The new
test_filterpy_testtarget inBUILD.bazelis missing itsdepssection. This preventstest_filter.pyfrom resolving imports forrayandray.datamodules, which are typically provided by":conftest"and"//:ray_lib"dependencies, as seen in other test targets.