Skip to content
Merged
3 changes: 2 additions & 1 deletion doc/source/data/api/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ instantiate them directly, but you may encounter them when working with expressi
Expr
ColumnExpr
LiteralExpr
BinaryExpr
BinaryExpr
AliasExpr
15 changes: 14 additions & 1 deletion python/ray/data/_expression_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ray.data.block import DataBatch
from ray.data.expressions import (
AliasExpr,
BinaryExpr,
ColumnExpr,
Expr,
Expand Down Expand Up @@ -57,7 +58,16 @@ def _eval_expr_recursive(
if isinstance(expr, ColumnExpr):
return batch[expr.name]
if isinstance(expr, LiteralExpr):
return expr.value
# Broadcast literal value to match batch size
if isinstance(batch, pd.DataFrame):
# For pandas, create a Series with the literal value repeated
return pd.Series(expr.value, index=batch.index)
elif isinstance(batch, pa.Table):
# For Arrow, create an Array with the literal value repeated
return pa.array([expr.value] * len(batch))
else:
# Fallback for other batch types
return expr.value
if isinstance(expr, BinaryExpr):
return ops[expr.op](
_eval_expr_recursive(expr.left, batch, ops),
Expand All @@ -79,6 +89,9 @@ def _eval_expr_recursive(
)

return result
if isinstance(expr, AliasExpr):
# AliasExpr just evaluates its wrapped expression
return _eval_expr_recursive(expr.expr, batch, ops)
raise TypeError(f"Unsupported expression node: {type(expr).__name__}")


Expand Down
52 changes: 52 additions & 0 deletions python/ray/data/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,57 @@ def __or__(self, other: Any) -> "Expr":
"""Logical OR operator (|)."""
return self._bin(other, Operation.OR)

def alias(self, name: str) -> "AliasExpr":
"""Rename the expression.

This method allows you to assign a new name to an expression result.
This is particularly useful when you want to specify the output column name
directly within the expression rather than as a separate parameter.

Args:
name: The new name for the expression

Returns:
An AliasExpr that wraps this expression with the specified name

Example:
>>> from ray.data.expressions import col, lit
>>> # Create an aliased expression
>>> expr = (col("price") * col("quantity")).alias("total")
>>> # Can be used with Dataset operations that support named expressions
"""
return AliasExpr(expr=self, alias=name, data_type=self.data_type)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False)
class AliasExpr(Expr):
"""Expression that represents renaming another expression.

This expression type wraps another expression and provides it with a new name.
When evaluated, it returns the same values as the wrapped expression but
allows the result to be assigned to a different column name.

Args:
expr: The expression to rename
alias: The new name for the expression

Example:
>>> from ray.data.expressions import col
>>> # Create an aliased expression
>>> expr = col("price").alias("product_price")
"""

expr: Expr
alias: str

def structurally_equals(self, other: Any) -> bool:
return (
isinstance(other, AliasExpr)
and self.expr.structurally_equals(other.expr)
and self.alias == other.alias
)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False)
Expand Down Expand Up @@ -519,6 +570,7 @@ def download(uri_column_name: str) -> DownloadExpr:
"BinaryExpr",
"UDFExpr",
"udf",
"AliasExpr",
"DownloadExpr",
"col",
"lit",
Expand Down
66 changes: 66 additions & 0 deletions python/ray/data/tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,75 @@
# Commutative operations are not structurally equal
(col("a") + col("b"), col("b") + col("a"), False),
(lit(1) * col("c"), col("c") * lit(1), False),
# Alias expression tests
(col("a").alias("b"), col("a").alias("b"), True),
(col("a").alias("b"), col("a").alias("c"), False), # Different alias
(col("a").alias("b"), col("b").alias("b"), False), # Different column
((col("a") + 1).alias("result"), (col("a") + 1).alias("result"), True),
(
(col("a") + 1).alias("result"),
(col("a") + 2).alias("result"),
False,
), # Different expr
(col("a").alias("b"), col("a"), False), # Alias vs non-alias
]


# Parametrized test cases for alias functionality
ALIAS_TEST_CASES = [
# (expression, alias_name, expected_alias, should_match_original)
(col("price"), "product_price", "product_price", True),
(lit(42), "answer", "answer", True),
(col("a") + col("b"), "sum", "sum", True),
((col("price") * col("qty")) + lit(5), "total_with_fee", "total_with_fee", True),
(col("age") >= lit(18), "is_adult", "is_adult", True),
]


