Skip to content

Fix invalid schema for unions in ViewTables #15135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,7 @@ impl LogicalPlanBuilder {
&missing_cols,
is_distinct,
)?;

let sort_plan = LogicalPlan::Sort(Sort {
expr: normalize_sorts(sorts, &plan)?,
input: Arc::new(plan),
Expand Down
20 changes: 15 additions & 5 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2812,14 +2812,16 @@ impl Union {
}
}

let mut name_counts: HashMap<String, usize> = HashMap::new();
let union_fields = (0..fields_count)
.map(|i| {
let fields = inputs
.iter()
.map(|input| input.schema().field(i))
.collect::<Vec<_>>();
let first_field = fields[0];
let name = first_field.name();
let base_name = first_field.name().to_string();

let data_type = if loose_types {
// TODO apply type coercion here, or document why it's better to defer
// temporarily use the data type from the left input and later rely on the analyzer to
Expand All @@ -2842,13 +2844,21 @@ impl Union {
)?
};
let nullable = fields.iter().any(|field| field.is_nullable());
let mut field = Field::new(name, data_type.clone(), nullable);

// Generate unique field name
let name = if let Some(count) = name_counts.get_mut(&base_name) {
*count += 1;
format!("{}_{}", base_name, count)
} else {
name_counts.insert(base_name.clone(), 0);
base_name
};

let mut field = Field::new(&name, data_type.clone(), nullable);
let field_metadata =
intersect_maps(fields.iter().map(|field| field.metadata()));
field.set_metadata(field_metadata);
// TODO reusing table reference from the first schema is probably wrong
let table_reference = first_schema.qualified_field(i).0.cloned();
Ok((table_reference, Arc::new(field)))
Ok((None, Arc::new(field)))
})
.collect::<Result<_>>()?;
let union_schema_metadata =
Expand Down
62 changes: 54 additions & 8 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ impl<'a> TypeCoercionRewriter<'a> {
/// Coerce the union’s inputs to a common schema compatible with all inputs.
/// This occurs after wildcard expansion and the coercion of the input expressions.
pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
let union_schema = Arc::new(coerce_union_schema(&union_plan.inputs)?);
let union_schema = Arc::new(coerce_union_schema_with_schema(
&union_plan.inputs,
&union_plan.schema,
)?);
let new_inputs = union_plan
.inputs
.into_iter()
Expand Down Expand Up @@ -930,7 +933,12 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
/// This method presumes that the wildcard expansion is unneeded, or has already
/// been applied.
pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
let base_schema = inputs[0].schema();
coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
}
fn coerce_union_schema_with_schema(
inputs: &[Arc<LogicalPlan>],
base_schema: &DFSchemaRef,
) -> Result<DFSchema> {
let mut union_datatypes = base_schema
.fields()
.iter()
Expand All @@ -949,7 +957,7 @@ pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {

let mut metadata = base_schema.metadata().clone();

for (i, plan) in inputs.iter().enumerate().skip(1) {
for (i, plan) in inputs.iter().enumerate() {
let plan_schema = plan.schema();
metadata.extend(plan_schema.metadata().clone());

Expand Down Expand Up @@ -989,15 +997,15 @@ pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
}
}
let union_qualified_fields = izip!(
base_schema.iter(),
base_schema.fields(),
union_datatypes.into_iter(),
union_nullabilities,
union_field_meta.into_iter()
)
.map(|((qualifier, field), datatype, nullable, metadata)| {
.map(|(field, datatype, nullable, metadata)| {
let mut field = Field::new(field.name().clone(), datatype, nullable);
field.set_metadata(metadata);
(qualifier.cloned(), field.into())
(None, field.into())
})
.collect::<Vec<_>>();

Expand Down Expand Up @@ -1041,11 +1049,12 @@ mod test {
use std::sync::Arc;

use arrow::datatypes::DataType::Utf8;
use arrow::datatypes::{DataType, Field, TimeUnit};
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};

use crate::analyzer::type_coercion::{
coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
};
use crate::analyzer::Analyzer;
use crate::test::{assert_analyzed_plan_eq, assert_analyzed_plan_with_config_eq};
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{TransformedResult, TreeNode};
Expand All @@ -1057,9 +1066,10 @@ mod test {
cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
SimpleAggregateUDF, Subquery, Volatility,
SimpleAggregateUDF, Subquery, Union, Volatility,
};
use datafusion_functions_aggregate::average::AvgAccumulator;
use datafusion_sql::TableReference;

fn empty() -> Arc<LogicalPlan> {
Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
Expand Down Expand Up @@ -1090,6 +1100,42 @@ mod test {
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
}

#[test]
fn test_coerce_union() -> Result<()> {
let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(
DFSchema::try_from_qualified_schema(
TableReference::full("datafusion", "test", "foo"),
&Schema::new(vec![Field::new("a", DataType::Int32, false)]),
)
.unwrap(),
),
}));
let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(
DFSchema::try_from_qualified_schema(
TableReference::full("datafusion", "test", "foo"),
&Schema::new(vec![Field::new("a", DataType::Int64, false)]),
)
.unwrap(),
),
}));
let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
left_plan, right_plan,
])?);
let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
.execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
let top_level_plan = LogicalPlan::Projection(Projection::try_new(
vec![col("a")],
Arc::new(analyzed_union),
)?);

