Skip to content

Commit 5cb5dd7

Browse files
[Data] - Add alias expression (ray-project#56550)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? This change adds support for renaming resultant expressions via `alias()`. As a result of this change, we can eventually consolidate the Project operator to use only expressions instead of having to support `cols` and `cols_rename` in addition to expressions. ## Related issue number <!-- For example: "Closes ray-project#1234" --> ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Goutam V. <goutam@anyscale.com> Signed-off-by: Future-Outlier <eric901201@gmail.com>
1 parent 15b6365 commit 5cb5dd7

File tree

7 files changed

+226
-14
lines changed

7 files changed

+226
-14
lines changed

python/ray/data/_expression_evaluator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from ray.data.block import DataBatch
1212
from ray.data.expressions import (
13+
AliasExpr,
1314
BinaryExpr,
1415
ColumnExpr,
1516
Expr,
@@ -165,6 +166,10 @@ def _eval_expr_recursive(
165166

166167
return result
167168

169+
if isinstance(expr, AliasExpr):
170+
# The renaming of the column is handled in the project op planner stage.
171+
return _eval_expr_recursive(expr.expr, batch, ops)
172+
168173
raise TypeError(f"Unsupported expression node: {type(expr).__name__}")
169174

170175

python/ray/data/_internal/arrow_block.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,18 +206,21 @@ def column_names(self) -> List[str]:
206206
return self._table.column_names
207207

208208
def fill_column(self, name: str, value: Any) -> Block:
209-
assert name not in self._table.column_names
210-
211209
import pyarrow.compute as pc
212210

213-
if isinstance(value, pyarrow.Scalar):
214-
type = value.type
211+
# Check if value is array-like - if so, use upsert_column logic
212+
if isinstance(value, (pyarrow.Array, pyarrow.ChunkedArray)):
213+
return self.upsert_column(name, value)
215214
else:
216-
type = pyarrow.infer_type([value])
215+
# Scalar value - use original fill_column logic
216+
if isinstance(value, pyarrow.Scalar):
217+
type = value.type
218+
else:
219+
type = pyarrow.infer_type([value])
217220

218-
array = pyarrow.nulls(len(self._table), type=type)
219-
array = pc.fill_null(array, value)
220-
return self._table.append_column(name, array)
221+
array = pyarrow.nulls(len(self._table), type=type)
222+
array = pc.fill_null(array, value)
223+
return self._table.append_column(name, array)
221224

222225
@classmethod
223226
def from_bytes(cls, data: bytes) -> "ArrowBlockAccessor":

python/ray/data/_internal/pandas_block.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,10 @@ def column_names(self) -> List[str]:
289289
return self._table.columns.tolist()
290290

291291
def fill_column(self, name: str, value: Any) -> Block:
292-
assert name not in self._table.columns
293-
292+
# Check if value is array-like - if so, use upsert_column logic
293+
if isinstance(value, (pd.Series, np.ndarray)):
294+
return self.upsert_column(name, value)
295+
# Scalar value - use original fill_column logic
294296
return self._table.assign(**{name: value})
295297

296298
@staticmethod

python/ray/data/_internal/planner/plan_udf_map_op.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,14 @@ def _project_block(block: Block) -> Block:
130130
# Add/update with expression results
131131
result_block = block
132132
for name, expr in exprs.items():
133+
# Use expr.name if available, otherwise fall back to the dict key name
134+
actual_name = expr.name if expr.name is not None else name
133135
result = eval_expr(expr, result_block)
134136
result_block_accessor = BlockAccessor.for_block(result_block)
135-
result_block = result_block_accessor.upsert_column(name, result)
136-
137+
# fill_column handles both scalars and arrays
138+
result_block = result_block_accessor.fill_column(
139+
actual_name, result
140+
)
137141
block = result_block
138142

139143
# 2. (optional) column projection

python/ray/data/expressions.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@ class Expr(ABC):
8686

8787
data_type: DataType
8888

89+
@property
90+
def name(self) -> str | None:
91+
"""Get the name associated with this expression.
92+
93+
Returns:
94+
The name for expressions that have one (ColumnExpr, AliasExpr),
95+
None otherwise.
96+
"""
97+
return None
98+
8999
@abstractmethod
90100
def structurally_equals(self, other: Any) -> bool:
91101
"""Compare two expression ASTs for structural equality."""
@@ -208,6 +218,27 @@ def not_in(self, values: Union[List[Any], "Expr"]) -> "Expr":
208218
values = LiteralExpr(values)
209219
return self._bin(values, Operation.NOT_IN)
210220

221+
def alias(self, name: str) -> "Expr":
222+
"""Rename the expression.
223+
224+
This method allows you to assign a new name to an expression result.
225+
This is particularly useful when you want to specify the output column name
226+
directly within the expression rather than as a separate parameter.
227+
228+
Args:
229+
name: The new name for the expression
230+
231+
Returns:
232+
An AliasExpr that wraps this expression with the specified name
233+
234+
Example:
235+
>>> from ray.data.expressions import col, lit
236+
>>> # Create an expression with a new aliased name
237+
>>> expr = (col("price") * col("quantity")).alias("total")
238+
>>> # Can be used with Dataset operations that support named expressions
239+
"""
240+
return AliasExpr(data_type=self.data_type, expr=self, _name=name)
241+
211242

212243
@DeveloperAPI(stability="alpha")
213244
@dataclass(frozen=True, eq=False)
@@ -227,9 +258,14 @@ class ColumnExpr(Expr):
227258
>>> age_expr = col("age") # Creates ColumnExpr(name="age")
228259
"""
229260

230-
name: str
261+
_name: str
231262
data_type: DataType = field(default_factory=lambda: DataType(object), init=False)
232263

264+
@property
265+
def name(self) -> str:
266+
"""Get the column name."""
267+
return self._name
268+
233269
def structurally_equals(self, other: Any) -> bool:
234270
return isinstance(other, ColumnExpr) and self.name == other.name
235271

@@ -498,6 +534,27 @@ def structurally_equals(self, other: Any) -> bool:
498534
)
499535

500536

537+
@DeveloperAPI(stability="alpha")
538+
@dataclass(frozen=True, eq=False)
539+
class AliasExpr(Expr):
540+
"""Expression that represents an alias for an expression."""
541+
542+
expr: Expr
543+
_name: str
544+
545+
@property
546+
def name(self) -> str:
547+
"""Get the alias name."""
548+
return self._name
549+
550+
def structurally_equals(self, other: Any) -> bool:
551+
return (
552+
isinstance(other, AliasExpr)
553+
and self.expr.structurally_equals(other.expr)
554+
and self.name == other.name
555+
)
556+
557+
501558
@PublicAPI(stability="beta")
502559
def col(name: str) -> ColumnExpr:
503560
"""
@@ -603,8 +660,9 @@ def download(uri_column_name: str) -> DownloadExpr:
603660
"BinaryExpr",
604661
"UnaryExpr",
605662
"UDFExpr",
606-
"udf",
607663
"DownloadExpr",
664+
"AliasExpr",
665+
"udf",
608666
"col",
609667
"lit",
610668
"download",

python/ray/data/tests/test_expressions.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,64 @@
3434
# Commutative operations are not structurally equal
3535
(col("a") + col("b"), col("b") + col("a"), False),
3636
(lit(1) * col("c"), col("c") * lit(1), False),
37+
# Alias expression tests
38+
(col("a").alias("b"), col("a").alias("b"), True),
39+
(col("a").alias("b"), col("a").alias("c"), False), # Different alias
40+
(col("a").alias("b"), col("b").alias("b"), False), # Different column
41+
((col("a") + 1).alias("result"), (col("a") + 1).alias("result"), True),
42+
(
43+
(col("a") + 1).alias("result"),
44+
(col("a") + 2).alias("result"),
45+
False,
46+
), # Different expr
47+
(col("a").alias("b"), col("a"), False), # Alias vs non-alias
3748
]
3849

3950

51+
@pytest.mark.parametrize(
52+
"expr, alias_name, expected_alias",
53+
[
54+
# (expression, alias_name, expected_alias)
55+
(col("price"), "product_price", "product_price"),
56+
(lit(42), "answer", "answer"),
57+
(col("a") + col("b"), "sum", "sum"),
58+
((col("price") * col("qty")) + lit(5), "total_with_fee", "total_with_fee"),
59+
(col("age") >= lit(18), "is_adult", "is_adult"),
60+
],
61+
ids=["col_alias", "lit_alias", "binary_alias", "complex_alias", "comparison_alias"],
62+
)
63+
def test_alias_functionality(expr, alias_name, expected_alias):
64+
"""Test alias functionality with various expression types."""
65+
import pandas as pd
66+
67+
from ray.data._expression_evaluator import eval_expr
68+
69+
# Test alias creation
70+
aliased_expr = expr.alias(alias_name)
71+
assert aliased_expr.name == expected_alias
72+
assert aliased_expr.expr.structurally_equals(expr)
73+
74+
# Test data type preservation
75+
assert aliased_expr.data_type == expr.data_type
76+
77+
# Test evaluation equivalence
78+
test_data = pd.DataFrame(
79+
{
80+
"price": [10, 20],
81+
"qty": [2, 3],
82+
"a": [1, 2],
83+
"b": [3, 4],
84+
"age": [17, 25],
85+
}
86+
)
87+
original_result = eval_expr(expr, test_data)
88+
aliased_result = eval_expr(aliased_expr, test_data)
89+
if hasattr(original_result, "equals"): # For pandas Series
90+
assert original_result.equals(aliased_result)
91+
else: # For scalars
92+
assert original_result == aliased_result
93+
94+
4095
@pytest.mark.parametrize(
4196
"expr1, expr2, expected",
4297
STRUCTURAL_EQUALITY_TEST_CASES,

python/ray/data/tests/test_map.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3248,6 +3248,91 @@ def test_with_column_filter_in_pipeline(ray_start_regular_shared):
32483248
pd.testing.assert_frame_equal(result_df, expected_df, check_dtype=False)
32493249

32503250

3251+
@pytest.mark.parametrize(
3252+
"expr_factory, expected_columns, alias_name, expected_values",
3253+
[
3254+
(
3255+
lambda: col("id").alias("new_id"),
3256+
["id", "new_id"],
3257+
"new_id",
3258+
[0, 1, 2, 3, 4], # Copy of id column
3259+
),
3260+
(
3261+
lambda: (col("id") + 1).alias("id_plus_one"),
3262+
["id", "id_plus_one"],
3263+
"id_plus_one",
3264+
[1, 2, 3, 4, 5], # id + 1
3265+
),
3266+
(
3267+
lambda: (col("id") * 2 + 5).alias("transformed"),
3268+
["id", "transformed"],
3269+
"transformed",
3270+
[5, 7, 9, 11, 13], # id * 2 + 5
3271+
),
3272+
(
3273+
lambda: lit(42).alias("constant"),
3274+
["id", "constant"],
3275+
"constant",
3276+
[42, 42, 42, 42, 42], # lit(42)
3277+
),
3278+
(
3279+
lambda: (col("id") >= 0).alias("is_non_negative"),
3280+
["id", "is_non_negative"],
3281+
"is_non_negative",
3282+
[True, True, True, True, True], # id >= 0
3283+
),
3284+
(
3285+
lambda: (col("id") + 1).alias("id"),
3286+
["id"], # Only one column since we're overwriting id
3287+
"id",
3288+
[1, 2, 3, 4, 5], # id + 1 replaces original id
3289+
),
3290+
],
3291+
ids=[
3292+
"col_alias",
3293+
"arithmetic_alias",
3294+
"complex_alias",
3295+
"literal_alias",
3296+
"comparison_alias",
3297+
"overwrite_existing_column",
3298+
],
3299+
)
3300+
def test_with_column_alias_expressions(
3301+
ray_start_regular_shared,
3302+
expr_factory,
3303+
expected_columns,
3304+
alias_name,
3305+
expected_values,
3306+
):
3307+
"""Test that alias expressions work correctly with with_column."""
3308+
expr = expr_factory()
3309+
3310+
# Verify the alias name matches what we expect
3311+
assert expr.name == alias_name
3312+
3313+
# Apply the aliased expression
3314+
ds = ray.data.range(5).with_column(alias_name, expr)
3315+
3316+
# Convert to pandas for comprehensive comparison
3317+
result_df = ds.to_pandas()
3318+
3319+
# Create expected DataFrame
3320+
expected_df = pd.DataFrame({"id": [0, 1, 2, 3, 4], alias_name: expected_values})
3321+
3322+
# Ensure column order matches expected_columns
3323+
expected_df = expected_df[expected_columns]
3324+
3325+
# Assert the entire DataFrame is equal
3326+
pd.testing.assert_frame_equal(result_df, expected_df)
3327+
# Verify the alias expression evaluates the same as the non-aliased version
3328+
non_aliased_expr = expr
3329+
ds_non_aliased = ray.data.range(5).with_column(alias_name, non_aliased_expr)
3330+
3331+
non_aliased_df = ds_non_aliased.to_pandas()
3332+
3333+
pd.testing.assert_frame_equal(result_df, non_aliased_df)
3334+
3335+
32513336
if __name__ == "__main__":
32523337
import sys
32533338

0 commit comments

Comments
 (0)