Skip to content

Commit 8d2d495

Browse files
authored
Preserve the name of grouping sets in SimplifyExpressions (#14888)
Whenever we use `recompute_schema` or `with_exprs_and_inputs`, this ensures that we obtain the same schema.
1 parent f5b7aff commit 8d2d495

File tree

1 file changed

+36
-12
lines changed

1 file changed

+36
-12
lines changed

datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
2020
use std::sync::Arc;
2121

22-
use datafusion_common::tree_node::Transformed;
22+
use datafusion_common::tree_node::{Transformed, TreeNode};
2323
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
2424
use datafusion_expr::execution_props::ExecutionProps;
2525
use datafusion_expr::logical_plan::LogicalPlan;
2626
use datafusion_expr::simplify::SimplifyContext;
2727
use datafusion_expr::utils::merge_schema;
28+
use datafusion_expr::Expr;
2829

2930
use crate::optimizer::ApplyOrder;
3031
use crate::utils::NamePreserver;
@@ -122,14 +123,21 @@ impl SimplifyExpressions {
122123

123124
// Preserve expression names to avoid changing the schema of the plan.
124125
let name_preserver = NamePreserver::new(&plan);
125-
plan.map_expressions(|e| {
126-
let original_name = name_preserver.save(&e);
127-
let new_e = simplifier
128-
.simplify(e)
129-
.map(|expr| original_name.restore(expr))?;
126+
let mut rewrite_expr = |expr: Expr| {
127+
let name = name_preserver.save(&expr);
128+
let expr = simplifier.simplify(expr)?;
130129
// TODO it would be nice to have a way to know if the expression was simplified
131130
// or not. For now conservatively return Transformed::yes
132-
Ok(Transformed::yes(new_e))
131+
Ok(Transformed::yes(name.restore(expr)))
132+
};
133+
134+
plan.map_expressions(|expr| {
135+
// Preserve the aliasing of grouping sets.
136+
if let Expr::GroupingSet(_) = &expr {
137+
expr.map_children(&mut rewrite_expr)
138+
} else {
139+
rewrite_expr(expr)
140+
}
133141
})
134142
}
135143
}
@@ -151,11 +159,7 @@ mod tests {
151159
use crate::optimizer::Optimizer;
152160
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
153161
use datafusion_expr::logical_plan::table_scan;
154-
use datafusion_expr::{
155-
and, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Expr,
156-
ExprSchemable, JoinType,
157-
};
158-
use datafusion_expr::{or, BinaryExpr, Cast, Operator};
162+
use datafusion_expr::*;
159163
use datafusion_functions_aggregate::expr_fn::{max, min};
160164

161165
use crate::test::{assert_fields_eq, test_table_scan_with_name};
@@ -743,4 +747,24 @@ mod tests {
743747

744748
assert_optimized_plan_eq(plan, expected)
745749
}
750+
751+
#[test]
752+
fn simplify_grouping_sets() -> Result<()> {
753+
let table_scan = test_table_scan();
754+
let plan = LogicalPlanBuilder::from(table_scan)
755+
.aggregate(
756+
[grouping_set(vec![
757+
vec![(lit(42).alias("prev") + lit(1)).alias("age"), col("a")],
758+
vec![col("a").or(col("b")).and(lit(1).lt(lit(0))).alias("cond")],
759+
vec![col("d").alias("e"), (lit(1) + lit(2))],
760+
])],
761+
[] as [Expr; 0],
762+
)?
763+
.build()?;
764+
765+
let expected = "Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]]\
766+
\n TableScan: test";
767+
768+
assert_optimized_plan_eq(plan, expected)
769+
}
746770
}

0 commit comments

Comments
 (0)