let expected = "Projection: a\n Union\n Projection: CAST(datafusion.test.foo.a AS Int64) AS a\n EmptyRelation\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), top_level_plan, expected)
}

fn coerce_on_output_if_viewtype(plan: LogicalPlan, expected: &str) -> Result<()> {
let mut options = ConfigOptions::default();
options.optimizer.expand_views_at_output = true;
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/propagate_empty_relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ mod tests {

let plan = LogicalPlanBuilder::from(left).union(right)?.build()?;

let expected = "TableScan: test";
let expected = "Projection: a, b, c\n TableScan: test";
assert_together_optimized_plan(plan, expected, true)
}

Expand Down Expand Up @@ -406,7 +406,7 @@ mod tests {

let plan = LogicalPlanBuilder::from(left).union(right)?.build()?;

let expected = "TableScan: test";
let expected = "Projection: a, b, c\n TableScan: test";
assert_together_optimized_plan(plan, expected, true)
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/limit.slt
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ explain select * FROM (
----
logical_plan
01)Limit: skip=4, fetch=10
02)--Sort: ordered_table.c DESC NULLS FIRST, fetch=14
02)--Sort: c DESC NULLS FIRST, fetch=14
03)----Union
04)------Projection: CAST(ordered_table.c AS Int64) AS c
05)--------TableScan: ordered_table projection=[c]
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/order.slt
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ SELECT * FROM v
ORDER BY 1, 2;
----
logical_plan
01)Sort: u.m ASC NULLS LAST, u.t ASC NULLS LAST
01)Sort: m ASC NULLS LAST, t ASC NULLS LAST
02)--Union
03)----SubqueryAlias: u
04)------Projection: Int64(0) AS m, m0.t
Expand Down Expand Up @@ -1248,7 +1248,7 @@ order by d, c, a, a0, b
limit 2;
----
logical_plan
01)Sort: t1.d ASC NULLS LAST, t1.c ASC NULLS LAST, t1.a ASC NULLS LAST, t1.a0 ASC NULLS LAST, t1.b ASC NULLS LAST, fetch=2
01)Sort: d ASC NULLS LAST, c ASC NULLS LAST, a ASC NULLS LAST, a0 ASC NULLS LAST, b ASC NULLS LAST, fetch=2
02)--Union
03)----SubqueryAlias: t1
04)------Projection: ordered_table.b, ordered_table.c, ordered_table.a, Int32(NULL) AS a0, ordered_table.d
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/type_coercion.slt
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ EXPLAIN SELECT a FROM (select 1 a) x GROUP BY 1
(SELECT a FROM (select 1.1 a) x GROUP BY 1) ORDER BY 1
----
logical_plan
01)Sort: x.a ASC NULLS LAST
01)Sort: a ASC NULLS LAST
02)--Union
03)----Projection: CAST(x.a AS Float64) AS a
04)------Aggregate: groupBy=[[x.a]], aggr=[[]]
Expand Down
65 changes: 60 additions & 5 deletions datafusion/sqllogictest/test_files/union.slt
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ query TT
EXPLAIN SELECT name FROM t1 UNION (SELECT name from t2 UNION SELECT name || '_new' from t2)
----
logical_plan
01)Aggregate: groupBy=[[t1.name]], aggr=[[]]
01)Aggregate: groupBy=[[name]], aggr=[[]]
02)--Union
03)----TableScan: t1 projection=[name]
04)----TableScan: t2 projection=[name]
Expand Down Expand Up @@ -411,7 +411,7 @@ query TT
explain SELECT c1, c9 FROM aggregate_test_100 UNION ALL SELECT c1, c3 FROM aggregate_test_100 ORDER BY c9 DESC LIMIT 5
----
logical_plan
01)Sort: aggregate_test_100.c9 DESC NULLS FIRST, fetch=5
01)Sort: c9 DESC NULLS FIRST, fetch=5
02)--Union
03)----Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c9 AS Decimal128(20, 0)) AS c9
04)------TableScan: aggregate_test_100 projection=[c1, c9]
Expand Down Expand Up @@ -449,7 +449,7 @@ SELECT count(*) FROM (
----
logical_plan
01)Projection: count(Int64(1)) AS count(*)
02)--Aggregate: groupBy=[[t1.name]], aggr=[[count(Int64(1))]]
02)--Aggregate: groupBy=[[name]], aggr=[[count(Int64(1))]]
03)----Union
04)------Aggregate: groupBy=[[t1.name]], aggr=[[]]
05)--------TableScan: t1 projection=[name]
Expand Down Expand Up @@ -601,7 +601,7 @@ UNION ALL
ORDER BY c1
----
logical_plan
01)Sort: t1.c1 ASC NULLS LAST
01)Sort: c1 ASC NULLS LAST
02)--Union
03)----TableScan: t1 projection=[c1]
04)----Projection: t2.c1a AS c1
Expand Down Expand Up @@ -709,6 +709,25 @@ SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1
SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 WHERE (t1.v2 IS NULL);
----

