19
19
20
20
use std:: sync:: Arc ;
21
21
22
- use datafusion_common:: tree_node:: Transformed ;
22
+ use datafusion_common:: tree_node:: { Transformed , TreeNode } ;
23
23
use datafusion_common:: { DFSchema , DFSchemaRef , DataFusionError , Result } ;
24
24
use datafusion_expr:: execution_props:: ExecutionProps ;
25
25
use datafusion_expr:: logical_plan:: LogicalPlan ;
26
26
use datafusion_expr:: simplify:: SimplifyContext ;
27
27
use datafusion_expr:: utils:: merge_schema;
28
+ use datafusion_expr:: Expr ;
28
29
29
30
use crate :: optimizer:: ApplyOrder ;
30
31
use crate :: utils:: NamePreserver ;
@@ -122,14 +123,21 @@ impl SimplifyExpressions {
122
123
123
124
// Preserve expression names to avoid changing the schema of the plan.
124
125
let name_preserver = NamePreserver :: new ( & plan) ;
125
- plan. map_expressions ( |e| {
126
- let original_name = name_preserver. save ( & e) ;
127
- let new_e = simplifier
128
- . simplify ( e)
129
- . map ( |expr| original_name. restore ( expr) ) ?;
126
+ let mut rewrite_expr = |expr : Expr | {
127
+ let name = name_preserver. save ( & expr) ;
128
+ let expr = simplifier. simplify ( expr) ?;
130
129
// TODO it would be nice to have a way to know if the expression was simplified
131
130
// or not. For now conservatively return Transformed::yes
132
- Ok ( Transformed :: yes ( new_e) )
131
+ Ok ( Transformed :: yes ( name. restore ( expr) ) )
132
+ } ;
133
+
134
+ plan. map_expressions ( |expr| {
135
+ // Preserve the aliasing of grouping sets.
136
+ if let Expr :: GroupingSet ( _) = & expr {
137
+ expr. map_children ( & mut rewrite_expr)
138
+ } else {
139
+ rewrite_expr ( expr)
140
+ }
133
141
} )
134
142
}
135
143
}
@@ -151,11 +159,7 @@ mod tests {
151
159
use crate :: optimizer:: Optimizer ;
152
160
use datafusion_expr:: logical_plan:: builder:: table_scan_with_filters;
153
161
use datafusion_expr:: logical_plan:: table_scan;
154
- use datafusion_expr:: {
155
- and, binary_expr, col, lit, logical_plan:: builder:: LogicalPlanBuilder , Expr ,
156
- ExprSchemable , JoinType ,
157
- } ;
158
- use datafusion_expr:: { or, BinaryExpr , Cast , Operator } ;
162
+ use datafusion_expr:: * ;
159
163
use datafusion_functions_aggregate:: expr_fn:: { max, min} ;
160
164
161
165
use crate :: test:: { assert_fields_eq, test_table_scan_with_name} ;
@@ -743,4 +747,24 @@ mod tests {
743
747
744
748
assert_optimized_plan_eq ( plan, expected)
745
749
}
750
+
751
+ #[ test]
752
+ fn simplify_grouping_sets ( ) -> Result < ( ) > {
753
+ let table_scan = test_table_scan ( ) ;
754
+ let plan = LogicalPlanBuilder :: from ( table_scan)
755
+ . aggregate (
756
+ [ grouping_set ( vec ! [
757
+ vec![ ( lit( 42 ) . alias( "prev" ) + lit( 1 ) ) . alias( "age" ) , col( "a" ) ] ,
758
+ vec![ col( "a" ) . or( col( "b" ) ) . and( lit( 1 ) . lt( lit( 0 ) ) ) . alias( "cond" ) ] ,
759
+ vec![ col( "d" ) . alias( "e" ) , ( lit( 1 ) + lit( 2 ) ) ] ,
760
+ ] ) ] ,
761
+ [ ] as [ Expr ; 0 ] ,
762
+ ) ?
763
+ . build ( ) ?;
764
+
765
+ let expected = "Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]]\
766
+ \n TableScan: test";
767
+
768
+ assert_optimized_plan_eq ( plan, expected)
769
+ }
746
770
}
0 commit comments