Skip to content

Commit 0cfeb95

Browse files
authored
[Data] Compute Expressions-datetime (#58740)
## Description Completing the datetime namespace operations ## Related issues Related to #58674 ## Additional information --------- Signed-off-by: 400Ping <fourhundredping@gmail.com>
1 parent 907fdbb commit 0cfeb95

File tree

3 files changed

+193
-2
lines changed

3 files changed

+193
-2
lines changed

python/ray/data/expressions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ray.util.annotations import DeveloperAPI, PublicAPI
2424

2525
if TYPE_CHECKING:
26+
from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace
2627
from ray.data.namespace_expressions.list_namespace import _ListNamespace
2728
from ray.data.namespace_expressions.string_namespace import _StringNamespace
2829
from ray.data.namespace_expressions.struct_namespace import _StructNamespace
@@ -486,6 +487,13 @@ def struct(self) -> "_StructNamespace":
486487

487488
return _StructNamespace(self)
488489

490+
@property
491+
def dt(self) -> "_DatetimeNamespace":
492+
"""Access datetime operations for this expression."""
493+
from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace
494+
495+
return _DatetimeNamespace(self)
496+
489497
def _unalias(self) -> "Expr":
490498
return self
491499

@@ -1061,6 +1069,7 @@ def download(uri_column_name: str) -> DownloadExpr:
10611069
"_ListNamespace",
10621070
"_StringNamespace",
10631071
"_StructNamespace",
1072+
"_DatetimeNamespace",
10641073
]
10651074

10661075

@@ -1078,4 +1087,8 @@ def __getattr__(name: str):
10781087
from ray.data.namespace_expressions.struct_namespace import _StructNamespace
10791088

10801089
return _StructNamespace
1090+
elif name == "_DatetimeNamespace":
1091+
from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace
1092+
1093+
return _DatetimeNamespace
10811094
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING, Callable, Literal
5+
6+
import pyarrow
7+
import pyarrow.compute as pc
8+
9+
from ray.data.datatype import DataType
10+
from ray.data.expressions import pyarrow_udf
11+
12+
if TYPE_CHECKING:
13+
from ray.data.expressions import Expr, UDFExpr
14+
15+
TemporalUnit = Literal[
16+
"year",
17+
"quarter",
18+
"month",
19+
"week",
20+
"day",
21+
"hour",
22+
"minute",
23+
"second",
24+
"millisecond",
25+
"microsecond",
26+
"nanosecond",
27+
]
28+
29+
30+
@dataclass
31+
class _DatetimeNamespace:
32+
"""Datetime namespace for operations on datetime-typed expression columns."""
33+
34+
_expr: "Expr"
35+
36+
def _unary_temporal_int(
37+
self, func: Callable[[pyarrow.Array], pyarrow.Array]
38+
) -> "UDFExpr":
39+
"""Helper for year/month/… that return int32."""
40+
41+
@pyarrow_udf(return_dtype=DataType.int32())
42+
def _udf(arr: pyarrow.Array) -> pyarrow.Array:
43+
return func(arr)
44+
45+
return _udf(self._expr)
46+
47+
# extractors
48+
49+
def year(self) -> "UDFExpr":
50+
"""Extract year component."""
51+
return self._unary_temporal_int(pc.year)
52+
53+
def month(self) -> "UDFExpr":
54+
"""Extract month component."""
55+
return self._unary_temporal_int(pc.month)
56+
57+
def day(self) -> "UDFExpr":
58+
"""Extract day component."""
59+
return self._unary_temporal_int(pc.day)
60+
61+
def hour(self) -> "UDFExpr":
62+
"""Extract hour component."""
63+
return self._unary_temporal_int(pc.hour)
64+
65+
def minute(self) -> "UDFExpr":
66+
"""Extract minute component."""
67+
return self._unary_temporal_int(pc.minute)
68+
69+
def second(self) -> "UDFExpr":
70+
"""Extract second component."""
71+
return self._unary_temporal_int(pc.second)
72+
73+
# formatting
74+
75+
def strftime(self, fmt: str) -> "UDFExpr":
76+
"""Format timestamps with a strftime pattern."""
77+
78+
@pyarrow_udf(return_dtype=DataType.string())
79+
def _format(arr: pyarrow.Array) -> pyarrow.Array:
80+
return pc.strftime(arr, format=fmt)
81+
82+
return _format(self._expr)
83+
84+
# rounding
85+
86+
def ceil(self, unit: TemporalUnit) -> "UDFExpr":
87+
"""Ceil timestamps to the next multiple of the given unit."""
88+
return_dtype = self._expr.data_type
89+
90+
@pyarrow_udf(return_dtype=return_dtype)
91+
def _ceil(arr: pyarrow.Array) -> pyarrow.Array:
92+
return pc.ceil_temporal(arr, multiple=1, unit=unit)
93+
94+
return _ceil(self._expr)
95+
96+
def floor(self, unit: TemporalUnit) -> "UDFExpr":
97+
"""Floor timestamps to the previous multiple of the given unit."""
98+
return_dtype = self._expr.data_type
99+
100+
@pyarrow_udf(return_dtype=return_dtype)
101+
def _floor(arr: pyarrow.Array) -> pyarrow.Array:
102+
return pc.floor_temporal(arr, multiple=1, unit=unit)
103+
104+
return _floor(self._expr)
105+
106+
def round(self, unit: TemporalUnit) -> "UDFExpr":
107+
"""Round timestamps to the nearest multiple of the given unit."""
108+
return_dtype = self._expr.data_type
109+
110+
@pyarrow_udf(return_dtype=return_dtype)
111+
def _round(arr: pyarrow.Array) -> pyarrow.Array:
112+
113+
return pc.round_temporal(arr, multiple=1, unit=unit)
114+
115+
return _round(self._expr)

