Skip to content

Commit 3ac303a

Browse files
murray-dsclaude
andcommitted
refactor(segstats): make _rewrite_to_grouping_sets a void in-place mutator
It mutated the sqlglot tree in place but also returned it. Per the usual "in-place -> return None" convention it now returns None, and the sole caller uses the mutated tree directly. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent f6a9f4d commit 3ac303a

1 file changed

Lines changed: 6 additions & 11 deletions

File tree

openretailscience/segmentation/segstats.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,18 @@ def _rewrite_to_grouping_sets(
102102
grouping_sets: list[tuple[str, ...]],
103103
segment_col: list[str],
104104
flag_names: list[str],
105-
) -> exp.Select:
106-
"""Rewrite an ibis-compiled aggregation AST into ``GROUP BY GROUPING SETS`` with ``GROUPING()`` flags.
105+
) -> None:
106+
"""Rewrite an ibis-compiled aggregation AST in place into ``GROUP BY GROUPING SETS`` with ``GROUPING()`` flags.
107107
108108
Operates on the sqlglot AST ibis produced (so aggregate expressions, quoting, and dialect quirks
109109
are already correct); swaps the grouping clause and appends one ``GROUPING(col)`` flag column per
110-
segment column, used downstream to apply rollup labels. Mutates ``tree`` in place.
110+
segment column, used downstream to apply rollup labels.
111111
112112
Args:
113-
tree (exp.Select): The base aggregation SELECT, grouped by all of ``segment_col``.
113+
tree (exp.Select): The base aggregation SELECT, grouped by all of ``segment_col``. Mutated in place.
114114
grouping_sets (list[tuple[str, ...]]): The grouping sets to emit (tuples of column names).
115115
segment_col (list[str]): All segment columns, in the order they appear in the GROUP BY.
116116
flag_names (list[str]): Output names for the per-segment GROUPING() flag columns.
117-
118-
Returns:
119-
exp.Select: The rewritten SELECT (the same, mutated, ``tree``).
120117
"""
121118
resolved = [_resolve_group_key(k, tree.selects) for k in tree.args["group"].expressions]
122119
key_by_name = dict(zip(segment_col, resolved, strict=True))
@@ -126,7 +123,6 @@ def _rewrite_to_grouping_sets(
126123

127124
tuples = [exp.Tuple(expressions=[key_by_name[col].copy() for col in gs]) for gs in grouping_sets]
128125
tree.set("group", exp.Group(expressions=[exp.GroupingSets(expressions=tuples)]))
129-
return tree
130126

131127

132128
def cube(*columns: str) -> list[tuple[str, ...]]:
@@ -875,9 +871,8 @@ def _execute_grouping_sets_native(
875871
tree = compiler.to_sqlglot(base.unbind())
876872
# to_sqlglot returns a list of statements on some backends (e.g. BigQuery scalar UDFs); the SELECT is last.
877873
tree = tree[-1] if isinstance(tree, list) else tree
878-
native_sql = _rewrite_to_grouping_sets(tree, grouping_sets, segment_col, flag_names).sql(
879-
dialect=compiler.dialect,
880-
)
874+
_rewrite_to_grouping_sets(tree, grouping_sets, segment_col, flag_names)
875+
native_sql = tree.sql(dialect=compiler.dialect)
881876

882877
# Pass the known schema so con.sql skips inference (it mis-types float aggregates on e.g. Oracle).
883878
out_schema = ibis.schema({**base.schema(), **dict.fromkeys(flag_names, "int32")})

0 commit comments

Comments
 (0)