@@ -102,21 +102,18 @@ def _rewrite_to_grouping_sets(
102102 grouping_sets : list [tuple [str , ...]],
103103 segment_col : list [str ],
104104 flag_names : list [str ],
105- ) -> exp . Select :
106- """Rewrite an ibis-compiled aggregation AST into ``GROUP BY GROUPING SETS`` with ``GROUPING()`` flags.
105+ ) -> None :
106+ """Rewrite an ibis-compiled aggregation AST in place into ``GROUP BY GROUPING SETS`` with ``GROUPING()`` flags.
107107
108108 Operates on the sqlglot AST ibis produced (so aggregate expressions, quoting, and dialect quirks
109109 are already correct); swaps the grouping clause and appends one ``GROUPING(col)`` flag column per
110- segment column, used downstream to apply rollup labels. Mutates ``tree`` in place.
110+ segment column, used downstream to apply rollup labels.
111111
112112 Args:
113- tree (exp.Select): The base aggregation SELECT, grouped by all of ``segment_col``.
113+ tree (exp.Select): The base aggregation SELECT, grouped by all of ``segment_col``. Mutated in place.
114114 grouping_sets (list[tuple[str, ...]]): The grouping sets to emit (tuples of column names).
115115 segment_col (list[str]): All segment columns, in the order they appear in the GROUP BY.
116116 flag_names (list[str]): Output names for the per-segment GROUPING() flag columns.
117-
118- Returns:
119- exp.Select: The rewritten SELECT (the same, mutated, ``tree``).
120117 """
121118 resolved = [_resolve_group_key (k , tree .selects ) for k in tree .args ["group" ].expressions ]
122119 key_by_name = dict (zip (segment_col , resolved , strict = True ))
@@ -126,7 +123,6 @@ def _rewrite_to_grouping_sets(
126123
127124 tuples = [exp .Tuple (expressions = [key_by_name [col ].copy () for col in gs ]) for gs in grouping_sets ]
128125 tree .set ("group" , exp .Group (expressions = [exp .GroupingSets (expressions = tuples )]))
129- return tree
130126
131127
132128def cube (* columns : str ) -> list [tuple [str , ...]]:
@@ -875,9 +871,8 @@ def _execute_grouping_sets_native(
875871 tree = compiler .to_sqlglot (base .unbind ())
876872 # to_sqlglot returns a list of statements on some backends (e.g. BigQuery scalar UDFs); the SELECT is last.
877873 tree = tree [- 1 ] if isinstance (tree , list ) else tree
878- native_sql = _rewrite_to_grouping_sets (tree , grouping_sets , segment_col , flag_names ).sql (
879- dialect = compiler .dialect ,
880- )
874+ _rewrite_to_grouping_sets (tree , grouping_sets , segment_col , flag_names )
875+ native_sql = tree .sql (dialect = compiler .dialect )
881876
882877 # Pass the known schema so con.sql skips inference (it mis-types float aggregates on e.g. Oracle).
883878 out_schema = ibis .schema ({** base .schema (), ** dict .fromkeys (flag_names , "int32" )})
0 commit comments