python/ray/data/tests/test_namespace_expressions.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
convenient access to PyArrow compute functions through the expression API.
55
"""
66

7+
import datetime
78
from typing import Any
89

910
import pandas as pd
@@ -457,7 +458,9 @@ def test_struct_nested_field(self, dataset_format):
457458
)
458459
items_data = [
459460
{"user": {"name": "Alice", "address": {"city": "NYC", "zip": "10001"}}},
460-
{"user": {"name": "Bob", "address": {"city": "LA", "zip": "90001"}}},
461+
{
462+
"user": {"name": "Bob", "address": {"city": "LA", "zip": "90001"}},
463+
},
461464
]
462465
ds = _create_dataset(items_data, dataset_format, arrow_table)
463466

@@ -504,7 +507,9 @@ def test_struct_nested_bracket(self, dataset_format):
504507
)
505508
items_data = [
506509
{"user": {"name": "Alice", "address": {"city": "NYC", "zip": "10001"}}},
507-
{"user": {"name": "Bob", "address": {"city": "LA", "zip": "90001"}}},
510+
{
511+
"user": {"name": "Bob", "address": {"city": "LA", "zip": "90001"}},
512+
},
508513
]
509514
ds = _create_dataset(items_data, dataset_format, arrow_table)
510515

@@ -523,6 +528,64 @@ def test_struct_nested_bracket(self, dataset_format):
523528
assert rows_same(result, expected)
524529

525530

531+
# ──────────────────────────────────────
532+
# Datetime Namespace Tests
533+
# ──────────────────────────────────────
534+
535+
536+
def test_datetime_namespace_all_operations(ray_start_regular):
537+
"""Test all datetime namespace operations on a datetime column."""
538+
539+
ts = datetime.datetime(2024, 1, 2, 10, 30, 0)
540+
541+
ds = ray.data.from_items([{"ts": ts}])
542+
543+
result_ds = ds.select(
544+
[
545+
col("ts").dt.year().alias("year"),
546+
col("ts").dt.month().alias("month"),
547+
col("ts").dt.day().alias("day"),
548+
col("ts").dt.hour().alias("hour"),
549+
col("ts").dt.minute().alias("minute"),
550+
col("ts").dt.second().alias("second"),
551+
col("ts").dt.strftime("%Y-%m-%d").alias("date_str"),
552+
col("ts").dt.floor("day").alias("ts_floor"),
553+
col("ts").dt.ceil("day").alias("ts_ceil"),
554+
col("ts").dt.round("day").alias("ts_round"),
555+
]
556+
)
557+
558+
actual = result_ds.to_pandas()
559+
560+
expected = pd.DataFrame(
561+
[
562+
{
563+
"year": 2024,
564+
"month": 1,
565+
"day": 2,
566+
"hour": 10,
567+
"minute": 30,
568+
"second": 0,
569+
"date_str": "2024-01-02",
570+
"ts_floor": datetime.datetime(2024, 1, 2, 0, 0, 0),
571+
"ts_ceil": datetime.datetime(2024, 1, 3, 0, 0, 0),
572+
"ts_round": datetime.datetime(2024, 1, 3, 0, 0, 0),
573+
}
574+
]
575+
)
576+
577+
assert rows_same(actual, expected)
578+
579+
580+
def test_dt_namespace_invalid_dtype_raises(ray_start_regular):
581+
"""Test that dt namespace on non-datetime column raises an error."""
582+
583+
ds = ray.data.from_items([{"value": 1}])
584+
585+
with pytest.raises(Exception):
586+
ds.select(col("value").dt.year()).to_pandas()
587+
588+
526589
# ──────────────────────────────────────
527590
# Integration Tests
528591
# ──────────────────────────────────────

0 commit comments

Comments
 (0)