Skip to content

Commit 396d8a9

Browse files
authored
Enhance partition_by to support strings (#1191)
1 parent dbd2c65 commit 396d8a9

File tree

3 files changed

+87
-2
lines changed

3 files changed

+87
-2
lines changed

src/datachain/lib/dc/datachain.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import orjson
2222
import sqlalchemy
2323
from pydantic import BaseModel
24+
from sqlalchemy.sql.elements import ColumnElement
2425
from tqdm import tqdm
2526

2627
from datachain import semver
@@ -806,11 +807,35 @@ def agg_sum(
806807
chain.save("new_dataset")
807808
```
808809
"""
810+
# Convert string partition_by parameters to Column objects
811+
processed_partition_by = partition_by
812+
if partition_by is not None:
813+
if isinstance(partition_by, (str, Function, ColumnElement)):
814+
list_partition_by = [partition_by]
815+
else:
816+
list_partition_by = list(partition_by)
817+
818+
processed_partition_columns: list[ColumnElement] = []
819+
for col in list_partition_by:
820+
if isinstance(col, str):
821+
col_db_name = ColumnMeta.to_db_name(col)
822+
col_type = self.signals_schema.get_column_type(col_db_name)
823+
column = Column(col_db_name, python_to_sql(col_type))
824+
processed_partition_columns.append(column)
825+
elif isinstance(col, Function):
826+
column = col.get_column(self.signals_schema)
827+
processed_partition_columns.append(column)
828+
else:
829+
# Assume it's already a ColumnElement
830+
processed_partition_columns.append(col)
831+
832+
processed_partition_by = processed_partition_columns
833+
809834
udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
810835
return self._evolve(
811836
query=self._query.generate(
812837
udf_obj.to_udf_wrapper(),
813-
partition_by=partition_by,
838+
partition_by=processed_partition_by,
814839
**self._settings.to_dict(),
815840
),
816841
signal_schema=udf_obj.output,

src/datachain/query/dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@
8282
INSERT_BATCH_SIZE = 10000
8383

8484
PartitionByType = Union[
85-
Function, ColumnElement, Sequence[Union[Function, ColumnElement]]
85+
str,
86+
Function,
87+
ColumnElement,
88+
Sequence[Union[str, Function, ColumnElement]],
8689
]
8790
JoinPredicateType = Union[str, ColumnClause, ColumnElement]
8891
DatasetDependencyType = tuple["DatasetRecord", str]

tests/unit/lib/test_datachain.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3595,3 +3595,60 @@ def test_save_create_project_not_allowed(test_session, allow_create_project):
35953595
dc.read_values(fib=[1, 1, 2, 3, 5, 8], session=test_session).save(
35963596
"dev.numbers.fibonacci"
35973597
)
3598+
3599+
3600+
def test_agg_partition_by_string_notation(test_session):
3601+
"""Test that agg method supports string notation for partition_by."""
3602+
3603+
class _ImageGroup(BaseModel):
3604+
name: str
3605+
size: int
3606+
3607+
def func(key, val) -> Iterator[tuple[File, _ImageGroup]]:
3608+
n = "-".join(key)
3609+
v = sum(val)
3610+
yield File(path=n), _ImageGroup(name=n, size=v)
3611+
3612+
keys = ["n1", "n2", "n1"]
3613+
values = [1, 5, 9]
3614+
3615+
# Test using string notation (NEW functionality)
3616+
ds = dc.read_values(key=keys, val=values, session=test_session).agg(
3617+
x=func,
3618+
partition_by="key", # String notation instead of C("key")
3619+
)
3620+
3621+
assert ds.order_by("x_1.name").to_values("x_1.name") == ["n1-n1", "n2"]
3622+
assert ds.order_by("x_1.size").to_values("x_1.size") == [5, 10]
3623+
3624+
3625+
def test_agg_partition_by_string_sequence(test_session):
3626+
"""Test that agg method supports sequence of strings for partition_by."""
3627+
3628+
class _ImageGroup(BaseModel):
3629+
name: str
3630+
size: int
3631+
3632+
def func(key1, key2, val) -> Iterator[tuple[File, _ImageGroup]]:
3633+
n = f"{key1[0]}-{key2[0]}"
3634+
v = sum(val)
3635+
yield File(path=n), _ImageGroup(name=n, size=v)
3636+
3637+
key1_values = ["a", "a", "b"]
3638+
key2_values = ["x", "y", "x"]
3639+
values = [1, 5, 9]
3640+
3641+
# Test using sequence of strings (NEW functionality)
3642+
ds = dc.read_values(
3643+
key1=key1_values, key2=key2_values, val=values, session=test_session
3644+
).agg(
3645+
x=func,
3646+
partition_by=["key1", "key2"], # Sequence of strings
3647+
)
3648+
3649+
result_names = ds.order_by("x_1.name").to_values("x_1.name")
3650+
result_sizes = ds.order_by("x_1.size").to_values("x_1.size")
3651+
3652+
# Should have 3 partitions: (a,x), (a,y), (b,x)
3653+
assert len(result_names) == 3
3654+
assert len(result_sizes) == 3

0 commit comments

Comments
 (0)