@pytest.mark.parametrize(
"expr, alias_name, expected_alias, should_match_original",
ALIAS_TEST_CASES,
ids=["col_alias", "lit_alias", "binary_alias", "complex_alias", "comparison_alias"],
)
def test_alias_functionality(expr, alias_name, expected_alias, should_match_original):
"""Test alias functionality with various expression types."""
import pandas as pd

from ray.data._expression_evaluator import eval_expr
from ray.data.expressions import AliasExpr

# Test alias creation
aliased_expr = expr.alias(alias_name)
assert isinstance(aliased_expr, AliasExpr)
assert aliased_expr.alias == expected_alias
assert aliased_expr.expr.structurally_equals(expr)

# Test data type preservation
assert aliased_expr.data_type == expr.data_type

# Test evaluation equivalence (if we can create test data)
if should_match_original:
test_data = pd.DataFrame(
{
"price": [10, 20],
"qty": [2, 3],
"a": [1, 2],
"b": [3, 4],
"age": [17, 25],
}
)
try:
original_result = eval_expr(expr, test_data)
aliased_result = eval_expr(aliased_expr, test_data)
if hasattr(original_result, "equals"): # For pandas Series
assert original_result.equals(aliased_result)
else: # For scalars
assert original_result == aliased_result
except (KeyError, TypeError):
# Skip evaluation test if columns don't exist in test data
pass


@pytest.mark.parametrize(
"expr1, expr2, expected",
STRUCTURAL_EQUALITY_TEST_CASES,
Expand Down
72 changes: 72 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2711,6 +2711,78 @@ def invalid_int_return(x: pa.Array) -> int:
assert "pandas.Series" in error_message and "numpy.ndarray" in error_message


@pytest.mark.skipif(
get_pyarrow_version() < parse_version("20.0.0"),
reason="with_column requires PyArrow >= 20.0.0",
)
@pytest.mark.parametrize(
"test_case",
[
{
"expr_factory": lambda: col("id").alias("new_id"),
"expected_columns": ["id", "new_id"],
"expected_value": 0, # First row id value
"test_name": "simple_column_alias",
},
{
"expr_factory": lambda: (col("id") + 1).alias("id_plus_one"),
"expected_columns": ["id", "id_plus_one"],
"expected_value": 1, # 0 + 1
"test_name": "arithmetic_expression_alias",
},
{
"expr_factory": lambda: (col("id") * 2 + 5).alias("transformed"),
"expected_columns": ["id", "transformed"],
"expected_value": 5, # 0 * 2 + 5
"test_name": "complex_expression_alias",
},
{
"expr_factory": lambda: lit(42).alias("constant"),
"expected_columns": ["id", "constant"],
"expected_value": 42,
"test_name": "literal_alias",
},
{
"expr_factory": lambda: (col("id") >= 0).alias("is_non_negative"),
"expected_columns": ["id", "is_non_negative"],
"expected_value": True, # 0 >= 0
"test_name": "comparison_alias",
},
],
ids=lambda test_case: test_case["test_name"],
)
def test_with_column_alias_expressions(
ray_start_regular_shared,
test_case,
):
"""Test that alias expressions work correctly with with_column."""
expr = test_case["expr_factory"]()
expected_columns = test_case["expected_columns"]
expected_value = test_case["expected_value"]

# Use alias to determine column name
column_name = expr.alias

# Apply the aliased expression
ds = ray.data.range(5).with_column(column_name, expr)

# Verify schema
result = ds.take(1)[0]
assert set(result.keys()) == set(expected_columns)

# Verify values
assert result["id"] == 0 # First row
assert result[column_name] == expected_value

# Verify the alias expression evaluates the same as the non-aliased version
# by comparing with a dataset that uses the same expression without alias
non_aliased_expr = expr.expr # Get the wrapped expression
ds_non_aliased = ray.data.range(5).with_column(column_name, non_aliased_expr)

result_non_aliased = ds_non_aliased.take(1)[0]
assert result[column_name] == result_non_aliased[column_name]


if __name__ == "__main__":
import sys

Expand Down