query IR
SELECT t1.v0, t2.v0 FROM t1,t2
UNION ALL
SELECT t1.v0, t2.v0 FROM t1,t2
ORDER BY v0;
----
-1493773377 0.280145772929
-1493773377 0.280145772929
-1229445667 0.280145772929
-1229445667 0.280145772929
1541512604 0.280145772929
1541512604 0.280145772929
NULL 0.280145772929
NULL 0.280145772929
NULL 0.280145772929
NULL 0.280145772929
NULL 0.280145772929
NULL 0.280145772929

statement ok
CREATE TABLE t3 (
id INT
Expand Down Expand Up @@ -814,7 +833,7 @@ UNION ALL
ORDER BY c1
----
logical_plan
01)Sort: aggregate_test_100.c1 ASC NULLS LAST
01)Sort: c1 ASC NULLS LAST
02)--Union
03)----Filter: aggregate_test_100.c1 = Utf8("a")
04)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8("a")]
Expand Down Expand Up @@ -860,3 +879,39 @@ FROM (
GROUP BY combined
----
AB


# Test union in view
statement ok
CREATE TABLE u1 (x INT, y INT);

statement ok
INSERT INTO u1 VALUES (3, 3), (3, 3), (1, 1);

statement ok
CREATE TABLE u2 (y BIGINT, z BIGINT);

statement ok
INSERT INTO u2 VALUES (20, 20), (40, 40);

statement ok
CREATE VIEW v1 AS
SELECT y FROM u1 UNION ALL SELECT y FROM u2 ORDER BY y;

query I
SELECT * FROM (SELECT y FROM u1 UNION ALL SELECT y FROM u2) ORDER BY y;
----
1
3
3
20
40

query I
SELECT * FROM v1;
----
1
3
3
20
40
12 changes: 6 additions & 6 deletions datafusion/sqllogictest/test_files/union_by_name.slt
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ INSERT INTO t2 VALUES (2, 2), (4, 4);

# Test binding
query I
SELECT t1.x FROM t1 UNION BY NAME SELECT x FROM t1 ORDER BY t1.x;
SELECT t1.x FROM t1 UNION BY NAME SELECT x FROM t1 ORDER BY x;
----
1
3

query I
SELECT t1.x FROM t1 UNION ALL BY NAME SELECT x FROM t1 ORDER BY t1.x;
SELECT t1.x FROM t1 UNION ALL BY NAME SELECT x FROM t1 ORDER BY x;
----
1
1
Expand All @@ -70,13 +70,13 @@ SELECT t1.x FROM t1 UNION ALL BY NAME SELECT x FROM t1 ORDER BY t1.x;
3

query I
SELECT x FROM t1 UNION BY NAME SELECT x FROM t1 ORDER BY t1.x;
SELECT x FROM t1 UNION BY NAME SELECT x FROM t1 ORDER BY x;
----
1
3

query I
SELECT x FROM t1 UNION ALL BY NAME SELECT x FROM t1 ORDER BY t1.x;
SELECT x FROM t1 UNION ALL BY NAME SELECT x FROM t1 ORDER BY x;
----
1
1
Expand Down Expand Up @@ -124,8 +124,8 @@ NULL 3

# Ambiguous name

statement error DataFusion error: Schema error: No field named t1.x. Valid fields are a, b.
SELECT x AS a FROM t1 UNION BY NAME SELECT x AS b FROM t1 ORDER BY t1.x;
statement error DataFusion error: Schema error: No field named x. Valid fields are a, b.
SELECT x AS a FROM t1 UNION BY NAME SELECT x AS b FROM t1 ORDER BY x;

query II
(SELECT y FROM t1 UNION ALL SELECT x FROM t1) UNION BY NAME (SELECT z FROM t2 UNION ALL SELECT y FROM t2) ORDER BY y, z;
Expand Down