From 8db6b110ac0fbedcd34b8c6e85232fbc351d31bc Mon Sep 17 00:00:00 2001 From: "Goutam V." Date: Thu, 18 Sep 2025 13:32:41 -0700 Subject: [PATCH 01/10] [Data] [2/2] - Add predicate expression support for dataset.filter Signed-off-by: Goutam V. --- python/ray/data/_internal/arrow_block.py | 11 + .../logical/operators/map_operator.py | 42 ++- python/ray/data/_internal/pandas_block.py | 11 + .../data/_internal/planner/plan_udf_map_op.py | 17 + python/ray/data/dataset.py | 82 +++-- python/ray/data/tests/test_map.py | 331 +++++++++++++++++- 6 files changed, 463 insertions(+), 31 deletions(-) diff --git a/python/ray/data/_internal/arrow_block.py b/python/ray/data/_internal/arrow_block.py index 14a11b8b0fab..a8eca21d484e 100644 --- a/python/ray/data/_internal/arrow_block.py +++ b/python/ray/data/_internal/arrow_block.py @@ -39,6 +39,7 @@ U, ) from ray.data.context import DEFAULT_TARGET_MAX_BLOCK_SIZE, DataContext +from ray.data.expressions import Expr try: import pyarrow @@ -449,6 +450,16 @@ def iter_rows( for i in range(self.num_rows()): yield self._get_row(i) + def filter(self, predicate_expr: "Expr") -> "pyarrow.Table": + """Filter rows based on a predicate expression.""" + from ray.data._expression_evaluator import eval_expr + + # Evaluate the expression to get a boolean mask + mask = eval_expr(predicate_expr, self._table) + + # Use PyArrow's built-in filter method + return self._table.filter(mask) + class ArrowBlockColumnAccessor(BlockColumnAccessor): def __init__(self, col: Union["pyarrow.Array", "pyarrow.ChunkedArray"]): diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index 6b1bcefba5ad..5285979d0444 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -224,20 +224,42 @@ class Filter(AbstractUDFMap): def __init__( self, input_op: LogicalOperator, - fn: Optional[UserDefinedFunction] = None, - fn_args: Optional[Iterable[Any]] = None, - fn_kwargs: Optional[Dict[str, Any]] = None, - fn_constructor_args: Optional[Iterable[Any]] = None, - fn_constructor_kwargs: Optional[Dict[str, Any]] = None, - filter_expr: Optional["pa.dataset.Expression"] = None, + fn: Optional[ + UserDefinedFunction + ] = None, # TODO: Deprecate this parameter in favor of predicate_expr + fn_args: Optional[ + Iterable[Any] + ] = None, # TODO: Deprecate this parameter in favor of predicate_expr + fn_kwargs: Optional[ + Dict[str, Any] + ] = None, # TODO: Deprecate this parameter in favor of predicate_expr + fn_constructor_args: Optional[ + Iterable[Any] + ] = None, # TODO: Deprecate this parameter in favor of predicate_expr + fn_constructor_kwargs: Optional[ + Dict[str, Any] + ] = None, # TODO: Deprecate this parameter in favor of predicate_expr + predicate_expr: Optional[Expr] = None, + filter_expr: Optional[ + "pa.dataset.Expression" + ] = None, # TODO: Deprecate this parameter in favor of predicate_expr compute: Optional[ComputeStrategy] = None, ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, ray_remote_args: Optional[Dict[str, Any]] = None, ): - # Ensure exactly one of fn or filter_expr is provided - if not ((fn is None) ^ (filter_expr is None)): - raise ValueError("Exactly one of 'fn' or 'filter_expr' must be provided") - self._filter_expr = filter_expr + # Ensure exactly one of fn, filter_expr, or predicate_expr is provided + provided_params = sum( + [fn is not None, filter_expr is not None, predicate_expr is not None] + ) + if provided_params != 1: + raise ValueError( + "Exactly one of 'fn', 'filter_expr', or 'predicate_expr' must be provided" + ) + + self._filter_expr = ( + filter_expr # TODO: Deprecate this parameter in favor of predicate_expr + ) + self._predicate_expr = predicate_expr super().__init__( "Filter", diff --git a/python/ray/data/_internal/pandas_block.py b/python/ray/data/_internal/pandas_block.py index ff08af3c0622..2ea54b554d9d 100644 --- a/python/ray/data/_internal/pandas_block.py +++ b/python/ray/data/_internal/pandas_block.py @@ -34,6 +34,7 @@ U, ) from ray.data.context import DataContext +from ray.data.expressions import Expr if TYPE_CHECKING: import pandas @@ -607,3 +608,13 @@ def iter_rows( yield row.as_pydict() else: yield row + + def filter(self, predicate_expr: "Expr") -> "pandas.DataFrame": + """Filter rows based on a predicate expression.""" + from ray.data._expression_evaluator import eval_expr + + # Evaluate the expression to get a boolean mask + mask = eval_expr(predicate_expr, self._table) + + # Use pandas boolean indexing + return self._table[mask] diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 3809bcafd8a5..72d5f217aa69 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -201,6 +201,7 @@ def plan_filter_op( input_physical_dag = physical_children[0] expression = op._filter_expr + predicate_expr = op._predicate_expr compute = get_compute(op._compute) if expression is not None: @@ -217,6 +218,22 @@ def filter_batch_fn(block: "pa.Table") -> "pa.Table": batch_format="pyarrow", zero_copy_batch=True, ) + elif predicate_expr is not None: + # Ray Data expression path using BlockAccessor + def filter_block_fn(block: Block) -> Block: + try: + block_accessor = BlockAccessor.for_block(block) + if not block_accessor.num_rows(): + return block + return block_accessor.filter(predicate_expr) + + except Exception as e: + _try_wrap_udf_exception(e) + + transform_fn = _generate_transform_fn_for_map_block(filter_block_fn) + map_transformer = _create_map_transformer_for_block_based_map_op( + transform_fn, + ) else: udf_is_callable_class = isinstance(op._fn, CallableClass) filter_fn, init_fn = _get_udf( diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 9c792eeb95ec..0582dd75a64b 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1392,8 +1392,13 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]: @PublicAPI(api_group=BT_API_GROUP) def filter( self, - fn: Optional[UserDefinedFunction[Dict[str, Any], bool]] = None, - expr: Optional[str] = None, + fn: Optional[ + UserDefinedFunction[Dict[str, Any], bool] + ] = None, # TODO: Deprecate this parameter in favor of predicate + expr: Optional[ + str + ] = None, # TODO: Deprecate this parameter in favor of predicate + predicate: Optional[Expr] = None, *, compute: Union[str, ComputeStrategy] = None, fn_args: Optional[Iterable[Any]] = None, @@ -1419,9 +1424,13 @@ def filter( Examples: >>> import ray + >>> from ray.data.expressions import col >>> ds = ray.data.range(100) - >>> ds.filter(expr="id <= 4").take_all() + >>> ds.filter(expr="id <= 4").take_all() # Will be deprecated in the future [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}, {'id': 4}] + >>> # Using predicate expressions + >>> ds.filter(predicate=(col("id") > 10) & (col("id") < 20)).take_all() + [{'id': 11}, {'id': 12}, {'id': 13}, {'id': 14}, {'id': 15}, {'id': 16}, {'id': 17}, {'id': 18}, {'id': 19}] Time complexity: O(dataset size / parallelism) @@ -1429,7 +1438,8 @@ def filter( fn: The predicate to apply to each row, or a class type that can be instantiated to create such a callable. expr: An expression string needs to be a valid Python expression that - will be converted to ``pyarrow.dataset.Expression`` type. + will be converted to ``pyarrow.dataset.Expression`` type. (Soon to be deprecated) + predicate: An expression that represents a predicate (boolean condition) for filtering. fn_args: Positional arguments to pass to ``fn`` after the first argument. These arguments are top-level arguments to the underlying Ray task. fn_kwargs: Keyword arguments to pass to ``fn``. These arguments are @@ -1473,9 +1483,33 @@ def filter( :func:`ray.remote` for details. """ # Ensure exactly one of fn or expr is provided - resolved_expr = None - if not ((fn is None) ^ (expr is None)): - raise ValueError("Exactly one of 'fn' or 'expr' must be provided.") + # Ensure exactly one of fn, expr, or predicate is provided + provided_params = sum([fn is not None, expr is not None, predicate is not None]) + if provided_params != 1: + raise ValueError( + "Exactly one of 'fn', 'expr', or 'predicate' must be provided." + ) + if predicate is not None: + if ( + fn_args is not None + or fn_kwargs is not None + or fn_constructor_args is not None + or fn_constructor_kwargs is not None + ): + raise ValueError( + "when 'predicate' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used." + ) + from ray.data._internal.compute import TaskPoolStrategy + + compute = TaskPoolStrategy(size=concurrency) + # Create Filter operator with predicate expression + filter_op = Filter( + input_op=self._logical_plan.dag, + predicate_expr=predicate, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) elif expr is not None: if ( fn_args is not None @@ -1497,6 +1531,14 @@ def filter( resolved_expr = ExpressionEvaluator.get_filters(expression=expr) compute = TaskPoolStrategy(size=concurrency) + # Create Filter operator with string expression + filter_op = Filter( + input_op=self._logical_plan.dag, + filter_expr=resolved_expr, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) else: warnings.warn( "Use 'expr' instead of 'fn' when possible for performant filters." @@ -1514,21 +1556,21 @@ def filter( f"fn must be a UserDefinedFunction, but got " f"{type(fn).__name__} instead." ) + # Create Filter operator with function + filter_op = Filter( + input_op=self._logical_plan.dag, + fn=fn, + fn_args=fn_args, + fn_kwargs=fn_kwargs, + fn_constructor_args=fn_constructor_args, + fn_constructor_kwargs=fn_constructor_kwargs, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) plan = self._plan.copy() - op = Filter( - input_op=self._logical_plan.dag, - fn=fn, - fn_args=fn_args, - fn_kwargs=fn_kwargs, - fn_constructor_args=fn_constructor_args, - fn_constructor_kwargs=fn_constructor_kwargs, - filter_expr=resolved_expr, - compute=compute, - ray_remote_args_fn=ray_remote_args_fn, - ray_remote_args=ray_remote_args, - ) - logical_plan = LogicalPlan(op, self.context) + logical_plan = LogicalPlan(filter_op, self.context) return Dataset(plan, logical_plan) @PublicAPI(api_group=SSR_API_GROUP) diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index e90d027d242a..3905e24066ce 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -714,7 +714,10 @@ def test_filter_mutex( parquet_ds = ray.data.read_parquet(str(parquet_file)) # Filter using lambda (UDF) - with pytest.raises(ValueError, match="Exactly one of 'fn' or 'expr'"): + with pytest.raises( + ValueError, + match="Exactly one of 'fn', 'expr', or 'predicate' must be provided.", + ): parquet_ds.filter( fn=lambda r: r["sepal.length"] > 5.0, expr="sepal.length > 5.0" ) @@ -802,6 +805,332 @@ def test_filter_with_invalid_expression( fake_column_ds.to_pandas() +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="predicate expressions require PyArrow >= 20.0.0", +) +@pytest.mark.parametrize( + "predicate_expr, test_data, expected_indices, test_description", + [ + # Simple comparison filters + pytest.param( + col("age") >= 21, + [ + {"age": 20, "name": "Alice"}, + {"age": 21, "name": "Bob"}, + {"age": 25, "name": "Charlie"}, + {"age": 30, "name": "David"}, + ], + [1, 2, 3], # Indices of rows that should remain + "age_greater_equal_filter", + ), + pytest.param( + col("score") > 50, + [ + {"score": 30, "status": "fail"}, + {"score": 50, "status": "borderline"}, + {"score": 70, "status": "pass"}, + {"score": 90, "status": "excellent"}, + ], + [2, 3], + "score_greater_than_filter", + ), + pytest.param( + col("category") == "premium", + [ + {"category": "basic", "price": 10}, + {"category": "premium", "price": 50}, + {"category": "standard", "price": 25}, + {"category": "premium", "price": 75}, + ], + [1, 3], + "equality_string_filter", + ), + # Complex logical filters + pytest.param( + (col("age") >= 18) & (col("active")), + [ + {"age": 17, "active": True}, + {"age": 18, "active": False}, + {"age": 25, "active": True}, + {"age": 30, "active": True}, + ], + [2, 3], + "logical_and_filter", + ), + pytest.param( + (col("status") == "approved") | (col("priority") == "high"), + [ + {"status": "pending", "priority": "low"}, + {"status": "approved", "priority": "low"}, + {"status": "pending", "priority": "high"}, + {"status": "rejected", "priority": "high"}, + ], + [1, 2, 3], + "logical_or_filter", + ), + # Null handling filters + pytest.param( + col("value").is_not_null(), + [ + {"value": None, "id": 1}, + {"value": 0, "id": 2}, + {"value": None, "id": 3}, + {"value": 42, "id": 4}, + ], + [1, 3], + "not_null_filter", + ), + pytest.param( + col("name").is_null(), + [ + {"name": "Alice", "id": 1}, + {"name": None, "id": 2}, + {"name": "Bob", "id": 3}, + {"name": None, "id": 4}, + ], + [1, 3], + "is_null_filter", + ), + # Complex multi-condition filters + pytest.param( + col("value").is_not_null() & (col("value") > 0), + [ + {"value": None, "type": "missing"}, + {"value": -5, "type": "negative"}, + {"value": 0, "type": "zero"}, + {"value": 10, "type": "positive"}, + ], + [3], + "null_aware_positive_filter", + ), + # String operations + pytest.param( + col("name").is_not_null() & (col("name") != "excluded"), + [ + {"name": "included", "id": 1}, + {"name": "excluded", "id": 2}, + {"name": None, "id": 3}, + {"name": "allowed", "id": 4}, + ], + [0, 3], + "string_exclusion_filter", + ), + # Membership operations + pytest.param( + col("category").is_in(["A", "B"]), + [ + {"category": "A", "value": 1}, + {"category": "B", "value": 2}, + {"category": "C", "value": 3}, + {"category": "D", "value": 4}, + {"category": "A", "value": 5}, + ], + [0, 1, 4], + "membership_filter", + ), + # Negation operations + pytest.param( + ~(col("category") == "reject"), + [ + {"category": "accept", "id": 1}, + {"category": "reject", "id": 2}, + {"category": "pending", "id": 3}, + {"category": "reject", "id": 4}, + ], + [0, 2], + "negation_filter", + ), + # Nested complex expressions + pytest.param( + (col("score") >= 50) & (col("grade") != "F") & col("active"), + [ + {"score": 45, "grade": "F", "active": True}, + {"score": 55, "grade": "D", "active": True}, + {"score": 75, "grade": "B", "active": False}, + {"score": 85, "grade": "A", "active": True}, + ], + [1, 3], + "complex_nested_filter", + ), + ], +) +def test_filter_with_predicate_expressions( + ray_start_regular_shared, + predicate_expr, + test_data, + expected_indices, + test_description, + target_max_block_size_infinite_or_default, +): + """Test filter() with Ray Data predicate expressions.""" + # Create dataset from test data + ds = ray.data.from_items(test_data) + + # Apply filter with predicate expression + filtered_ds = ds.filter(predicate=predicate_expr) + + # Convert to list and verify results + result_data = filtered_ds.to_pandas().to_dict("records") + expected_data = [test_data[i] for i in expected_indices] + + # Use pandas testing for consistent comparison + result_df = pd.DataFrame(result_data) + expected_df = pd.DataFrame(expected_data) + + pd.testing.assert_frame_equal( + result_df.reset_index(drop=True), + expected_df.reset_index(drop=True), + check_dtype=False, + ) + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="predicate expressions require PyArrow >= 20.0.0", +) +def test_filter_predicate_expr_vs_function_consistency( + ray_start_regular_shared, target_max_block_size_infinite_or_default +): + """Test that predicate expressions produce the same results as equivalent functions.""" + test_data = [ + {"age": 20, "score": 85, "active": True}, + {"age": 25, "score": 45, "active": False}, + {"age": 30, "score": 95, "active": True}, + {"age": 18, "score": 60, "active": True}, + ] + + ds = ray.data.from_items(test_data) + + # Test simple comparison + predicate_result = ds.filter(predicate=col("age") >= 21).to_pandas() + function_result = ds.filter(fn=lambda row: row["age"] >= 21).to_pandas() + pd.testing.assert_frame_equal(predicate_result, function_result, check_dtype=False) + + # Test complex logical expression + complex_predicate = (col("age") >= 21) & (col("score") > 80) & col("active") + predicate_result = ds.filter(predicate=complex_predicate).to_pandas() + function_result = ds.filter( + fn=lambda row: row["age"] >= 21 and row["score"] > 80 and row["active"] + ).to_pandas() + pd.testing.assert_frame_equal(predicate_result, function_result, check_dtype=False) + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="predicate expressions require PyArrow >= 20.0.0", +) +@pytest.mark.parametrize( + "filter_args, expected_error_match", + [ + # Test that exactly one parameter must be provided + pytest.param( + {}, + "Exactly one of 'fn', 'expr', or 'predicate' must be provided", + id="no_parameters", + ), + pytest.param( + {"fn": lambda x: True, "predicate": col("x") > 0}, + "Exactly one of 'fn', 'expr', or 'predicate' must be provided", + id="fn_and_predicate", + ), + pytest.param( + {"expr": "x > 0", "predicate": col("x") > 0}, + "Exactly one of 'fn', 'expr', or 'predicate' must be provided", + id="expr_and_predicate", + ), + pytest.param( + {"fn": lambda x: True, "expr": "x > 0", "predicate": col("x") > 0}, + "Exactly one of 'fn', 'expr', or 'predicate' must be provided", + id="all_three_parameters", + ), + pytest.param( + {"fn": lambda x: True, "expr": "x > 0"}, + "Exactly one of 'fn', 'expr', or 'predicate' must be provided", + id="fn_and_expr", + ), + # Test that predicate is incompatible with function-specific parameters + pytest.param( + {"predicate": col("x") > 0, "fn_args": [1, 2]}, + "when 'predicate' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used", + id="predicate_with_fn_args", + ), + pytest.param( + {"predicate": col("x") > 0, "fn_kwargs": {"key": "value"}}, + "when 'predicate' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used", + id="predicate_with_fn_kwargs", + ), + pytest.param( + {"predicate": col("x") > 0, "fn_constructor_args": [1, 2]}, + "when 'predicate' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used", + id="predicate_with_fn_constructor_args", + ), + pytest.param( + {"predicate": col("x") > 0, "fn_constructor_kwargs": {"key": "value"}}, + "when 'predicate' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used", + id="predicate_with_fn_constructor_kwargs", + ), + ], +) +def test_filter_predicate_parameter_validation( + ray_start_regular_shared, + target_max_block_size_infinite_or_default, + filter_args, + expected_error_match, +): + """Test that filter() properly validates predicate parameter usage.""" + ds = ray.data.from_items([{"x": 1}, {"x": 2}]) + + with pytest.raises(ValueError, match=expected_error_match): + ds.filter(**filter_args) + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="predicate expressions require PyArrow >= 20.0.0", +) +def test_filter_predicate_with_different_block_formats( + ray_start_regular_shared, target_max_block_size_infinite_or_default +): + """Test that predicate expressions work with different block formats (pandas/arrow).""" + test_data = [ + {"category": "A", "value": 10}, + {"category": "B", "value": 20}, + {"category": "A", "value": 30}, + {"category": "C", "value": 40}, + ] + + # Test with different data sources that produce different block formats + + # From items (typically arrow) + ds_items = ray.data.from_items(test_data) + result_items = ds_items.filter(predicate=col("category") == "A").to_pandas() + + # From pandas (pandas blocks) + df = pd.DataFrame(test_data) + ds_pandas = ray.data.from_pandas([df]) + result_pandas = ds_pandas.filter(predicate=col("category") == "A").to_pandas() + + # Results should be identical (reset indices for comparison) + expected_df = pd.DataFrame( + [ + {"category": "A", "value": 10}, + {"category": "A", "value": 30}, + ] + ) + + pd.testing.assert_frame_equal( + result_items.reset_index(drop=True), + expected_df.reset_index(drop=True), + check_dtype=False, + ) + pd.testing.assert_frame_equal( + result_pandas.reset_index(drop=True), + expected_df.reset_index(drop=True), + check_dtype=False, + ) + + def test_drop_columns( ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default ): From 27be0de2d549b18c5187005ba354d20a67c596b7 Mon Sep 17 00:00:00 2001 From: "Goutam V." Date: Fri, 19 Sep 2025 13:41:50 -0700 Subject: [PATCH 02/10] Fix doclint + address comments Signed-off-by: Goutam V. --- .../data/_internal/planner/plan_udf_map_op.py | 1 + python/ray/data/dataset.py | 37 +++++++++---------- python/ray/data/tests/test_map.py | 1 - 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 16951dc52803..a976c5b62f1e 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -247,6 +247,7 @@ def filter_block_fn(block: Block) -> Block: except Exception as e: _try_wrap_udf_exception(e) + init_fn = None transform_fn = BatchMapTransformFn( _generate_transform_fn_for_map_batches(filter_block_fn), batch_size=None, diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 20c7a265c214..362a478aacc6 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1488,14 +1488,15 @@ def filter( Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See :func:`ray.remote` for details. """ - # Ensure exactly one of fn or expr is provided # Ensure exactly one of fn, expr, or predicate is provided provided_params = sum([fn is not None, expr is not None, predicate is not None]) if provided_params != 1: raise ValueError( "Exactly one of 'fn', 'expr', or 'predicate' must be provided." ) - if predicate is not None: + + # Helper function to check for incompatible function parameters + def _check_fn_params_incompatible(param_type): if ( fn_args is not None or fn_kwargs is not None @@ -1503,8 +1504,11 @@ def filter( or fn_constructor_kwargs is not None ): raise ValueError( - "when 'predicate' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used." + f"when '{param_type}' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used." ) + + if predicate is not None: + _check_fn_params_incompatible("predicate") from ray.data._internal.compute import TaskPoolStrategy compute = TaskPoolStrategy(size=concurrency) @@ -1517,15 +1521,7 @@ def filter( ray_remote_args=ray_remote_args, ) elif expr is not None: - if ( - fn_args is not None - or fn_kwargs is not None - or fn_constructor_args is not None - or fn_constructor_kwargs is not None - ): - raise ValueError( - "when 'expr' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' can not be used." - ) + _check_fn_params_incompatible("expr") from ray.data._internal.compute import TaskPoolStrategy from ray.data._internal.planner.plan_expression.expression_evaluator import ( # noqa: E501 ExpressionEvaluator, @@ -1550,18 +1546,19 @@ def filter( "Use 'expr' instead of 'fn' when possible for performant filters." ) - if callable(fn): - compute = get_compute_strategy( - fn=fn, - fn_constructor_args=fn_constructor_args, - compute=compute, - concurrency=concurrency, - ) - else: + if not callable(fn): raise ValueError( f"fn must be a UserDefinedFunction, but got " f"{type(fn).__name__} instead." ) + + compute = get_compute_strategy( + fn=fn, + fn_constructor_args=fn_constructor_args, + compute=compute, + concurrency=concurrency, + ) + # Create Filter operator with function filter_op = Filter( input_op=self._logical_plan.dag, diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 19751e29dacf..aec5bb42276a 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -713,7 +713,6 @@ def test_filter_mutex( # Filter using lambda (UDF) with pytest.raises( ValueError, - match="Exactly one of 'fn', 'expr', or 'predicate' must be provided.", ): parquet_ds.filter( fn=lambda r: r["sepal.length"] > 5.0, expr="sepal.length > 5.0" From 096051ca07b8b14c850d3a244dee054361062835 Mon Sep 17 00:00:00 2001 From: "Goutam V." Date: Wed, 24 Sep 2025 13:37:31 -0700 Subject: [PATCH 03/10] Address comments Signed-off-by: Goutam V. --- python/ray/data/_internal/arrow_block.py | 3 + .../logical/operators/map_operator.py | 38 +++--- python/ray/data/_internal/pandas_block.py | 4 + .../data/_internal/planner/plan_udf_map_op.py | 52 +++++---- python/ray/data/dataset.py | 108 ++++++++++-------- python/ray/data/tests/test_map.py | 79 +------------ 6 files changed, 115 insertions(+), 169 deletions(-) diff --git a/python/ray/data/_internal/arrow_block.py b/python/ray/data/_internal/arrow_block.py index f16c41bf3e85..55d474e8c4e8 100644 --- a/python/ray/data/_internal/arrow_block.py +++ b/python/ray/data/_internal/arrow_block.py @@ -463,6 +463,9 @@ def iter_rows( def filter(self, predicate_expr: "Expr") -> "pyarrow.Table": """Filter rows based on a predicate expression.""" + if self._table.num_rows == 0: + return self._table + from ray.data._expression_evaluator import eval_expr # Evaluate the expression to get a boolean mask diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index 5285979d0444..632f347ca9df 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -1,6 +1,7 @@ import functools import inspect import logging +import warnings from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional from ray.data._internal.compute import ComputeStrategy, TaskPoolStrategy @@ -224,29 +225,26 @@ class Filter(AbstractUDFMap): def __init__( self, input_op: LogicalOperator, - fn: Optional[ - UserDefinedFunction - ] = None, # TODO: Deprecate this parameter in favor of predicate_expr - fn_args: Optional[ - Iterable[Any] - ] = None, # TODO: Deprecate this parameter in favor of predicate_expr - fn_kwargs: Optional[ - Dict[str, Any] - ] = None, # TODO: Deprecate this parameter in favor of predicate_expr - fn_constructor_args: Optional[ - Iterable[Any] - ] = None, # TODO: Deprecate this parameter in favor of predicate_expr - fn_constructor_kwargs: Optional[ - Dict[str, Any] - ] = None, # TODO: Deprecate this parameter in favor of predicate_expr predicate_expr: Optional[Expr] = None, - filter_expr: Optional[ - "pa.dataset.Expression" - ] = None, # TODO: Deprecate this parameter in favor of predicate_expr + fn: Optional[UserDefinedFunction] = None, + fn_args: Optional[Iterable[Any]] = None, + fn_kwargs: Optional[Dict[str, Any]] = None, + fn_constructor_args: Optional[Iterable[Any]] = None, + fn_constructor_kwargs: Optional[Dict[str, Any]] = None, + filter_expr: Optional["pa.dataset.Expression"] = None, compute: Optional[ComputeStrategy] = None, ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, ray_remote_args: Optional[Dict[str, Any]] = None, ): + # Deprecation warning for filter_expr + if filter_expr is not None: + warnings.warn( + "The 'filter_expr' parameter is deprecated and will be removed in a future version. " + "Use 'predicate_expr' instead.", + FutureWarning, + stacklevel=2, + ) + # Ensure exactly one of fn, filter_expr, or predicate_expr is provided provided_params = sum( [fn is not None, filter_expr is not None, predicate_expr is not None] @@ -256,9 +254,7 @@ def __init__( "Exactly one of 'fn', 'filter_expr', or 'predicate_expr' must be provided" ) - self._filter_expr = ( - filter_expr # TODO: Deprecate this parameter in favor of predicate_expr - ) + self._filter_expr = filter_expr self._predicate_expr = predicate_expr super().__init__( diff --git a/python/ray/data/_internal/pandas_block.py b/python/ray/data/_internal/pandas_block.py index 4b8eaca9a5f5..29f5274246b4 100644 --- a/python/ray/data/_internal/pandas_block.py +++ b/python/ray/data/_internal/pandas_block.py @@ -613,6 +613,10 @@ def iter_rows( def filter(self, predicate_expr: "Expr") -> "pandas.DataFrame": """Filter rows based on a predicate expression.""" + if self._table.empty: + return self._table + + # TODO: Move _expression_evaluator to _internal from ray.data._expression_evaluator import eval_expr # Evaluate the expression to get a boolean mask diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index a976c5b62f1e..d1590dba0bfa 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -51,7 +51,6 @@ from ray.data._internal.output_buffer import OutputBlockSizeOption from ray.data._internal.util import _truncated_repr from ray.data.block import ( - BatchFormat, Block, BlockAccessor, CallableClass, @@ -219,40 +218,43 @@ def plan_filter_op( compute = get_compute(op._compute) if expression is not None: - def filter_batch_fn(block: "pa.Table") -> "pa.Table": - try: - return block.filter(expression) - except Exception as e: - _try_wrap_udf_exception(e) + def filter_block_fn( + blocks: Iterable[Block], ctx: TaskContext + ) -> Iterable[Block]: + for block in blocks: + try: + # Convert block to Arrow table and apply expression filter + if isinstance(block, pa.Table): + filtered_block = block.filter(expression) + else: + # Convert to Arrow first if needed + block_accessor = BlockAccessor.for_block(block) + arrow_block = block_accessor.to_arrow() + filtered_block = arrow_block.filter(expression) + yield filtered_block + except Exception as e: + _try_wrap_udf_exception(e) init_fn = None - transform_fn = BatchMapTransformFn( - _generate_transform_fn_for_map_batches(filter_batch_fn), - batch_size=None, - batch_format=BatchFormat.ARROW, - zero_copy_batch=True, + transform_fn = BlockMapTransformFn( + filter_block_fn, is_udf=True, output_block_size_option=output_block_size_option, ) elif predicate_expr is not None: - # Ray Data expression path using BlockAccessor - def filter_block_fn(block: Block) -> Block: - try: - block_accessor = BlockAccessor.for_block(block) - if not block_accessor.num_rows(): - return block - return block_accessor.filter(predicate_expr) - except Exception as e: - _try_wrap_udf_exception(e) + def filter_block_fn( + blocks: Iterable[Block], ctx: TaskContext + ) -> Iterable[Block]: + for block in blocks: + block_accessor = BlockAccessor.for_block(block) + filtered_block = block_accessor.filter(predicate_expr) + yield filtered_block init_fn = None - transform_fn = BatchMapTransformFn( - _generate_transform_fn_for_map_batches(filter_block_fn), - batch_size=None, - batch_format=BatchFormat.ARROW, - zero_copy_batch=True, + transform_fn = BlockMapTransformFn( + filter_block_fn, is_udf=True, output_block_size_option=output_block_size_option, ) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 362a478aacc6..b29d7f73f894 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1392,11 +1392,8 @@ def filter( self, fn: Optional[ UserDefinedFunction[Dict[str, Any], bool] - ] = None, # TODO: Deprecate this parameter in favor of predicate - expr: Optional[ - str - ] = None, # TODO: Deprecate this parameter in favor of predicate - predicate: Optional[Expr] = None, + ] = None, # TODO: Deprecate this parameter in favor of expr + expr: Optional[Union[str, Expr]] = None, *, compute: Union[str, ComputeStrategy] = None, fn_args: Optional[Iterable[Any]] = None, @@ -1412,25 +1409,30 @@ def filter( ) -> "Dataset": """Filter out rows that don't satisfy the given predicate. - You can use either a function or a callable class or an expression string to + You can use either a function or a callable class or an expression to perform the transformation. For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses stateful Ray actors. For more information, see :ref:`Stateful Transforms `. .. tip:: - If you use the `expr` parameter with a Python expression string, Ray Data + If you use the `expr` parameter with a predicate expression, Ray Data optimizes your filter with native Arrow interfaces. + .. deprecated:: + String expressions are deprecated and will be removed in a future version. + Use predicate expressions from `ray.data.expressions` instead. + Examples: >>> import ray >>> from ray.data.expressions import col >>> ds = ray.data.range(100) - >>> ds.filter(expr="id <= 4").take_all() # Will be deprecated in the future + >>> # String expressions (deprecated - will warn) + >>> ds.filter(expr="id <= 4").take_all() [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}, {'id': 4}] - >>> # Using predicate expressions - >>> ds.filter(predicate=(col("id") > 10) & (col("id") < 20)).take_all() + >>> # Using predicate expressions (preferred) + >>> ds.filter(expr=(col("id") > 10) & (col("id") < 20)).take_all() [{'id': 11}, {'id': 12}, {'id': 13}, {'id': 14}, {'id': 15}, {'id': 16}, {'id': 17}, {'id': 18}, {'id': 19}] Time complexity: O(dataset size / parallelism) @@ -1438,9 +1440,9 @@ def filter( Args: fn: The predicate to apply to each row, or a class type that can be instantiated to create such a callable. - expr: An expression string needs to be a valid Python expression that - will be converted to ``pyarrow.dataset.Expression`` type. (Soon to be deprecated) - predicate: An expression that represents a predicate (boolean condition) for filtering. + expr: An expression that represents a predicate (boolean condition) for filtering. + Can be either a string expression (deprecated) or a predicate expression + from `ray.data.expressions`. fn_args: Positional arguments to pass to ``fn`` after the first argument. These arguments are top-level arguments to the underlying Ray task. fn_kwargs: Keyword arguments to pass to ``fn``. These arguments are @@ -1488,12 +1490,10 @@ def filter( Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See :func:`ray.remote` for details. """ - # Ensure exactly one of fn, expr, or predicate is provided - provided_params = sum([fn is not None, expr is not None, predicate is not None]) + # Ensure exactly one of fn or expr is provided + provided_params = sum([fn is not None, expr is not None]) if provided_params != 1: - raise ValueError( - "Exactly one of 'fn', 'expr', or 'predicate' must be provided." - ) + raise ValueError("Exactly one of 'fn' or 'expr' must be provided.") # Helper function to check for incompatible function parameters def _check_fn_params_incompatible(param_type): @@ -1507,40 +1507,50 @@ def _check_fn_params_incompatible(param_type): f"when '{param_type}' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used." ) - if predicate is not None: - _check_fn_params_incompatible("predicate") - from ray.data._internal.compute import TaskPoolStrategy - - compute = TaskPoolStrategy(size=concurrency) - # Create Filter operator with predicate expression - filter_op = Filter( - input_op=self._logical_plan.dag, - predicate_expr=predicate, - compute=compute, - ray_remote_args_fn=ray_remote_args_fn, - ray_remote_args=ray_remote_args, - ) - elif expr is not None: + if expr is not None: _check_fn_params_incompatible("expr") from ray.data._internal.compute import TaskPoolStrategy - from ray.data._internal.planner.plan_expression.expression_evaluator import ( # noqa: E501 - ExpressionEvaluator, - ) - # TODO: (srinathk) bind the expression to the actual schema. - # If fn is a string, convert it to a pyarrow.dataset.Expression - # Initialize ExpressionEvaluator with valid columns, if available - resolved_expr = ExpressionEvaluator.get_filters(expression=expr) + # Check if expr is a string (deprecated) or Expr object + if isinstance(expr, str): + warnings.warn( + "String expressions are deprecated and will be removed in a future version. " + "Use predicate expressions from ray.data.expressions instead. " + "For example: from ray.data.expressions import col; " + "ds.filter(expr=col('column_name') > 5)", + DeprecationWarning, + stacklevel=2, + ) - compute = TaskPoolStrategy(size=concurrency) - # Create Filter operator with string expression - filter_op = Filter( - input_op=self._logical_plan.dag, - filter_expr=resolved_expr, - compute=compute, - ray_remote_args_fn=ray_remote_args_fn, - ray_remote_args=ray_remote_args, - ) + from ray.data._internal.planner.plan_expression.expression_evaluator import ( # noqa: E501 + ExpressionEvaluator, + ) + + # TODO: (srinathk) bind the expression to the actual schema. + # If expr is a string, convert it to a pyarrow.dataset.Expression + # Initialize ExpressionEvaluator with valid columns, if available + resolved_expr = ExpressionEvaluator.get_filters(expression=expr) + + compute = TaskPoolStrategy(size=concurrency) + # Create Filter operator with string expression + filter_op = Filter( + input_op=self._logical_plan.dag, + filter_expr=resolved_expr, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) + else: + # expr is an Expr object (predicate expression) + compute = TaskPoolStrategy(size=concurrency) + # Create Filter operator with predicate expression + filter_op = Filter( + input_op=self._logical_plan.dag, + predicate_expr=expr, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) else: warnings.warn( "Use 'expr' instead of 'fn' when possible for performant filters." @@ -1608,7 +1618,7 @@ def repartition( * When ``num_blocks`` and ``shuffle=True`` are specified Ray Data performs a full distributed shuffle producing exactly ``num_blocks`` blocks. * When ``num_blocks`` and ``shuffle=False`` are specified, Ray Data does NOT perform full shuffle, instead opting in for splitting and combining of the blocks attempting to minimize the necessary data movement (relative to full-blown shuffle). Exactly ``num_blocks`` will be produced. - * If ``target_num_rows_per_block`` is set (exclusive with ``num_blocks`` and ``shuffle``), streaming repartitioning will be executed, where blocks will be made to carry no more than ``target_num_rows_per_block``. Smaller blocks will be combined into bigger ones up to ``target_num_rows_per_block`` as well. + * If ``target_num_rows_per_block`` is set (exclusive with ``num_blocks`` and ``shuffle``), streaming repartitioning will be executed, where blocks will be made to carry no more than ``target_num_rows_per_block`` rows. Smaller blocks will be combined into bigger ones up to ``target_num_rows_per_block`` as well. .. image:: /data/images/dataset-shuffle.svg :align: center diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index aec5bb42276a..8520d38b6564 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -964,7 +964,7 @@ def test_filter_with_predicate_expressions( ds = ray.data.from_items(test_data) # Apply filter with predicate expression - filtered_ds = ds.filter(predicate=predicate_expr) + filtered_ds = ds.filter(expr=predicate_expr) # Convert to list and verify results result_data = filtered_ds.to_pandas().to_dict("records") @@ -999,88 +999,19 @@ def test_filter_predicate_expr_vs_function_consistency( ds = ray.data.from_items(test_data) # Test simple comparison - predicate_result = ds.filter(predicate=col("age") >= 21).to_pandas() + predicate_result = ds.filter(expr=col("age") >= 21).to_pandas() function_result = ds.filter(fn=lambda row: row["age"] >= 21).to_pandas() pd.testing.assert_frame_equal(predicate_result, function_result, check_dtype=False) # Test complex logical expression complex_predicate = (col("age") >= 21) & (col("score") > 80) & col("active") - predicate_result = ds.filter(predicate=complex_predicate).to_pandas() + predicate_result = ds.filter(expr=complex_predicate).to_pandas() function_result = ds.filter( fn=lambda row: row["age"] >= 21 and row["score"] > 80 and row["active"] ).to_pandas() pd.testing.assert_frame_equal(predicate_result, function_result, check_dtype=False) -@pytest.mark.skipif( - get_pyarrow_version() < parse_version("20.0.0"), - reason="predicate expressions require PyArrow >= 20.0.0", -) -@pytest.mark.parametrize( - "filter_args, expected_error_match", - [ - # Test that exactly one parameter must be provided - pytest.param( - {}, - "Exactly one of 'fn', 'expr', or 'predicate' must be provided", - id="no_parameters", - ), - pytest.param( - {"fn": lambda x: True, "predicate": col("x") > 0}, - "Exactly one of 'fn', 'expr', or 'predicate' must be provided", - id="fn_and_predicate", - ), - pytest.param( - {"expr": "x > 0", "predicate": col("x") > 0}, - "Exactly one of 'fn', 'expr', or 'predicate' must be provided", - id="expr_and_predicate", - ), - pytest.param( - {"fn": lambda x: True, "expr": "x > 0", "predicate": col("x") > 0}, - "Exactly one of 'fn', 'expr', or 'predicate' must be provided", - id="all_three_parameters", - ), - pytest.param( - {"fn": lambda x: True, "expr": "x > 0"}, - "Exactly one of 'fn', 'expr', or 'predicate' must be provided", - id="fn_and_expr", - ), - # Test that predicate is incompatible with function-specific parameters - pytest.param( - {"predicate": col("x") > 0, "fn_args": [1, 2]}, - "when 'predicate' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used", - id="predicate_with_fn_args", - ), - pytest.param( - {"predicate": col("x") > 0, "fn_kwargs": {"key": "value"}}, - "when 'predicate' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used", - id="predicate_with_fn_kwargs", - ), - pytest.param( - {"predicate": col("x") > 0, "fn_constructor_args": [1, 2]}, - "when 'predicate' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used", - id="predicate_with_fn_constructor_args", - ), - pytest.param( - {"predicate": col("x") > 0, "fn_constructor_kwargs": {"key": "value"}}, - "when 'predicate' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used", - id="predicate_with_fn_constructor_kwargs", - ), - ], -) -def test_filter_predicate_parameter_validation( - ray_start_regular_shared, - target_max_block_size_infinite_or_default, - filter_args, - expected_error_match, -): - """Test that filter() properly validates predicate parameter usage.""" - ds = ray.data.from_items([{"x": 1}, {"x": 2}]) - - with pytest.raises(ValueError, match=expected_error_match): - ds.filter(**filter_args) - - @pytest.mark.skipif( get_pyarrow_version() < parse_version("20.0.0"), reason="predicate expressions require PyArrow >= 20.0.0", @@ -1100,12 +1031,12 @@ def test_filter_predicate_with_different_block_formats( # From items (typically arrow) ds_items = ray.data.from_items(test_data) - result_items = ds_items.filter(predicate=col("category") == "A").to_pandas() + result_items = ds_items.filter(expr=col("category") == "A").to_pandas() # From pandas (pandas blocks) df = pd.DataFrame(test_data) ds_pandas = ray.data.from_pandas([df]) - result_pandas = ds_pandas.filter(predicate=col("category") == "A").to_pandas() + result_pandas = ds_pandas.filter(expr=col("category") == "A").to_pandas() # Results should be identical (reset indices for comparison) expected_df = pd.DataFrame( From 2d84e77dbdc0b3fc24b746fd44cd3e74ae1b3fec Mon Sep 17 00:00:00 2001 From: "Goutam V." Date: Wed, 24 Sep 2025 15:30:07 -0700 Subject: [PATCH 04/10] fix test Signed-off-by: Goutam V. --- python/ray/data/tests/test_execution_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 1b8c91474c82..4d1dcb1ae4b4 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -301,7 +301,7 @@ def test_filter_operator(ray_start_regular_shared_2_cpus): read_op = get_parquet_read_logical_op() op = Filter( read_op, - lambda x: x, + fn=lambda x: x, ) plan = LogicalPlan(op, ctx) physical_op = planner.plan(plan).dag From 7ff4f9e6ab6867b2ed9040a81c7fb5dfeffd7e30 Mon Sep 17 00:00:00 2001 From: "Goutam V." Date: Fri, 26 Sep 2025 11:33:56 -0700 Subject: [PATCH 05/10] Respond to comments Signed-off-by: Goutam V. --- python/ray/data/BUILD.bazel | 10 + python/ray/data/dataset.py | 69 ++--- python/ray/data/tests/test_filter.py | 388 +++++++++++++++++++++++++++ python/ray/data/tests/test_map.py | 372 ------------------------- 4 files changed, 436 insertions(+), 403 deletions(-) create mode 100644 python/ray/data/tests/test_filter.py diff --git a/python/ray/data/BUILD.bazel b/python/ray/data/BUILD.bazel index 680255466b0c..a1c34b5e0a1d 100644 --- a/python/ray/data/BUILD.bazel +++ b/python/ray/data/BUILD.bazel @@ -700,6 +700,16 @@ py_test( ], ) +py_test( + name = "test_filter", + size = "medium", + srcs = ["tests/test_filter.py"], + tags = [ + "exclusive", + "team:data", + ], +) + py_test( name = "test_numpy", size = "medium", diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index a787ab5551a2..6291efeb0a95 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1507,6 +1507,21 @@ def _check_fn_params_incompatible(param_type): f"when '{param_type}' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used." ) + # Merge ray remote args early + ray_remote_args = merge_resources_to_ray_remote_args( + num_cpus, + num_gpus, + memory, + ray_remote_args, + ) + + # Prepare common Filter operator arguments + filter_kwargs = { + "input_op": self._logical_plan.dag, + "ray_remote_args_fn": ray_remote_args_fn, + "ray_remote_args": ray_remote_args, + } + if expr is not None: _check_fn_params_incompatible("expr") from ray.data._internal.compute import TaskPoolStrategy @@ -1529,27 +1544,24 @@ def _check_fn_params_incompatible(param_type): # TODO: (srinathk) bind the expression to the actual schema. # If expr is a string, convert it to a pyarrow.dataset.Expression # Initialize ExpressionEvaluator with valid columns, if available + # TODO: Rewrite the converter in Ray Data's Expression system. resolved_expr = ExpressionEvaluator.get_filters(expression=expr) compute = TaskPoolStrategy(size=concurrency) - # Create Filter operator with string expression - filter_op = Filter( - input_op=self._logical_plan.dag, - filter_expr=resolved_expr, - compute=compute, - ray_remote_args_fn=ray_remote_args_fn, - ray_remote_args=ray_remote_args, + filter_kwargs.update( + { + "filter_expr": resolved_expr, + "compute": compute, + } ) else: # expr is an Expr object (predicate expression) compute = TaskPoolStrategy(size=concurrency) - # Create Filter operator with predicate expression - filter_op = Filter( - input_op=self._logical_plan.dag, - predicate_expr=expr, - compute=compute, - ray_remote_args_fn=ray_remote_args_fn, - ray_remote_args=ray_remote_args, + filter_kwargs.update( + { + "predicate_expr": expr, + "compute": compute, + } ) else: warnings.warn( @@ -1569,25 +1581,20 @@ def _check_fn_params_incompatible(param_type): concurrency=concurrency, ) - # Create Filter operator with function - filter_op = Filter( - input_op=self._logical_plan.dag, - fn=fn, - fn_args=fn_args, - fn_kwargs=fn_kwargs, - fn_constructor_args=fn_constructor_args, - fn_constructor_kwargs=fn_constructor_kwargs, - compute=compute, - ray_remote_args_fn=ray_remote_args_fn, - ray_remote_args=ray_remote_args, + filter_kwargs.update( + { + "fn": fn, + "fn_args": fn_args, + "fn_kwargs": fn_kwargs, + "fn_constructor_args": fn_constructor_args, + "fn_constructor_kwargs": fn_constructor_kwargs, + "compute": compute, + } ) - ray_remote_args = merge_resources_to_ray_remote_args( - num_cpus, - num_gpus, - memory, - ray_remote_args, - ) + # Create Filter operator with consolidated arguments + filter_op = Filter(**filter_kwargs) + plan = self._plan.copy() logical_plan = LogicalPlan(filter_op, self.context) return Dataset(plan, logical_plan) diff --git a/python/ray/data/tests/test_filter.py b/python/ray/data/tests/test_filter.py new file mode 100644 index 000000000000..3a6a44f211e0 --- /dev/null +++ b/python/ray/data/tests/test_filter.py @@ -0,0 +1,388 @@ +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from pkg_resources import parse_version + +import ray +from ray.data.exceptions import UserCodeException +from ray.data.expressions import col +from ray.data.tests.conftest import get_pyarrow_version + + +def test_filter_mutex( + ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default +): + """Test filter op.""" + + # Generate sample data + data = { + "sepal.length": [4.8, 5.1, 5.7, 6.3, 7.0], + "sepal.width": [3.0, 3.3, 3.5, 3.2, 2.8], + "petal.length": [1.4, 1.7, 4.2, 5.4, 6.1], + "petal.width": [0.2, 0.4, 1.5, 2.1, 2.4], + } + df = pd.DataFrame(data) + + # Define the path for the Parquet file in the tmp_path directory + parquet_file = tmp_path / "sample_data.parquet" + + # Write DataFrame to a Parquet file + table = pa.Table.from_pandas(df) + pq.write_table(table, parquet_file) + + # Load parquet dataset + parquet_ds = ray.data.read_parquet(str(parquet_file)) + + # Filter using lambda (UDF) + with pytest.raises( + ValueError, + ): + parquet_ds.filter( + fn=lambda r: r["sepal.length"] > 5.0, expr="sepal.length > 5.0" + ) + + with pytest.raises(ValueError, match="must be a UserDefinedFunction"): + parquet_ds.filter(fn="sepal.length > 5.0") + + +def test_filter_with_expressions( + ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default +): + """Test filtering with expressions.""" + + # Generate sample data + data = { + "sepal.length": [4.8, 5.1, 5.7, 6.3, 7.0], + "sepal.width": [3.0, 3.3, 3.5, 3.2, 2.8], + "petal.length": [1.4, 1.7, 4.2, 5.4, 6.1], + "petal.width": [0.2, 0.4, 1.5, 2.1, 2.4], + } + df = pd.DataFrame(data) + + # Define the path for the Parquet file in the tmp_path directory + parquet_file = tmp_path / "sample_data.parquet" + + # Write DataFrame to a Parquet file + table = pa.Table.from_pandas(df) + pq.write_table(table, parquet_file) + + # Load parquet dataset + parquet_ds = ray.data.read_parquet(str(parquet_file)) + + # Filter using lambda (UDF) + filtered_udf_ds = parquet_ds.filter(lambda r: r["sepal.length"] > 5.0) + filtered_udf_data = filtered_udf_ds.to_pandas() + + # Filter using expressions + filtered_expr_ds = parquet_ds.filter(expr="sepal.length > 5.0") + filtered_expr_data = filtered_expr_ds.to_pandas() + + # Assert the filtered data is the same + assert set(filtered_udf_data["sepal.length"]) == set( + filtered_expr_data["sepal.length"] + ) + assert len(filtered_udf_data) == len(filtered_expr_data) + + # Verify correctness of filtered results: only rows with 'sepal.length' > 5.0 + assert all( + filtered_expr_data["sepal.length"] > 5.0 + ), "Filtered data contains rows with 'sepal.length' <= 5.0" + assert all( + filtered_udf_data["sepal.length"] > 5.0 + ), "UDF-filtered data contains rows with 'sepal.length' <= 5.0" + + +def test_filter_with_invalid_expression( + ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default +): + """Test filtering with invalid expressions.""" + + # Generate sample data + data = { + "sepal.length": [4.8, 5.1, 5.7, 6.3, 7.0], + "sepal.width": [3.0, 3.3, 3.5, 3.2, 2.8], + "petal.length": [1.4, 1.7, 4.2, 5.4, 6.1], + "petal.width": [0.2, 0.4, 1.5, 2.1, 2.4], + } + df = pd.DataFrame(data) + + # Define the path for the Parquet file in the tmp_path directory + parquet_file = tmp_path / "sample_data.parquet" + + # Write DataFrame to a Parquet file + table = pa.Table.from_pandas(df) + pq.write_table(table, parquet_file) + + # Load parquet dataset + parquet_ds = ray.data.read_parquet(str(parquet_file)) + + with pytest.raises(ValueError, match="Invalid syntax in the expression"): + parquet_ds.filter(expr="fake_news super fake") + + fake_column_ds = parquet_ds.filter(expr="sepal_length_123 > 1") + with pytest.raises(UserCodeException): + fake_column_ds.to_pandas() + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="predicate expressions require PyArrow >= 20.0.0", +) +@pytest.mark.parametrize( + "predicate_expr, test_data, expected_indices, test_description", + [ + # Simple comparison filters + pytest.param( + col("age") >= 21, + [ + {"age": 20, "name": "Alice"}, + {"age": 21, "name": "Bob"}, + {"age": 25, "name": "Charlie"}, + {"age": 30, "name": "David"}, + ], + [1, 2, 3], # Indices of rows that should remain + "age_greater_equal_filter", + ), + pytest.param( + col("score") > 50, + [ + {"score": 30, "status": "fail"}, + {"score": 50, "status": "borderline"}, + {"score": 70, "status": "pass"}, + {"score": 90, "status": "excellent"}, + ], + [2, 3], + "score_greater_than_filter", + ), + pytest.param( + col("category") == "premium", + [ + {"category": "basic", "price": 10}, + {"category": "premium", "price": 50}, + {"category": "standard", "price": 25}, + {"category": "premium", "price": 75}, + ], + [1, 3], + "equality_string_filter", + ), + # Complex logical filters + pytest.param( + (col("age") >= 18) & (col("active")), + [ + {"age": 17, "active": True}, + {"age": 18, "active": False}, + {"age": 25, "active": True}, + {"age": 30, "active": True}, + ], + [2, 3], + "logical_and_filter", + ), + pytest.param( + (col("status") == "approved") | (col("priority") == "high"), + [ + {"status": "pending", "priority": "low"}, + {"status": "approved", "priority": "low"}, + {"status": "pending", "priority": "high"}, + {"status": "rejected", "priority": "high"}, + ], + [1, 2, 3], + "logical_or_filter", + ), + # Null handling filters + pytest.param( + col("value").is_not_null(), + [ + {"value": None, "id": 1}, + {"value": 0, "id": 2}, + {"value": None, "id": 3}, + {"value": 42, "id": 4}, + ], + [1, 3], + "not_null_filter", + ), + pytest.param( + col("name").is_null(), + [ + {"name": "Alice", "id": 1}, + {"name": None, "id": 2}, + {"name": "Bob", "id": 3}, + {"name": None, "id": 4}, + ], + [1, 3], + "is_null_filter", + ), + # Complex multi-condition filters + pytest.param( + col("value").is_not_null() & (col("value") > 0), + [ + {"value": None, "type": "missing"}, + {"value": -5, "type": "negative"}, + {"value": 0, "type": "zero"}, + {"value": 10, "type": "positive"}, + ], + [3], + "null_aware_positive_filter", + ), + # String operations + pytest.param( + col("name").is_not_null() & (col("name") != "excluded"), + [ + {"name": "included", "id": 1}, + {"name": "excluded", "id": 2}, + {"name": None, "id": 3}, + {"name": "allowed", "id": 4}, + ], + [0, 3], + "string_exclusion_filter", + ), + # Membership operations + pytest.param( + col("category").is_in(["A", "B"]), + [ + {"category": "A", "value": 1}, + {"category": "B", "value": 2}, + {"category": "C", "value": 3}, + {"category": "D", "value": 4}, + {"category": "A", "value": 5}, + ], + [0, 1, 4], + "membership_filter", + ), + # Negation operations + pytest.param( + ~(col("category") == "reject"), + [ + {"category": "accept", "id": 1}, + {"category": "reject", "id": 2}, + {"category": "pending", "id": 3}, + {"category": "reject", "id": 4}, + ], + [0, 2], + "negation_filter", + ), + # Nested complex expressions + pytest.param( + (col("score") >= 50) & (col("grade") != "F") & col("active"), + [ + {"score": 45, "grade": "F", "active": True}, + {"score": 55, "grade": "D", "active": True}, + {"score": 75, "grade": "B", "active": False}, + {"score": 85, "grade": "A", "active": True}, + ], + [1, 3], + "complex_nested_filter", + ), + ], +) +def test_filter_with_predicate_expressions( + ray_start_regular_shared, + predicate_expr, + test_data, + expected_indices, + test_description, + target_max_block_size_infinite_or_default, +): + """Test filter() with Ray Data predicate expressions.""" + # Create dataset from test data + ds = ray.data.from_items(test_data) + + # Apply filter with predicate expression + filtered_ds = ds.filter(expr=predicate_expr) + + # Convert to list and verify results + result_data = filtered_ds.to_pandas().to_dict("records") + expected_data = [test_data[i] for i in expected_indices] + + # Use pandas testing for consistent comparison + result_df = pd.DataFrame(result_data) + expected_df = pd.DataFrame(expected_data) + + pd.testing.assert_frame_equal( + result_df.reset_index(drop=True), + expected_df.reset_index(drop=True), + check_dtype=False, + ) + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="predicate expressions require PyArrow >= 20.0.0", +) +def test_filter_predicate_expr_vs_function_consistency( + ray_start_regular_shared, target_max_block_size_infinite_or_default +): + """Test that predicate expressions produce the same results as equivalent functions.""" + test_data = [ + {"age": 20, "score": 85, "active": True}, + {"age": 25, "score": 45, "active": False}, + {"age": 30, "score": 95, "active": True}, + {"age": 18, "score": 60, "active": True}, + ] + + ds = ray.data.from_items(test_data) + + # Test simple comparison + predicate_result = ds.filter(expr=col("age") >= 21).to_pandas() + function_result = ds.filter(fn=lambda row: row["age"] >= 21).to_pandas() + pd.testing.assert_frame_equal(predicate_result, function_result, check_dtype=False) + + # Test complex logical expression + complex_predicate = (col("age") >= 21) & (col("score") > 80) & col("active") + predicate_result = ds.filter(expr=complex_predicate).to_pandas() + function_result = ds.filter( + fn=lambda row: row["age"] >= 21 and row["score"] > 80 and row["active"] + ).to_pandas() + pd.testing.assert_frame_equal(predicate_result, function_result, check_dtype=False) + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="predicate expressions require PyArrow >= 20.0.0", +) +def test_filter_predicate_with_different_block_formats( + ray_start_regular_shared, target_max_block_size_infinite_or_default +): + """Test that predicate expressions work with different block formats (pandas/arrow).""" + test_data = [ + {"category": "A", "value": 10}, + {"category": "B", "value": 20}, + {"category": "A", "value": 30}, + {"category": "C", "value": 40}, + ] + + # Test with different data sources that produce different block formats + + # From items (typically arrow) + ds_items = ray.data.from_items(test_data) + result_items = ds_items.filter(expr=col("category") == "A").to_pandas() + + # From pandas (pandas blocks) + df = pd.DataFrame(test_data) + ds_pandas = ray.data.from_pandas([df]) + result_pandas = ds_pandas.filter(expr=col("category") == "A").to_pandas() + + # Results should be identical (reset indices for comparison) + expected_df = pd.DataFrame( + [ + {"category": "A", "value": 10}, + {"category": "A", "value": 30}, + ] + ) + + pd.testing.assert_frame_equal( + result_items.reset_index(drop=True), + expected_df.reset_index(drop=True), + check_dtype=False, + ) + pd.testing.assert_frame_equal( + result_pandas.reset_index(drop=True), + expected_df.reset_index(drop=True), + check_dtype=False, + ) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 1d907b6a9a5e..d55c36717117 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -686,378 +686,6 @@ def test_rename_columns_error_cases( assert str(exc_info.value) == expected_message -def test_filter_mutex( - ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default -): - """Test filter op.""" - - # Generate sample data - data = { - "sepal.length": [4.8, 5.1, 5.7, 6.3, 7.0], - "sepal.width": [3.0, 3.3, 3.5, 3.2, 2.8], - "petal.length": [1.4, 1.7, 4.2, 5.4, 6.1], - "petal.width": [0.2, 0.4, 1.5, 2.1, 2.4], - } - df = pd.DataFrame(data) - - # Define the path for the Parquet file in the tmp_path directory - parquet_file = tmp_path / "sample_data.parquet" - - # Write DataFrame to a Parquet file - table = pa.Table.from_pandas(df) - pq.write_table(table, parquet_file) - - # Load parquet dataset - parquet_ds = ray.data.read_parquet(str(parquet_file)) - - # Filter using lambda (UDF) - with pytest.raises( - ValueError, - ): - parquet_ds.filter( - fn=lambda r: r["sepal.length"] > 5.0, expr="sepal.length > 5.0" - ) - - with pytest.raises(ValueError, match="must be a UserDefinedFunction"): - parquet_ds.filter(fn="sepal.length > 5.0") - - -def test_filter_with_expressions( - ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default -): - """Test filtering with expressions.""" - - # Generate sample data - data = { - "sepal.length": [4.8, 5.1, 5.7, 6.3, 7.0], - "sepal.width": [3.0, 3.3, 3.5, 3.2, 2.8], - "petal.length": [1.4, 1.7, 4.2, 5.4, 6.1], - "petal.width": [0.2, 0.4, 1.5, 2.1, 2.4], - } - df = pd.DataFrame(data) - - # Define the path for the Parquet file in the tmp_path directory - parquet_file = tmp_path / "sample_data.parquet" - - # Write DataFrame to a Parquet file - table = pa.Table.from_pandas(df) - pq.write_table(table, parquet_file) - - # Load parquet dataset - parquet_ds = ray.data.read_parquet(str(parquet_file)) - - # Filter using lambda (UDF) - filtered_udf_ds = parquet_ds.filter(lambda r: r["sepal.length"] > 5.0) - filtered_udf_data = filtered_udf_ds.to_pandas() - - # Filter using expressions - filtered_expr_ds = parquet_ds.filter(expr="sepal.length > 5.0") - filtered_expr_data = filtered_expr_ds.to_pandas() - - # Assert the filtered data is the same - assert set(filtered_udf_data["sepal.length"]) == set( - filtered_expr_data["sepal.length"] - ) - assert len(filtered_udf_data) == len(filtered_expr_data) - - # Verify correctness of filtered results: only rows with 'sepal.length' > 5.0 - assert all( - filtered_expr_data["sepal.length"] > 5.0 - ), "Filtered data contains rows with 'sepal.length' <= 5.0" - assert all( - filtered_udf_data["sepal.length"] > 5.0 - ), "UDF-filtered data contains rows with 'sepal.length' <= 5.0" - - -def test_filter_with_invalid_expression( - ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default -): - """Test filtering with invalid expressions.""" - - # Generate sample data - data = { - "sepal.length": [4.8, 5.1, 5.7, 6.3, 7.0], - "sepal.width": [3.0, 3.3, 3.5, 3.2, 2.8], - "petal.length": [1.4, 1.7, 4.2, 5.4, 6.1], - "petal.width": [0.2, 0.4, 1.5, 2.1, 2.4], - } - df = pd.DataFrame(data) - - # Define the path for the Parquet file in the tmp_path directory - parquet_file = tmp_path / "sample_data.parquet" - - # Write DataFrame to a Parquet file - table = pa.Table.from_pandas(df) - pq.write_table(table, parquet_file) - - # Load parquet dataset - parquet_ds = ray.data.read_parquet(str(parquet_file)) - - with pytest.raises(ValueError, match="Invalid syntax in the expression"): - parquet_ds.filter(expr="fake_news super fake") - - fake_column_ds = parquet_ds.filter(expr="sepal_length_123 > 1") - with pytest.raises(UserCodeException): - fake_column_ds.to_pandas() - - -@pytest.mark.skipif( - get_pyarrow_version() < parse_version("20.0.0"), - reason="predicate expressions require PyArrow >= 20.0.0", -) -@pytest.mark.parametrize( - "predicate_expr, test_data, expected_indices, test_description", - [ - # Simple comparison filters - pytest.param( - col("age") >= 21, - [ - {"age": 20, "name": "Alice"}, - {"age": 21, "name": "Bob"}, - {"age": 25, "name": "Charlie"}, - {"age": 30, "name": "David"}, - ], - [1, 2, 3], # Indices of rows that should remain - "age_greater_equal_filter", - ), - pytest.param( - col("score") > 50, - [ - {"score": 30, "status": "fail"}, - {"score": 50, "status": "borderline"}, - {"score": 70, "status": "pass"}, - {"score": 90, "status": "excellent"}, - ], - [2, 3], - "score_greater_than_filter", - ), - pytest.param( - col("category") == "premium", - [ - {"category": "basic", "price": 10}, - {"category": "premium", "price": 50}, - {"category": "standard", "price": 25}, - {"category": "premium", "price": 75}, - ], - [1, 3], - "equality_string_filter", - ), - # Complex logical filters - pytest.param( - (col("age") >= 18) & (col("active")), - [ - {"age": 17, "active": True}, - {"age": 18, "active": False}, - {"age": 25, "active": True}, - {"age": 30, "active": True}, - ], - [2, 3], - "logical_and_filter", - ), - pytest.param( - (col("status") == "approved") | (col("priority") == "high"), - [ - {"status": "pending", "priority": "low"}, - {"status": "approved", "priority": "low"}, - {"status": "pending", "priority": "high"}, - {"status": "rejected", "priority": "high"}, - ], - [1, 2, 3], - "logical_or_filter", - ), - # Null handling filters - pytest.param( - col("value").is_not_null(), - [ - {"value": None, "id": 1}, - {"value": 0, "id": 2}, - {"value": None, "id": 3}, - {"value": 42, "id": 4}, - ], - [1, 3], - "not_null_filter", - ), - pytest.param( - col("name").is_null(), - [ - {"name": "Alice", "id": 1}, - {"name": None, "id": 2}, - {"name": "Bob", "id": 3}, - {"name": None, "id": 4}, - ], - [1, 3], - "is_null_filter", - ), - # Complex multi-condition filters - pytest.param( - col("value").is_not_null() & (col("value") > 0), - [ - {"value": None, "type": "missing"}, - {"value": -5, "type": "negative"}, - {"value": 0, "type": "zero"}, - {"value": 10, "type": "positive"}, - ], - [3], - "null_aware_positive_filter", - ), - # String operations - pytest.param( - col("name").is_not_null() & (col("name") != "excluded"), - [ - {"name": "included", "id": 1}, - {"name": "excluded", "id": 2}, - {"name": None, "id": 3}, - {"name": "allowed", "id": 4}, - ], - [0, 3], - "string_exclusion_filter", - ), - # Membership operations - pytest.param( - col("category").is_in(["A", "B"]), - [ - {"category": "A", "value": 1}, - {"category": "B", "value": 2}, - {"category": "C", "value": 3}, - {"category": "D", "value": 4}, - {"category": "A", "value": 5}, - ], - [0, 1, 4], - "membership_filter", - ), - # Negation operations - pytest.param( - ~(col("category") == "reject"), - [ - {"category": "accept", "id": 1}, - {"category": "reject", "id": 2}, - {"category": "pending", "id": 3}, - {"category": "reject", "id": 4}, - ], - [0, 2], - "negation_filter", - ), - # Nested complex expressions - pytest.param( - (col("score") >= 50) & (col("grade") != "F") & col("active"), - [ - {"score": 45, "grade": "F", "active": True}, - {"score": 55, "grade": "D", "active": True}, - {"score": 75, "grade": "B", "active": False}, - {"score": 85, "grade": "A", "active": True}, - ], - [1, 3], - "complex_nested_filter", - ), - ], -) -def test_filter_with_predicate_expressions( - ray_start_regular_shared, - predicate_expr, - test_data, - expected_indices, - test_description, - target_max_block_size_infinite_or_default, -): - """Test filter() with Ray Data predicate expressions.""" - # Create dataset from test data - ds = ray.data.from_items(test_data) - - # Apply filter with predicate expression - filtered_ds = ds.filter(expr=predicate_expr) - - # Convert to list and verify results - result_data = filtered_ds.to_pandas().to_dict("records") - expected_data = [test_data[i] for i in expected_indices] - - # Use pandas testing for consistent comparison - result_df = pd.DataFrame(result_data) - expected_df = pd.DataFrame(expected_data) - - pd.testing.assert_frame_equal( - result_df.reset_index(drop=True), - expected_df.reset_index(drop=True), - check_dtype=False, - ) - - -@pytest.mark.skipif( - get_pyarrow_version() < parse_version("20.0.0"), - reason="predicate expressions require PyArrow >= 20.0.0", -) -def test_filter_predicate_expr_vs_function_consistency( - ray_start_regular_shared, target_max_block_size_infinite_or_default -): - """Test that predicate expressions produce the same results as equivalent functions.""" - test_data = [ - {"age": 20, "score": 85, "active": True}, - {"age": 25, "score": 45, "active": False}, - {"age": 30, "score": 95, "active": True}, - {"age": 18, "score": 60, "active": True}, - ] - - ds = ray.data.from_items(test_data) - - # Test simple comparison - predicate_result = ds.filter(expr=col("age") >= 21).to_pandas() - function_result = ds.filter(fn=lambda row: row["age"] >= 21).to_pandas() - pd.testing.assert_frame_equal(predicate_result, function_result, check_dtype=False) - - # Test complex logical expression - complex_predicate = (col("age") >= 21) & (col("score") > 80) & col("active") - predicate_result = ds.filter(expr=complex_predicate).to_pandas() - function_result = ds.filter( - fn=lambda row: row["age"] >= 21 and row["score"] > 80 and row["active"] - ).to_pandas() - pd.testing.assert_frame_equal(predicate_result, function_result, check_dtype=False) - - -@pytest.mark.skipif( - get_pyarrow_version() < parse_version("20.0.0"), - reason="predicate expressions require PyArrow >= 20.0.0", -) -def test_filter_predicate_with_different_block_formats( - ray_start_regular_shared, target_max_block_size_infinite_or_default -): - """Test that predicate expressions work with different block formats (pandas/arrow).""" - test_data = [ - {"category": "A", "value": 10}, - {"category": "B", "value": 20}, - {"category": "A", "value": 30}, - {"category": "C", "value": 40}, - ] - - # Test with different data sources that produce different block formats - - # From items (typically arrow) - ds_items = ray.data.from_items(test_data) - result_items = ds_items.filter(expr=col("category") == "A").to_pandas() - - # From pandas (pandas blocks) - df = pd.DataFrame(test_data) - ds_pandas = ray.data.from_pandas([df]) - result_pandas = ds_pandas.filter(expr=col("category") == "A").to_pandas() - - # Results should be identical (reset indices for comparison) - expected_df = pd.DataFrame( - [ - {"category": "A", "value": 10}, - {"category": "A", "value": 30}, - ] - ) - - pd.testing.assert_frame_equal( - result_items.reset_index(drop=True), - expected_df.reset_index(drop=True), - check_dtype=False, - ) - pd.testing.assert_frame_equal( - result_pandas.reset_index(drop=True), - expected_df.reset_index(drop=True), - check_dtype=False, - ) - - def test_drop_columns( ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default ): From 735faa6de0dff97de3e77447fe3ec9c4d391749d Mon Sep 17 00:00:00 2001 From: "Goutam V." Date: Fri, 26 Sep 2025 17:13:31 -0700 Subject: [PATCH 06/10] Cleanup + test fix Signed-off-by: Goutam V. --- .../logical/operators/map_operator.py | 26 +-- .../plan_expression/expression_evaluator.py | 170 ++++++++++++++++++ .../data/_internal/planner/plan_udf_map_op.py | 29 +-- python/ray/data/dataset.py | 6 +- python/ray/data/tests/test_filter.py | 4 +- 5 files changed, 180 insertions(+), 55 deletions(-) diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index b5b48bd46454..0a734e841abe 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -1,8 +1,7 @@ import functools import inspect import logging -import warnings -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional from ray.data._internal.compute import ComputeStrategy, TaskPoolStrategy from ray.data._internal.logical.interfaces import LogicalOperator @@ -11,10 +10,6 @@ from ray.data.expressions import Expr from ray.data.preprocessor import Preprocessor -if TYPE_CHECKING: - import pyarrow as pa - - logger = logging.getLogger(__name__) @@ -242,30 +237,17 @@ def __init__( fn_kwargs: Optional[Dict[str, Any]] = None, fn_constructor_args: Optional[Iterable[Any]] = None, fn_constructor_kwargs: Optional[Dict[str, Any]] = None, - filter_expr: Optional["pa.dataset.Expression"] = None, compute: Optional[ComputeStrategy] = None, ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, ray_remote_args: Optional[Dict[str, Any]] = None, ): - # Deprecation warning for filter_expr - if filter_expr is not None: - warnings.warn( - "The 'filter_expr' parameter is deprecated and will be removed in a future version. " - "Use 'predicate_expr' instead.", - FutureWarning, - stacklevel=2, - ) - - # Ensure exactly one of fn, filter_expr, or predicate_expr is provided - provided_params = sum( - [fn is not None, filter_expr is not None, predicate_expr is not None] - ) + # Ensure exactly one of fn, or predicate_expr is provided + provided_params = sum([fn is not None, predicate_expr is not None]) if provided_params != 1: raise ValueError( - "Exactly one of 'fn', 'filter_expr', or 'predicate_expr' must be provided" + "Exactly one of 'fn', or 'predicate_expr' must be provided" ) - self._filter_expr = filter_expr self._predicate_expr = predicate_expr super().__init__( diff --git a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py index 938c2a2d21fc..304c79c4c05c 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py @@ -5,6 +5,8 @@ import pyarrow.compute as pc import pyarrow.dataset as ds +from ray.data.expressions import Expr + logger = logging.getLogger(__name__) @@ -36,8 +38,29 @@ def get_filters(expression: str) -> ds.Expression: logger.exception(f"Error processing expression: {e}") raise + @staticmethod + def get_ray_data_expression(expression: str) -> "Expr": + """Parse and evaluate the expression to generate a Ray Data expression. + + Args: + expression: A string representing the filter expression to parse. + + Returns: + A Ray Data Expr object for filtering data. + + """ + try: + tree = ast.parse(expression, mode="eval") + return _ConvertToRayDataExpressionVisitor().visit(tree.body) + except SyntaxError as e: + raise ValueError(f"Invalid syntax in the expression: {expression}") from e + except Exception as e: + logger.exception(f"Error processing expression: {e}") + raise + class _ConvertToArrowExpressionVisitor(ast.NodeVisitor): + # TODO: Deprecate this visitor after we remove string support in filter API. def visit_Compare(self, node: ast.Compare) -> ds.Expression: """Handle comparison operations (e.g., a == b, a < b, a in b). @@ -234,3 +257,150 @@ def visit_Call(self, node: ast.Call) -> ds.Expression: return function_map[func_name](*args) else: raise ValueError(f"Unsupported function: {func_name}") + + +class _ConvertToRayDataExpressionVisitor(ast.NodeVisitor): + """AST visitor that converts string expressions to Ray Data expressions.""" + + def visit_Compare(self, node: ast.Compare) -> "Expr": + """Handle comparison operations (e.g., a == b, a < b, a in b).""" + from ray.data.expressions import BinaryExpr, Operation + + if len(node.ops) != 1 or len(node.comparators) != 1: + raise ValueError("Only simple binary comparisons are supported") + + left = self.visit(node.left) + right = self.visit(node.comparators[0]) + op = node.ops[0] + + # Map AST comparison operators to Ray Data operations + op_map = { + ast.Eq: Operation.EQ, + ast.NotEq: Operation.NE, + ast.Lt: Operation.LT, + ast.LtE: Operation.LE, + ast.Gt: Operation.GT, + ast.GtE: Operation.GE, + ast.In: Operation.IN, + ast.NotIn: Operation.NOT_IN, + } + + if type(op) not in op_map: + raise ValueError(f"Unsupported comparison operator: {type(op).__name__}") + + return BinaryExpr(op_map[type(op)], left, right) + + def visit_BoolOp(self, node: ast.BoolOp) -> "Expr": + """Handle logical operations (e.g., a and b, a or b).""" + from ray.data.expressions import BinaryExpr, Operation + + conditions = [self.visit(value) for value in node.values] + combined_expr = conditions[0] + + for condition in conditions[1:]: + if isinstance(node.op, ast.And): + combined_expr = BinaryExpr(Operation.AND, combined_expr, condition) + elif isinstance(node.op, ast.Or): + combined_expr = BinaryExpr(Operation.OR, combined_expr, condition) + else: + raise ValueError( + f"Unsupported logical operator: {type(node.op).__name__}" + ) + + return combined_expr + + def visit_UnaryOp(self, node: ast.UnaryOp) -> "Expr": + """Handle unary operations (e.g., not a, -5).""" + from ray.data.expressions import Operation, UnaryExpr, lit + + if isinstance(node.op, ast.Not): + operand = self.visit(node.operand) + return UnaryExpr(Operation.NOT, operand) + elif isinstance(node.op, ast.USub): + # For negative numbers, treat as literal + if isinstance(node.operand, ast.Constant): + return lit(-node.operand.value) + else: + raise ValueError("Unary minus only supported for constant values") + else: + raise ValueError(f"Unsupported unary operator: {type(node.op).__name__}") + + def visit_Name(self, node: ast.Name) -> "Expr": + """Handle variable names (column references).""" + from ray.data.expressions import col + + return col(node.id) + + def visit_Constant(self, node: ast.Constant) -> "Expr": + """Handle constant values (numbers, strings, booleans).""" + from ray.data.expressions import lit + + return lit(node.value) + + def visit_List(self, node: ast.List) -> "Expr": + """Handle list literals.""" + from ray.data.expressions import LiteralExpr, lit + + # Visit all elements first + visited_elements = [self.visit(elt) for elt in node.elts] + + # Try to extract constant values for literal list + elements = [] + for elem in visited_elements: + if isinstance(elem, LiteralExpr): + elements.append(elem.value) + else: + # For compatibility with Arrow visitor, we need to support non-literals + # but Ray Data expressions may have limitations here + raise ValueError( + "List contains non-constant expressions. Ray Data expressions " + "currently only support lists of constant values for 'in' operations." + ) + + return lit(elements) + + def visit_Attribute(self, node: ast.Attribute) -> "Expr": + """Handle attribute access (e.g., for nested column names).""" + from ray.data.expressions import col + + # For nested column names like "user.age", combine them with dots + if isinstance(node.value, ast.Name): + return col(f"{node.value.id}.{node.attr}") + elif isinstance(node.value, ast.Attribute): + # Recursively handle nested attributes + left_expr = self.visit(node.value) + if hasattr(left_expr, "_name"): # ColumnExpr + return col(f"{left_expr._name}.{node.attr}") + + raise ValueError(f"Unsupported attribute access: {node.attr}") + + def visit_Call(self, node: ast.Call) -> "Expr": + """Handle function calls for operations like is_null, is_not_null, is_nan.""" + from ray.data.expressions import BinaryExpr, Operation, UnaryExpr + + func_name = node.func.id if isinstance(node.func, ast.Name) else str(node.func) + + if func_name == "is_null": + if len(node.args) != 1: + raise ValueError("is_null() expects exactly one argument") + operand = self.visit(node.args[0]) + return UnaryExpr(Operation.IS_NULL, operand) + elif func_name == "is_valid" or func_name == "is_not_null": + if len(node.args) != 1: + raise ValueError(f"{func_name}() expects exactly one argument") + operand = self.visit(node.args[0]) + return UnaryExpr(Operation.IS_NOT_NULL, operand) + elif func_name == "is_nan": + if len(node.args) != 1: + raise ValueError("is_nan() expects exactly one argument") + operand = self.visit(node.args[0]) + # Use x != x pattern for NaN detection (NaN != NaN is True) + return BinaryExpr(Operation.NE, operand, operand) + elif func_name == "is_in": + if len(node.args) != 2: + raise ValueError("is_in() expects exactly two arguments") + left = self.visit(node.args[0]) + right = self.visit(node.args[1]) + return BinaryExpr(Operation.IN, left, right) + else: + raise ValueError(f"Unsupported function: {func_name}") diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 01504d737702..0578881620ca 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -217,36 +217,9 @@ def plan_filter_op( target_max_block_size=data_context.target_max_block_size, ) - expression = op._filter_expr predicate_expr = op._predicate_expr compute = get_compute(op._compute) - if expression is not None: - - def filter_block_fn( - blocks: Iterable[Block], ctx: TaskContext - ) -> Iterable[Block]: - for block in blocks: - try: - # Convert block to Arrow table and apply expression filter - if isinstance(block, pa.Table): - filtered_block = block.filter(expression) - else: - # Convert to Arrow first if needed - block_accessor = BlockAccessor.for_block(block) - arrow_block = block_accessor.to_arrow() - filtered_block = arrow_block.filter(expression) - yield filtered_block - except Exception as e: - _try_wrap_udf_exception(e) - - init_fn = None - transform_fn = BlockMapTransformFn( - filter_block_fn, - is_udf=True, - output_block_size_option=output_block_size_option, - ) - - elif predicate_expr is not None: + if predicate_expr is not None: def filter_block_fn( blocks: Iterable[Block], ctx: TaskContext diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 3ad7cd4502c3..203e41e81d68 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1544,13 +1544,13 @@ def _check_fn_params_incompatible(param_type): # TODO: (srinathk) bind the expression to the actual schema. # If expr is a string, convert it to a pyarrow.dataset.Expression # Initialize ExpressionEvaluator with valid columns, if available - # TODO: Rewrite the converter in Ray Data's Expression system. - resolved_expr = ExpressionEvaluator.get_filters(expression=expr) + # str -> Ray Data's Expression + resolved_expr = ExpressionEvaluator.get_ray_data_expression(expr) compute = TaskPoolStrategy(size=concurrency) filter_kwargs.update( { - "filter_expr": resolved_expr, + "predicate_expr": resolved_expr, "compute": compute, } ) diff --git a/python/ray/data/tests/test_filter.py b/python/ray/data/tests/test_filter.py index 3a6a44f211e0..b2bed51f2d90 100644 --- a/python/ray/data/tests/test_filter.py +++ b/python/ray/data/tests/test_filter.py @@ -5,9 +5,9 @@ from pkg_resources import parse_version import ray -from ray.data.exceptions import UserCodeException from ray.data.expressions import col from ray.data.tests.conftest import get_pyarrow_version +from ray.tests.conftest import * # noqa def test_filter_mutex( @@ -121,7 +121,7 @@ def test_filter_with_invalid_expression( parquet_ds.filter(expr="fake_news super fake") fake_column_ds = parquet_ds.filter(expr="sepal_length_123 > 1") - with pytest.raises(UserCodeException): + with pytest.raises(KeyError): fake_column_ds.to_pandas() From 9884ff8836803833776b54ce28bfa1e009a3c71b Mon Sep 17 00:00:00 2001 From: "Goutam V." Date: Sun, 28 Sep 2025 22:52:58 -0700 Subject: [PATCH 07/10] Fix test Signed-off-by: Goutam V. --- python/ray/data/tests/test_filter.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/python/ray/data/tests/test_filter.py b/python/ray/data/tests/test_filter.py index b2bed51f2d90..c7d101745e17 100644 --- a/python/ray/data/tests/test_filter.py +++ b/python/ray/data/tests/test_filter.py @@ -10,9 +10,7 @@ from ray.tests.conftest import * # noqa -def test_filter_mutex( - ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default -): +def test_filter_mutex(ray_start_regular_shared, tmp_path): """Test filter op.""" # Generate sample data @@ -46,9 +44,7 @@ def test_filter_mutex( parquet_ds.filter(fn="sepal.length > 5.0") -def test_filter_with_expressions( - ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default -): +def test_filter_with_expressions(ray_start_regular_shared, tmp_path): """Test filtering with expressions.""" # Generate sample data @@ -93,9 +89,7 @@ def test_filter_with_expressions( ), "UDF-filtered data contains rows with 'sepal.length' <= 5.0" -def test_filter_with_invalid_expression( - ray_start_regular_shared, tmp_path, target_max_block_size_infinite_or_default -): +def test_filter_with_invalid_expression(ray_start_regular_shared, tmp_path): """Test filtering with invalid expressions.""" # Generate sample data @@ -281,7 +275,6 @@ def test_filter_with_predicate_expressions( test_data, expected_indices, test_description, - target_max_block_size_infinite_or_default, ): """Test filter() with Ray Data predicate expressions.""" # Create dataset from test data @@ -309,9 +302,7 @@ def test_filter_with_predicate_expressions( get_pyarrow_version() < parse_version("20.0.0"), reason="predicate expressions require PyArrow >= 20.0.0", ) -def test_filter_predicate_expr_vs_function_consistency( - ray_start_regular_shared, target_max_block_size_infinite_or_default -): +def test_filter_predicate_expr_vs_function_consistency(ray_start_regular_shared): """Test that predicate expressions produce the same results as equivalent functions.""" test_data = [ {"age": 20, "score": 85, "active": True}, @@ -340,9 +331,7 @@ def test_filter_predicate_expr_vs_function_consistency( get_pyarrow_version() < parse_version("20.0.0"), reason="predicate expressions require PyArrow >= 20.0.0", ) -def test_filter_predicate_with_different_block_formats( - ray_start_regular_shared, target_max_block_size_infinite_or_default -): +def test_filter_predicate_with_different_block_formats(ray_start_regular_shared): """Test that predicate expressions work with different block formats (pandas/arrow).""" test_data = [ {"category": "A", "value": 10}, From 1f6872c0bd163ee704c0f29ef97051248f1464b3 Mon Sep 17 00:00:00 2001 From: "Goutam V." Date: Sun, 28 Sep 2025 23:10:26 -0700 Subject: [PATCH 08/10] Fix bazel file Signed-off-by: Goutam V. --- python/ray/data/BUILD.bazel | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/ray/data/BUILD.bazel b/python/ray/data/BUILD.bazel index a1c34b5e0a1d..af837f255b6c 100644 --- a/python/ray/data/BUILD.bazel +++ b/python/ray/data/BUILD.bazel @@ -708,6 +708,10 @@ py_test( "exclusive", "team:data", ], + deps = [ + ":conftest", + "//:ray_lib", + ], ) py_test( From 86fb5337f25e7cc884c7cde8685906dbac6a5508 Mon Sep 17 00:00:00 2001 From: "Goutam V." Date: Tue, 30 Sep 2025 00:30:38 -0700 Subject: [PATCH 09/10] Address comments Signed-off-by: Goutam V. --- .../logical/operators/map_operator.py | 2 +- .../plan_expression/expression_evaluator.py | 13 ++-- python/ray/data/dataset.py | 68 +++++++++---------- 3 files changed, 40 insertions(+), 43 deletions(-) diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index 0a734e841abe..63cf35410237 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -245,7 +245,7 @@ def __init__( provided_params = sum([fn is not None, predicate_expr is not None]) if provided_params != 1: raise ValueError( - "Exactly one of 'fn', or 'predicate_expr' must be provided" + f"Exactly one of 'fn', or 'predicate_expr' must be provided (received fn={fn}, predicate_expr={predicate_expr})" ) self._predicate_expr = predicate_expr diff --git a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py index 304c79c4c05c..eb39cbeaef1f 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py @@ -5,7 +5,7 @@ import pyarrow.compute as pc import pyarrow.dataset as ds -from ray.data.expressions import Expr +from ray.data.expressions import ColumnExpr, Expr logger = logging.getLogger(__name__) @@ -39,7 +39,7 @@ def get_filters(expression: str) -> ds.Expression: raise @staticmethod - def get_ray_data_expression(expression: str) -> "Expr": + def parse_native_expression(expression: str) -> "Expr": """Parse and evaluate the expression to generate a Ray Data expression. Args: @@ -51,7 +51,7 @@ def get_ray_data_expression(expression: str) -> "Expr": """ try: tree = ast.parse(expression, mode="eval") - return _ConvertToRayDataExpressionVisitor().visit(tree.body) + return _ConvertToNativeExpressionVisitor().visit(tree.body) except SyntaxError as e: raise ValueError(f"Invalid syntax in the expression: {expression}") from e except Exception as e: @@ -259,7 +259,7 @@ def visit_Call(self, node: ast.Call) -> ds.Expression: raise ValueError(f"Unsupported function: {func_name}") -class _ConvertToRayDataExpressionVisitor(ast.NodeVisitor): +class _ConvertToNativeExpressionVisitor(ast.NodeVisitor): """AST visitor that converts string expressions to Ray Data expressions.""" def visit_Compare(self, node: ast.Compare) -> "Expr": @@ -354,7 +354,7 @@ def visit_List(self, node: ast.List) -> "Expr": # but Ray Data expressions may have limitations here raise ValueError( "List contains non-constant expressions. Ray Data expressions " - "currently only support lists of constant values for 'in' operations." + "currently only support lists of constant values." ) return lit(elements) @@ -369,7 +369,7 @@ def visit_Attribute(self, node: ast.Attribute) -> "Expr": elif isinstance(node.value, ast.Attribute): # Recursively handle nested attributes left_expr = self.visit(node.value) - if hasattr(left_expr, "_name"): # ColumnExpr + if isinstance(left_expr, ColumnExpr): return col(f"{left_expr._name}.{node.attr}") raise ValueError(f"Unsupported attribute access: {node.attr}") @@ -385,6 +385,7 @@ def visit_Call(self, node: ast.Call) -> "Expr": raise ValueError("is_null() expects exactly one argument") operand = self.visit(node.args[0]) return UnaryExpr(Operation.IS_NULL, operand) + # Adding this conditional to keep it consistent with the existing implementation. elif func_name == "is_valid" or func_name == "is_not_null": if len(node.args) != 1: raise ValueError(f"{func_name}() expects exactly one argument") diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 203e41e81d68..b7d16a33ebb1 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1515,12 +1515,15 @@ def _check_fn_params_incompatible(param_type): ray_remote_args, ) - # Prepare common Filter operator arguments - filter_kwargs = { - "input_op": self._logical_plan.dag, - "ray_remote_args_fn": ray_remote_args_fn, - "ray_remote_args": ray_remote_args, - } + # Initialize Filter operator arguments with proper types + input_op = self._logical_plan.dag + predicate_expr: Optional[Expr] = None + filter_fn: Optional[UserDefinedFunction] = None + filter_fn_args: Optional[Iterable[Any]] = None + filter_fn_kwargs: Optional[Dict[str, Any]] = None + filter_fn_constructor_args: Optional[Iterable[Any]] = None + filter_fn_constructor_kwargs: Optional[Dict[str, Any]] = None + filter_compute: Optional[ComputeStrategy] = None if expr is not None: _check_fn_params_incompatible("expr") @@ -1545,24 +1548,12 @@ def _check_fn_params_incompatible(param_type): # If expr is a string, convert it to a pyarrow.dataset.Expression # Initialize ExpressionEvaluator with valid columns, if available # str -> Ray Data's Expression - resolved_expr = ExpressionEvaluator.get_ray_data_expression(expr) - - compute = TaskPoolStrategy(size=concurrency) - filter_kwargs.update( - { - "predicate_expr": resolved_expr, - "compute": compute, - } - ) + predicate_expr = ExpressionEvaluator.parse_native_expression(expr) else: # expr is an Expr object (predicate expression) - compute = TaskPoolStrategy(size=concurrency) - filter_kwargs.update( - { - "predicate_expr": expr, - "compute": compute, - } - ) + predicate_expr = expr + + filter_compute = TaskPoolStrategy(size=concurrency) else: warnings.warn( "Use 'expr' instead of 'fn' when possible for performant filters." @@ -1574,26 +1565,31 @@ def _check_fn_params_incompatible(param_type): f"{type(fn).__name__} instead." ) - compute = get_compute_strategy( + filter_fn = fn + filter_fn_args = fn_args + filter_fn_kwargs = fn_kwargs + filter_fn_constructor_args = fn_constructor_args + filter_fn_constructor_kwargs = fn_constructor_kwargs + filter_compute = get_compute_strategy( fn=fn, fn_constructor_args=fn_constructor_args, compute=compute, concurrency=concurrency, ) - filter_kwargs.update( - { - "fn": fn, - "fn_args": fn_args, - "fn_kwargs": fn_kwargs, - "fn_constructor_args": fn_constructor_args, - "fn_constructor_kwargs": fn_constructor_kwargs, - "compute": compute, - } - ) - - # Create Filter operator with consolidated arguments - filter_op = Filter(**filter_kwargs) + # Create Filter operator with explicitly typed arguments + filter_op = Filter( + input_op=input_op, + predicate_expr=predicate_expr, + fn=filter_fn, + fn_args=filter_fn_args, + fn_kwargs=filter_fn_kwargs, + fn_constructor_args=filter_fn_constructor_args, + fn_constructor_kwargs=filter_fn_constructor_kwargs, + compute=filter_compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) plan = self._plan.copy() logical_plan = LogicalPlan(filter_op, self.context) From 633d5e2e34a3075856e66c012ec477efdfb71995 Mon Sep 17 00:00:00 2001 From: "Goutam V." Date: Tue, 30 Sep 2025 16:50:00 -0700 Subject: [PATCH 10/10] Comments Signed-off-by: Goutam V. --- .../plan_expression/expression_evaluator.py | 14 +++++++------- python/ray/data/dataset.py | 4 +--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py index eb39cbeaef1f..df2a1c066c3c 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py @@ -317,11 +317,8 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> "Expr": operand = self.visit(node.operand) return UnaryExpr(Operation.NOT, operand) elif isinstance(node.op, ast.USub): - # For negative numbers, treat as literal - if isinstance(node.operand, ast.Constant): - return lit(-node.operand.value) - else: - raise ValueError("Unary minus only supported for constant values") + operand = self.visit(node.operand) + return operand * lit(-1) else: raise ValueError(f"Unsupported unary operator: {type(node.op).__name__}") @@ -372,7 +369,9 @@ def visit_Attribute(self, node: ast.Attribute) -> "Expr": if isinstance(left_expr, ColumnExpr): return col(f"{left_expr._name}.{node.attr}") - raise ValueError(f"Unsupported attribute access: {node.attr}") + raise ValueError( + f"Unsupported attribute access: {node.attr}. Node details: {ast.dump(node)}" + ) def visit_Call(self, node: ast.Call) -> "Expr": """Handle function calls for operations like is_null, is_not_null, is_nan.""" @@ -385,7 +384,8 @@ def visit_Call(self, node: ast.Call) -> "Expr": raise ValueError("is_null() expects exactly one argument") operand = self.visit(node.args[0]) return UnaryExpr(Operation.IS_NULL, operand) - # Adding this conditional to keep it consistent with the existing implementation. + # Adding this conditional to keep it consistent with the current implementation, + # of carrying Pyarrow's semantic of `is_valid` elif func_name == "is_valid" or func_name == "is_not_null": if len(node.args) != 1: raise ValueError(f"{func_name}() expects exactly one argument") diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index b7d16a33ebb1..b70cd4ea6afd 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1390,9 +1390,7 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]: @PublicAPI(api_group=BT_API_GROUP) def filter( self, - fn: Optional[ - UserDefinedFunction[Dict[str, Any], bool] - ] = None, # TODO: Deprecate this parameter in favor of expr + fn: Optional[UserDefinedFunction[Dict[str, Any], bool]] = None, expr: Optional[Union[str, Expr]] = None, *, compute: Union[str, ComputeStrategy] = None,