diff --git a/python/ray/data/BUILD.bazel b/python/ray/data/BUILD.bazel index 680255466b0c..af837f255b6c 100644 --- a/python/ray/data/BUILD.bazel +++ b/python/ray/data/BUILD.bazel @@ -700,6 +700,20 @@ py_test( ], ) +py_test( + name = "test_filter", + size = "medium", + srcs = ["tests/test_filter.py"], + tags = [ + "exclusive", + "team:data", + ], + deps = [ + ":conftest", + "//:ray_lib", + ], +) + py_test( name = "test_numpy", size = "medium", diff --git a/python/ray/data/_internal/arrow_block.py b/python/ray/data/_internal/arrow_block.py index 22a872c22fe1..5a64eb1c2bc2 100644 --- a/python/ray/data/_internal/arrow_block.py +++ b/python/ray/data/_internal/arrow_block.py @@ -39,6 +39,7 @@ U, ) from ray.data.context import DEFAULT_TARGET_MAX_BLOCK_SIZE, DataContext +from ray.data.expressions import Expr try: import pyarrow @@ -463,6 +464,19 @@ def iter_rows( for i in range(self.num_rows()): yield self._get_row(i) + def filter(self, predicate_expr: "Expr") -> "pyarrow.Table": + """Filter rows based on a predicate expression.""" + if self._table.num_rows == 0: + return self._table + + from ray.data._expression_evaluator import eval_expr + + # Evaluate the expression to get a boolean mask + mask = eval_expr(predicate_expr, self._table) + + # Use PyArrow's built-in filter method + return self._table.filter(mask) + class ArrowBlockColumnAccessor(BlockColumnAccessor): def __init__(self, col: Union["pyarrow.Array", "pyarrow.ChunkedArray"]): diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index 1d8ee48ad765..63cf35410237 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -1,7 +1,7 @@ import functools import inspect import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional from ray.data._internal.compute import ComputeStrategy, TaskPoolStrategy from ray.data._internal.logical.interfaces import LogicalOperator @@ -10,10 +10,6 @@ from ray.data.expressions import Expr from ray.data.preprocessor import Preprocessor -if TYPE_CHECKING: - import pyarrow as pa - - logger = logging.getLogger(__name__) @@ -235,20 +231,24 @@ class Filter(AbstractUDFMap): def __init__( self, input_op: LogicalOperator, + predicate_expr: Optional[Expr] = None, fn: Optional[UserDefinedFunction] = None, fn_args: Optional[Iterable[Any]] = None, fn_kwargs: Optional[Dict[str, Any]] = None, fn_constructor_args: Optional[Iterable[Any]] = None, fn_constructor_kwargs: Optional[Dict[str, Any]] = None, - filter_expr: Optional["pa.dataset.Expression"] = None, compute: Optional[ComputeStrategy] = None, ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, ray_remote_args: Optional[Dict[str, Any]] = None, ): - # Ensure exactly one of fn or filter_expr is provided - if not ((fn is None) ^ (filter_expr is None)): - raise ValueError("Exactly one of 'fn' or 'filter_expr' must be provided") - self._filter_expr = filter_expr + # Ensure exactly one of fn, or predicate_expr is provided + provided_params = sum([fn is not None, predicate_expr is not None]) + if provided_params != 1: + raise ValueError( + f"Exactly one of 'fn', or 'predicate_expr' must be provided (received fn={fn}, predicate_expr={predicate_expr})" + ) + + self._predicate_expr = predicate_expr super().__init__( "Filter", diff --git a/python/ray/data/_internal/pandas_block.py b/python/ray/data/_internal/pandas_block.py index 683bd70055cc..92ad48ea50b1 100644 --- a/python/ray/data/_internal/pandas_block.py +++ b/python/ray/data/_internal/pandas_block.py @@ -34,6 +34,7 @@ U, ) from ray.data.context import DataContext +from ray.data.expressions import Expr if TYPE_CHECKING: import pandas @@ -619,3 +620,17 @@ def iter_rows( yield row.as_pydict() else: yield row + + def filter(self, predicate_expr: "Expr") -> "pandas.DataFrame": + """Filter rows based on a predicate expression.""" + if self._table.empty: + return self._table + + # TODO: Move _expression_evaluator to _internal + from ray.data._expression_evaluator import eval_expr + + # Evaluate the expression to get a boolean mask + mask = eval_expr(predicate_expr, self._table) + + # Use pandas boolean indexing + return self._table[mask] diff --git a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py index 938c2a2d21fc..df2a1c066c3c 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py @@ -5,6 +5,8 @@ import pyarrow.compute as pc import pyarrow.dataset as ds +from ray.data.expressions import ColumnExpr, Expr + logger = logging.getLogger(__name__) @@ -36,8 +38,29 @@ def get_filters(expression: str) -> ds.Expression: logger.exception(f"Error processing expression: {e}") raise + @staticmethod + def parse_native_expression(expression: str) -> "Expr": + """Parse and evaluate the expression to generate a Ray Data expression. + + Args: + expression: A string representing the filter expression to parse. + + Returns: + A Ray Data Expr object for filtering data. + + """ + try: + tree = ast.parse(expression, mode="eval") + return _ConvertToNativeExpressionVisitor().visit(tree.body) + except SyntaxError as e: + raise ValueError(f"Invalid syntax in the expression: {expression}") from e + except Exception as e: + logger.exception(f"Error processing expression: {e}") + raise + class _ConvertToArrowExpressionVisitor(ast.NodeVisitor): + # TODO: Deprecate this visitor after we remove string support in filter API. def visit_Compare(self, node: ast.Compare) -> ds.Expression: """Handle comparison operations (e.g., a == b, a < b, a in b). @@ -234,3 +257,151 @@ def visit_Call(self, node: ast.Call) -> ds.Expression: return function_map[func_name](*args) else: raise ValueError(f"Unsupported function: {func_name}") + + +class _ConvertToNativeExpressionVisitor(ast.NodeVisitor): + """AST visitor that converts string expressions to Ray Data expressions.""" + + def visit_Compare(self, node: ast.Compare) -> "Expr": + """Handle comparison operations (e.g., a == b, a < b, a in b).""" + from ray.data.expressions import BinaryExpr, Operation + + if len(node.ops) != 1 or len(node.comparators) != 1: + raise ValueError("Only simple binary comparisons are supported") + + left = self.visit(node.left) + right = self.visit(node.comparators[0]) + op = node.ops[0] + + # Map AST comparison operators to Ray Data operations + op_map = { + ast.Eq: Operation.EQ, + ast.NotEq: Operation.NE, + ast.Lt: Operation.LT, + ast.LtE: Operation.LE, + ast.Gt: Operation.GT, + ast.GtE: Operation.GE, + ast.In: Operation.IN, + ast.NotIn: Operation.NOT_IN, + } + + if type(op) not in op_map: + raise ValueError(f"Unsupported comparison operator: {type(op).__name__}") + + return BinaryExpr(op_map[type(op)], left, right) + + def visit_BoolOp(self, node: ast.BoolOp) -> "Expr": + """Handle logical operations (e.g., a and b, a or b).""" + from ray.data.expressions import BinaryExpr, Operation + + conditions = [self.visit(value) for value in node.values] + combined_expr = conditions[0] + + for condition in conditions[1:]: + if isinstance(node.op, ast.And): + combined_expr = BinaryExpr(Operation.AND, combined_expr, condition) + elif isinstance(node.op, ast.Or): + combined_expr = BinaryExpr(Operation.OR, combined_expr, condition) + else: + raise ValueError( + f"Unsupported logical operator: {type(node.op).__name__}" + ) + + return combined_expr + + def visit_UnaryOp(self, node: ast.UnaryOp) -> "Expr": + """Handle unary operations (e.g., not a, -5).""" + from ray.data.expressions import Operation, UnaryExpr, lit + + if isinstance(node.op, ast.Not): + operand = self.visit(node.operand) + return UnaryExpr(Operation.NOT, operand) + elif isinstance(node.op, ast.USub): + operand = self.visit(node.operand) + return operand * lit(-1) + else: + raise ValueError(f"Unsupported unary operator: {type(node.op).__name__}") + + def visit_Name(self, node: ast.Name) -> "Expr": + """Handle variable names (column references).""" + from ray.data.expressions import col + + return col(node.id) + + def visit_Constant(self, node: ast.Constant) -> "Expr": + """Handle constant values (numbers, strings, booleans).""" + from ray.data.expressions import lit + + return lit(node.value) + + def visit_List(self, node: ast.List) -> "Expr": + """Handle list literals.""" + from ray.data.expressions import LiteralExpr, lit + + # Visit all elements first + visited_elements = [self.visit(elt) for elt in node.elts] + + # Try to extract constant values for literal list + elements = [] + for elem in visited_elements: + if isinstance(elem, LiteralExpr): + elements.append(elem.value) + else: + # For compatibility with Arrow visitor, we need to support non-literals + # but Ray Data expressions may have limitations here + raise ValueError( + "List contains non-constant expressions. Ray Data expressions " + "currently only support lists of constant values." + ) + + return lit(elements) + + def visit_Attribute(self, node: ast.Attribute) -> "Expr": + """Handle attribute access (e.g., for nested column names).""" + from ray.data.expressions import col + + # For nested column names like "user.age", combine them with dots + if isinstance(node.value, ast.Name): + return col(f"{node.value.id}.{node.attr}") + elif isinstance(node.value, ast.Attribute): + # Recursively handle nested attributes + left_expr = self.visit(node.value) + if isinstance(left_expr, ColumnExpr): + return col(f"{left_expr._name}.{node.attr}") + + raise ValueError( + f"Unsupported attribute access: {node.attr}. Node details: {ast.dump(node)}" + ) + + def visit_Call(self, node: ast.Call) -> "Expr": + """Handle function calls for operations like is_null, is_not_null, is_nan.""" + from ray.data.expressions import BinaryExpr, Operation, UnaryExpr + + func_name = node.func.id if isinstance(node.func, ast.Name) else str(node.func) + + if func_name == "is_null": + if len(node.args) != 1: + raise ValueError("is_null() expects exactly one argument") + operand = self.visit(node.args[0]) + return UnaryExpr(Operation.IS_NULL, operand) + # Adding this conditional to keep it consistent with the current implementation, + # of carrying Pyarrow's semantic of `is_valid` + elif func_name == "is_valid" or func_name == "is_not_null": + if len(node.args) != 1: + raise ValueError(f"{func_name}() expects exactly one argument") + operand = self.visit(node.args[0]) + return UnaryExpr(Operation.IS_NOT_NULL, operand) + elif func_name == "is_nan": + if len(node.args) != 1: + raise ValueError("is_nan() expects exactly one argument") + operand = self.visit(node.args[0]) + # Use x != x pattern for NaN detection (NaN != NaN is True) + return BinaryExpr(Operation.NE, operand, operand) + elif func_name == "is_in": + if len(node.args) != 2: + raise ValueError("is_in() expects exactly two arguments") + left = self.visit(node.args[0]) + right = self.visit(node.args[1]) + return BinaryExpr(Operation.IN, left, right) + else: + raise ValueError(f"Unsupported function: {func_name}") diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index ccf7b713490d..0578881620ca 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -51,7 +51,6 @@ from ray.data._internal.output_buffer import OutputBlockSizeOption from ray.data._internal.util import _truncated_repr from ray.data.block import ( - BatchFormat, Block, BlockAccessor, CallableClass, @@ -218,26 +217,24 @@ def plan_filter_op( target_max_block_size=data_context.target_max_block_size, ) - expression = op._filter_expr + predicate_expr = op._predicate_expr compute = get_compute(op._compute) - if expression is not None: + if predicate_expr is not None: - def filter_batch_fn(block: "pa.Table") -> "pa.Table": - try: - return block.filter(expression) - except Exception as e: - _try_wrap_udf_exception(e) + def filter_block_fn( + blocks: Iterable[Block], ctx: TaskContext + ) -> Iterable[Block]: + for block in blocks: + block_accessor = BlockAccessor.for_block(block) + filtered_block = block_accessor.filter(predicate_expr) + yield filtered_block init_fn = None - transform_fn = BatchMapTransformFn( - _generate_transform_fn_for_map_batches(filter_batch_fn), - batch_size=None, - batch_format=BatchFormat.ARROW, - zero_copy_batch=True, + transform_fn = BlockMapTransformFn( + filter_block_fn, is_udf=True, output_block_size_option=output_block_size_option, ) - else: udf_is_callable_class = isinstance(op._fn, CallableClass) filter_fn, init_fn = _get_udf( diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 22c9631caf0d..b70cd4ea6afd 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1391,7 +1391,7 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]: def filter( self, fn: Optional[UserDefinedFunction[Dict[str, Any], bool]] = None, - expr: Optional[str] = None, + expr: Optional[Union[str, Expr]] = None, *, compute: Union[str, ComputeStrategy] = None, fn_args: Optional[Iterable[Any]] = None, @@ -1407,30 +1407,40 @@ def filter( ) -> "Dataset": """Filter out rows that don't satisfy the given predicate. - You can use either a function or a callable class or an expression string to + You can use either a function or a callable class or an expression to perform the transformation. For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses stateful Ray actors. For more information, see :ref:`Stateful Transforms `. .. tip:: - If you use the `expr` parameter with a Python expression string, Ray Data + If you use the `expr` parameter with a predicate expression, Ray Data optimizes your filter with native Arrow interfaces. + .. deprecated:: + String expressions are deprecated and will be removed in a future version. + Use predicate expressions from `ray.data.expressions` instead. + Examples: >>> import ray + >>> from ray.data.expressions import col >>> ds = ray.data.range(100) + >>> # String expressions (deprecated - will warn) >>> ds.filter(expr="id <= 4").take_all() [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}, {'id': 4}] + >>> # Using predicate expressions (preferred) + >>> ds.filter(expr=(col("id") > 10) & (col("id") < 20)).take_all() + [{'id': 11}, {'id': 12}, {'id': 13}, {'id': 14}, {'id': 15}, {'id': 16}, {'id': 17}, {'id': 18}, {'id': 19}] Time complexity: O(dataset size / parallelism) Args: fn: The predicate to apply to each row, or a class type that can be instantiated to create such a callable. - expr: An expression string needs to be a valid Python expression that - will be converted to ``pyarrow.dataset.Expression`` type. + expr: An expression that represents a predicate (boolean condition) for filtering. + Can be either a string expression (deprecated) or a predicate expression + from `ray.data.expressions`. fn_args: Positional arguments to pass to ``fn`` after the first argument. These arguments are top-level arguments to the underlying Ray task. fn_kwargs: Keyword arguments to pass to ``fn``. These arguments are @@ -1479,10 +1489,12 @@ def filter( :func:`ray.remote` for details. """ # Ensure exactly one of fn or expr is provided - resolved_expr = None - if not ((fn is None) ^ (expr is None)): + provided_params = sum([fn is not None, expr is not None]) + if provided_params != 1: raise ValueError("Exactly one of 'fn' or 'expr' must be provided.") - elif expr is not None: + + # Helper function to check for incompatible function parameters + def _check_fn_params_incompatible(param_type): if ( fn_args is not None or fn_kwargs is not None @@ -1490,57 +1502,95 @@ def filter( or fn_constructor_kwargs is not None ): raise ValueError( - "when 'expr' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' can not be used." + f"when '{param_type}' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used." ) + + # Merge ray remote args early + ray_remote_args = merge_resources_to_ray_remote_args( + num_cpus, + num_gpus, + memory, + ray_remote_args, + ) + + # Initialize Filter operator arguments with proper types + input_op = self._logical_plan.dag + predicate_expr: Optional[Expr] = None + filter_fn: Optional[UserDefinedFunction] = None + filter_fn_args: Optional[Iterable[Any]] = None + filter_fn_kwargs: Optional[Dict[str, Any]] = None + filter_fn_constructor_args: Optional[Iterable[Any]] = None + filter_fn_constructor_kwargs: Optional[Dict[str, Any]] = None + filter_compute: Optional[ComputeStrategy] = None + + if expr is not None: + _check_fn_params_incompatible("expr") from ray.data._internal.compute import TaskPoolStrategy - from ray.data._internal.planner.plan_expression.expression_evaluator import ( # noqa: E501 - ExpressionEvaluator, - ) - # TODO: (srinathk) bind the expression to the actual schema. - # If fn is a string, convert it to a pyarrow.dataset.Expression - # Initialize ExpressionEvaluator with valid columns, if available - resolved_expr = ExpressionEvaluator.get_filters(expression=expr) + # Check if expr is a string (deprecated) or Expr object + if isinstance(expr, str): + warnings.warn( + "String expressions are deprecated and will be removed in a future version. " + "Use predicate expressions from ray.data.expressions instead. " + "For example: from ray.data.expressions import col; " + "ds.filter(expr=col('column_name') > 5)", + DeprecationWarning, + stacklevel=2, + ) + + from ray.data._internal.planner.plan_expression.expression_evaluator import ( # noqa: E501 + ExpressionEvaluator, + ) + + # TODO: (srinathk) bind the expression to the actual schema. + # If expr is a string, convert it to a pyarrow.dataset.Expression + # Initialize ExpressionEvaluator with valid columns, if available + # str -> Ray Data's Expression + predicate_expr = ExpressionEvaluator.parse_native_expression(expr) + else: + # expr is an Expr object (predicate expression) + predicate_expr = expr - compute = TaskPoolStrategy(size=concurrency) + filter_compute = TaskPoolStrategy(size=concurrency) else: warnings.warn( "Use 'expr' instead of 'fn' when possible for performant filters." ) - if callable(fn): - compute = get_compute_strategy( - fn=fn, - fn_constructor_args=fn_constructor_args, - compute=compute, - concurrency=concurrency, - ) - else: + if not callable(fn): raise ValueError( f"fn must be a UserDefinedFunction, but got " f"{type(fn).__name__} instead." ) - ray_remote_args = merge_resources_to_ray_remote_args( - num_cpus, - num_gpus, - memory, - ray_remote_args, - ) - plan = self._plan.copy() - op = Filter( - input_op=self._logical_plan.dag, - fn=fn, - fn_args=fn_args, - fn_kwargs=fn_kwargs, - fn_constructor_args=fn_constructor_args, - fn_constructor_kwargs=fn_constructor_kwargs, - filter_expr=resolved_expr, - compute=compute, + filter_fn = fn + filter_fn_args = fn_args + filter_fn_kwargs = fn_kwargs + filter_fn_constructor_args = fn_constructor_args + filter_fn_constructor_kwargs = fn_constructor_kwargs + filter_compute = get_compute_strategy( + fn=fn, + fn_constructor_args=fn_constructor_args, + compute=compute, + concurrency=concurrency, + ) + + # Create Filter operator with explicitly typed arguments + filter_op = Filter( + input_op=input_op, + predicate_expr=predicate_expr, + fn=filter_fn, + fn_args=filter_fn_args, + fn_kwargs=filter_fn_kwargs, + fn_constructor_args=filter_fn_constructor_args, + fn_constructor_kwargs=filter_fn_constructor_kwargs, + compute=filter_compute, ray_remote_args_fn=ray_remote_args_fn, ray_remote_args=ray_remote_args, ) - logical_plan = LogicalPlan(op, self.context) + + plan = self._plan.copy() + logical_plan = LogicalPlan(filter_op, self.context) return Dataset(plan, logical_plan) @PublicAPI(api_group=SSR_API_GROUP) @@ -1569,7 +1619,7 @@ def repartition( * When ``num_blocks`` and ``shuffle=True`` are specified Ray Data performs a full distributed shuffle producing exactly ``num_blocks`` blocks. * When ``num_blocks`` and ``shuffle=False`` are specified, Ray Data does NOT perform full shuffle, instead opting in for splitting and combining of the blocks attempting to minimize the necessary data movement (relative to full-blown shuffle). Exactly ``num_blocks`` will be produced. - * If ``target_num_rows_per_block`` is set (exclusive with ``num_blocks`` and ``shuffle``), streaming repartitioning will be executed, where blocks will be made to carry no more than ``target_num_rows_per_block``. Smaller blocks will be combined into bigger ones up to ``target_num_rows_per_block`` as well. + * If ``target_num_rows_per_block`` is set (exclusive with ``num_blocks`` and ``shuffle``), streaming repartitioning will be executed, where blocks will be made to carry no more than ``target_num_rows_per_block`` rows. Smaller blocks will be combined into bigger ones up to ``target_num_rows_per_block`` as well. .. image:: /data/images/dataset-shuffle.svg :align: center diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index ee5769720997..4a2fbae14098 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -307,7 +307,7 @@ def test_filter_operator(ray_start_regular_shared_2_cpus): read_op = get_parquet_read_logical_op() op = Filter( read_op, - lambda x: x, + fn=lambda x: x, ) plan = LogicalPlan(op, ctx) physical_op = planner.plan(plan).dag diff --git a/python/ray/data/tests/test_filter.py b/python/ray/data/tests/test_filter.py new file mode 100644 index 000000000000..c7d101745e17 --- /dev/null +++ b/python/ray/data/tests/test_filter.py @@ -0,0 +1,377 @@ +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from pkg_resources import parse_version + +import ray +from ray.data.expressions import col +from ray.data.tests.conftest import get_pyarrow_version +from ray.tests.conftest import * # noqa + + +def test_filter_mutex(ray_start_regular_shared, tmp_path): + """Test filter op.""" + + # Generate sample data + data = { + "sepal.length": [4.8, 5.1, 5.7, 6.3, 7.0], + "sepal.width": [3.0, 3.3, 3.5, 3.2, 2.8], + "petal.length": [1.4, 1.7, 4.2, 5.4, 6.1], + "petal.width": [0.2, 0.4, 1.5, 2.1, 2.4], + } + df = pd.DataFrame(data) + + # Define the path for the Parquet file in the tmp_path directory + parquet_file = tmp_path / "sample_data.parquet" + + # Write DataFrame to a Parquet file + table = pa.Table.from_pandas(df) + pq.write_table(table, parquet_file) + + # Load parquet dataset + parquet_ds = ray.data.read_parquet(str(parquet_file)) + + # Filter using lambda (UDF) + with pytest.raises( + ValueError, + ): + parquet_ds.filter( + fn=lambda r: r["sepal.length"] > 5.0, expr="sepal.length > 5.0" + ) + + with pytest.raises(ValueError, match="must be a UserDefinedFunction"): + parquet_ds.filter(fn="sepal.length > 5.0") + + +def test_filter_with_expressions(ray_start_regular_shared, tmp_path): + """Test filtering with expressions.""" + + # Generate sample data + data = { + "sepal.length": [4.8, 5.1, 5.7, 6.3, 7.0], + "sepal.width": [3.0, 3.3, 3.5, 3.2, 2.8], + "petal.length": [1.4, 1.7, 4.2, 5.4, 6.1], + "petal.width": [0.2, 0.4, 1.5, 2.1, 2.4], + } + df = pd.DataFrame(data) + + # Define the path for the Parquet file in the tmp_path directory + parquet_file = tmp_path / "sample_data.parquet" + + # Write DataFrame to a Parquet file + table = pa.Table.from_pandas(df) + pq.write_table(table, parquet_file) + + # Load parquet dataset + parquet_ds = ray.data.read_parquet(str(parquet_file)) + + # Filter using lambda (UDF) + filtered_udf_ds = parquet_ds.filter(lambda r: r["sepal.length"] > 5.0) + filtered_udf_data = filtered_udf_ds.to_pandas() + + # Filter using expressions + filtered_expr_ds = parquet_ds.filter(expr="sepal.length > 5.0") + filtered_expr_data = filtered_expr_ds.to_pandas() + + # Assert the filtered data is the same + assert set(filtered_udf_data["sepal.length"]) == set( + filtered_expr_data["sepal.length"] + ) + assert len(filtered_udf_data) == len(filtered_expr_data) + + # Verify correctness of filtered results: only rows with 'sepal.length' > 5.0 + assert all( + filtered_expr_data["sepal.length"] > 5.0 + ), "Filtered data contains rows with 'sepal.length' <= 5.0" + assert all( + filtered_udf_data["sepal.length"] > 5.0 + ), "UDF-filtered data contains rows with 'sepal.length' <= 5.0" + + +def test_filter_with_invalid_expression(ray_start_regular_shared, tmp_path): + """Test filtering with invalid expressions.""" + + # Generate sample data + data = { + "sepal.length": [4.8, 5.1, 5.7, 6.3, 7.0], + "sepal.width": [3.0, 3.3, 3.5, 3.2, 2.8], + "petal.length": [1.4, 1.7, 4.2, 5.4, 6.1], + "petal.width": [0.2, 0.4, 1.5, 2.1, 2.4], + } + df = pd.DataFrame(data) + + # Define the path for the Parquet file in the tmp_path directory + parquet_file = tmp_path / "sample_data.parquet" + + # Write DataFrame to a Parquet file + table = pa.Table.from_pandas(df) + pq.write_table(table, parquet_file) + + # Load parquet dataset + parquet_ds = ray.data.read_parquet(str(parquet_file)) + + with pytest.raises(ValueError, match="Invalid syntax in the expression"): + parquet_ds.filter(expr="fake_news super fake") + + fake_column_ds = parquet_ds.filter(expr="sepal_length_123 > 1") + with pytest.raises(KeyError): + fake_column_ds.to_pandas() + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="predicate expressions require PyArrow >= 20.0.0", +) +@pytest.mark.parametrize( + "predicate_expr, test_data, expected_indices, test_description", + [ + # Simple comparison filters + pytest.param( + col("age") >= 21, + [ + {"age": 20, "name": "Alice"}, + {"age": 21, "name": "Bob"}, + {"age": 25, "name": "Charlie"}, + {"age": 30, "name": "David"}, + ], + [1, 2, 3], # Indices of rows that should remain + "age_greater_equal_filter", + ), + pytest.param( + col("score") > 50, + [ + {"score": 30, "status": "fail"}, + {"score": 50, "status": "borderline"}, + {"score": 70, "status": "pass"}, + {"score": 90, "status": "excellent"}, + ], + [2, 3], + "score_greater_than_filter", + ), + pytest.param( + col("category") == "premium", + [ + {"category": "basic", "price": 10}, + {"category": "premium", "price": 50}, + {"category": "standard", "price": 25}, + {"category": "premium", "price": 75}, + ], + [1, 3], + "equality_string_filter", + ), + # Complex logical filters + pytest.param( + (col("age") >= 18) & (col("active")), + [ + {"age": 17, "active": True}, + {"age": 18, "active": False}, + {"age": 25, "active": True}, + {"age": 30, "active": True}, + ], + [2, 3], + "logical_and_filter", + ), + pytest.param( + (col("status") == "approved") | (col("priority") == "high"), + [ + {"status": "pending", "priority": "low"}, + {"status": "approved", "priority": "low"}, + {"status": "pending", "priority": "high"}, + {"status": "rejected", "priority": "high"}, + ], + [1, 2, 3], + "logical_or_filter", + ), + # Null handling filters + pytest.param( + col("value").is_not_null(), + [ + {"value": None, "id": 1}, + {"value": 0, "id": 2}, + {"value": None, "id": 3}, + {"value": 42, "id": 4}, + ], + [1, 3], + "not_null_filter", + ), + pytest.param( + col("name").is_null(), + [ + {"name": "Alice", "id": 1}, + {"name": None, "id": 2}, + {"name": "Bob", "id": 3}, + {"name": None, "id": 4}, + ], + [1, 3], + "is_null_filter", + ), + # Complex multi-condition filters + pytest.param( + col("value").is_not_null() & (col("value") > 0), + [ + {"value": None, "type": "missing"}, + {"value": -5, "type": "negative"}, + {"value": 0, "type": "zero"}, + {"value": 10, "type": "positive"}, + ], + [3], + "null_aware_positive_filter", + ), + # String operations + pytest.param( + col("name").is_not_null() & (col("name") != "excluded"), + [ + {"name": "included", "id": 1}, + {"name": "excluded", "id": 2}, + {"name": None, "id": 3}, + {"name": "allowed", "id": 4}, + ], + [0, 3], + "string_exclusion_filter", + ), + # Membership operations + pytest.param( + col("category").is_in(["A", "B"]), + [ + {"category": "A", "value": 1}, + {"category": "B", "value": 2}, + {"category": "C", "value": 3}, + {"category": "D", "value": 4}, + {"category": "A", "value": 5}, + ], + [0, 1, 4], + "membership_filter", + ), + # Negation operations + pytest.param( + ~(col("category") == "reject"), + [ + {"category": "accept", "id": 1}, + {"category": "reject", "id": 2}, + {"category": "pending", "id": 3}, + {"category": "reject", "id": 4}, + ], + [0, 2], + "negation_filter", + ), + # Nested complex expressions + pytest.param( + (col("score") >= 50) & (col("grade") != "F") & col("active"), + [ + {"score": 45, "grade": "F", "active": True}, + {"score": 55, "grade": "D", "active": True}, + {"score": 75, "grade": "B", "active": False}, + {"score": 85, "grade": "A", "active": True}, + ], + [1, 3], + "complex_nested_filter", + ), + ], +) +def test_filter_with_predicate_expressions( + ray_start_regular_shared, + predicate_expr, + test_data, + expected_indices, + test_description, +): + """Test filter() with Ray Data predicate expressions.""" + # Create dataset from test data + ds = ray.data.from_items(test_data) + + # Apply filter with predicate expression + filtered_ds = ds.filter(expr=predicate_expr) + + # Convert to list and verify results + result_data = filtered_ds.to_pandas().to_dict("records") + expected_data = [test_data[i] for i in expected_indices] + + # Use pandas testing for consistent comparison + result_df = pd.DataFrame(result_data) + expected_df = pd.DataFrame(expected_data) + + pd.testing.assert_frame_equal( + result_df.reset_index(drop=True), + expected_df.reset_index(drop=True), + check_dtype=False, + ) + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="predicate expressions require PyArrow >= 20.0.0", +) +def test_filter_predicate_expr_vs_function_consistency(ray_start_regular_shared): + """Test that predicate expressions produce the same results as equivalent functions.""" + test_data = [ + {"age": 20, "score": 85, "active": True}, + {"age": 25, "score": 45, "active": False}, + {"age": 30, "score": 95, "active": True}, + {"age": 18, "score": 60, "active": True}, + ] + + ds = ray.data.from_items(test_data) + + # Test simple comparison + predicate_result = ds.filter(expr=col("age") >= 21).to_pandas() + function_result = ds.filter(fn=lambda row: row["age"] >= 21).to_pandas() + pd.testing.assert_frame_equal(predicate_result, function_result, check_dtype=False) + + # Test complex logical expression + complex_predicate = (col("age") >= 21) & (col("score") > 80) & col("active") + predicate_result = ds.filter(expr=complex_predicate).to_pandas() + function_result = ds.filter( + fn=lambda row: row["age"] >= 21 and row["score"] > 80 and row["active"] + ).to_pandas() + pd.testing.assert_frame_equal(predicate_result, function_result, check_dtype=False) + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="predicate expressions require PyArrow >= 20.0.0", +) +def test_filter_predicate_with_different_block_formats(ray_start_regular_shared): + """Test that predicate expressions work with different block formats (pandas/arrow).""" + test_data = [ + {"category": "A", "value": 10}, + {"category": "B", "value": 20}, + {"category": "A", "value": 30}, + {"category": "C", "value": 40}, + ] + + # Test with different data sources that produce different block formats + + # From items (typically arrow) + ds_items = ray.data.from_items(test_data) + result_items = ds_items.filter(expr=col("category") == "A").to_pandas() + + # From pandas (pandas blocks) + df = pd.DataFrame(test_data) + ds_pandas = ray.data.from_pandas([df]) + result_pandas = ds_pandas.filter(expr=col("category") == "A").to_pandas() + + # Results should be identical (reset indices for comparison) + expected_df = pd.DataFrame( + [ + {"category": "A", "value": 10}, + {"category": "A", "value": 30}, + ] + ) + + pd.testing.assert_frame_equal( + result_items.reset_index(drop=True), + expected_df.reset_index(drop=True), + check_dtype=False, + ) + pd.testing.assert_frame_equal( + result_pandas.reset_index(drop=True), + expected_df.reset_index(drop=True), + check_dtype=False, + ) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 6e3889326d76..d55c36717117 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -686,119 +686,6 @@ def test_rename_columns_error_cases( assert str(exc_info.value) == expected_message -def test_filter_mutex( - ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default -): - """Test filter op.""" - - # Generate sample data - data = { - "sepal.length": [4.8, 5.1, 5.7, 6.3, 7.0], - "sepal.width": [3.0, 3.3, 3.5, 3.2, 2.8], - "petal.length": [1.4, 1.7, 4.2, 5.4, 6.1], - "petal.width": [0.2, 0.4, 1.5, 2.1, 2.4], - } - df = pd.DataFrame(data) - - # Define the path for the Parquet file in the tmp_path directory - parquet_file = tmp_path / "sample_data.parquet" - - # Write DataFrame to a Parquet file - table = pa.Table.from_pandas(df) - pq.write_table(table, parquet_file) - - # Load parquet dataset - parquet_ds = ray.data.read_parquet(str(parquet_file)) - - # Filter using lambda (UDF) - with pytest.raises(ValueError, match="Exactly one of 'fn' or 'expr'"): - parquet_ds.filter( - fn=lambda r: r["sepal.length"] > 5.0, expr="sepal.length > 5.0" - ) - - with pytest.raises(ValueError, match="must be a UserDefinedFunction"): - parquet_ds.filter(fn="sepal.length > 5.0") - - -def test_filter_with_expressions( - ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default -): - """Test filtering with expressions.""" - - # Generate sample data - data = { - "sepal.length": [4.8, 5.1, 5.7, 6.3, 7.0], - "sepal.width": [3.0, 3.3, 3.5, 3.2, 2.8], - "petal.length": [1.4, 1.7, 4.2, 5.4, 6.1], - "petal.width": [0.2, 0.4, 1.5, 2.1, 2.4], - } - df = pd.DataFrame(data) - - # Define the path for the Parquet file in the tmp_path directory - parquet_file = tmp_path / "sample_data.parquet" - - # Write DataFrame to a Parquet file - table = pa.Table.from_pandas(df) - pq.write_table(table, parquet_file) - - # Load parquet dataset - parquet_ds = ray.data.read_parquet(str(parquet_file)) - - # Filter using lambda (UDF) - filtered_udf_ds = parquet_ds.filter(lambda r: r["sepal.length"] > 5.0) - filtered_udf_data = filtered_udf_ds.to_pandas() - - # Filter using expressions - filtered_expr_ds = parquet_ds.filter(expr="sepal.length > 5.0") - filtered_expr_data = filtered_expr_ds.to_pandas() - - # Assert the filtered data is the same - assert set(filtered_udf_data["sepal.length"]) == set( - filtered_expr_data["sepal.length"] - ) - assert len(filtered_udf_data) == len(filtered_expr_data) - - # Verify correctness of filtered results: only rows with 'sepal.length' > 5.0 - assert all( - filtered_expr_data["sepal.length"] > 5.0 - ), "Filtered data contains rows with 'sepal.length' <= 5.0" - assert all( - filtered_udf_data["sepal.length"] > 5.0 - ), "UDF-filtered data contains rows with 'sepal.length' <= 5.0" - - -def test_filter_with_invalid_expression( - ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default -): - """Test filtering with invalid expressions.""" - - # Generate sample data - data = { - "sepal.length": [4.8, 5.1, 5.7, 6.3, 7.0], - "sepal.width": [3.0, 3.3, 3.5, 3.2, 2.8], - "petal.length": [1.4, 1.7, 4.2, 5.4, 6.1], - "petal.width": [0.2, 0.4, 1.5, 2.1, 2.4], - } - df = pd.DataFrame(data) - - # Define the path for the Parquet file in the tmp_path directory - parquet_file = tmp_path / "sample_data.parquet" - - # Write DataFrame to a Parquet file - table = pa.Table.from_pandas(df) - pq.write_table(table, parquet_file) - - # Load parquet dataset - parquet_ds = ray.data.read_parquet(str(parquet_file)) - - with pytest.raises(ValueError, match="Invalid syntax in the expression"): - parquet_ds.filter(expr="fake_news super fake") - - fake_column_ds = parquet_ds.filter(expr="sepal_length_123 > 1") - with pytest.raises(UserCodeException): - fake_column_ds.to_pandas() - - def test_drop_columns( ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default ):