Skip to content

Commit 1a75092

Browse files
lovasoaayman-sigma
authored andcommitted
Replace parallel condition/result vectors with single CaseWhen vector in Expr::Case (apache#1733)
1 parent 03bfb07 commit 1a75092

File tree

5 files changed

+160
-50
lines changed

5 files changed

+160
-50
lines changed

src/ast/mod.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,22 @@ pub enum CeilFloorKind {
600600
Scale(Value),
601601
}
602602

603+
/// A WHEN clause in a CASE expression containing both
604+
/// the condition and its corresponding result
605+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
606+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
607+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
608+
pub struct CaseWhen {
609+
pub condition: Expr,
610+
pub result: Expr,
611+
}
612+
613+
impl fmt::Display for CaseWhen {
614+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
615+
write!(f, "WHEN {} THEN {}", self.condition, self.result)
616+
}
617+
}
618+
603619
/// An SQL expression of any type.
604620
///
605621
/// # Semantics / Type Checking
@@ -925,8 +941,7 @@ pub enum Expr {
925941
/// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause>
926942
Case {
927943
operand: Option<Box<Expr>>,
928-
conditions: Vec<Expr>,
929-
results: Vec<Expr>,
944+
conditions: Vec<CaseWhen>,
930945
else_result: Option<Box<Expr>>,
931946
},
932947
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
@@ -1639,17 +1654,15 @@ impl fmt::Display for Expr {
16391654
Expr::Case {
16401655
operand,
16411656
conditions,
1642-
results,
16431657
else_result,
16441658
} => {
16451659
write!(f, "CASE")?;
16461660
if let Some(operand) = operand {
16471661
write!(f, " {operand}")?;
16481662
}
1649-
for (c, r) in conditions.iter().zip(results) {
1650-
write!(f, " WHEN {c} THEN {r}")?;
1663+
for when in conditions {
1664+
write!(f, " {when}")?;
16511665
}
1652-
16531666
if let Some(else_result) = else_result {
16541667
write!(f, " ELSE {else_result}")?;
16551668
}

src/ast/spans.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,15 +1455,15 @@ impl Spanned for Expr {
14551455
Expr::Case {
14561456
operand,
14571457
conditions,
1458-
results,
14591458
else_result,
14601459
} => union_spans(
14611460
operand
14621461
.as_ref()
14631462
.map(|i| i.span())
14641463
.into_iter()
1465-
.chain(conditions.iter().map(|i| i.span()))
1466-
.chain(results.iter().map(|i| i.span()))
1464+
.chain(conditions.iter().flat_map(|case_when| {
1465+
[case_when.condition.span(), case_when.result.span()]
1466+
}))
14671467
.chain(else_result.as_ref().map(|i| i.span())),
14681468
),
14691469
Expr::Exists { subquery, .. } => subquery.span(),

src/parser/mod.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,11 +2065,11 @@ impl<'a> Parser<'a> {
20652065
self.expect_keyword_is(Keyword::WHEN)?;
20662066
}
20672067
let mut conditions = vec![];
2068-
let mut results = vec![];
20692068
loop {
2070-
conditions.push(self.parse_expr()?);
2069+
let condition = self.parse_expr()?;
20712070
self.expect_keyword_is(Keyword::THEN)?;
2072-
results.push(self.parse_expr()?);
2071+
let result = self.parse_expr()?;
2072+
conditions.push(CaseWhen { condition, result });
20732073
if !self.parse_keyword(Keyword::WHEN) {
20742074
break;
20752075
}
@@ -2083,7 +2083,6 @@ impl<'a> Parser<'a> {
20832083
Ok(Expr::Case {
20842084
operand,
20852085
conditions,
2086-
results,
20872086
else_result,
20882087
})
20892088
}

tests/sqlparser_common.rs

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6696,22 +6696,26 @@ fn parse_searched_case_expr() {
66966696
&Case {
66976697
operand: None,
66986698
conditions: vec![
6699-
IsNull(Box::new(Identifier(Ident::new("bar")))),
6700-
BinaryOp {
6701-
left: Box::new(Identifier(Ident::new("bar"))),
6702-
op: Eq,
6703-
right: Box::new(Expr::Value(number("0"))),
6699+
CaseWhen {
6700+
condition: IsNull(Box::new(Identifier(Ident::new("bar")))),
6701+
result: Expr::Value(Value::SingleQuotedString("null".to_string())),
67046702
},
6705-
BinaryOp {
6706-
left: Box::new(Identifier(Ident::new("bar"))),
6707-
op: GtEq,
6708-
right: Box::new(Expr::Value(number("0"))),
6703+
CaseWhen {
6704+
condition: BinaryOp {
6705+
left: Box::new(Identifier(Ident::new("bar"))),
6706+
op: Eq,
6707+
right: Box::new(Expr::Value(number("0"))),
6708+
},
6709+
result: Expr::Value(Value::SingleQuotedString("=0".to_string())),
6710+
},
6711+
CaseWhen {
6712+
condition: BinaryOp {
6713+
left: Box::new(Identifier(Ident::new("bar"))),
6714+
op: GtEq,
6715+
right: Box::new(Expr::Value(number("0"))),
6716+
},
6717+
result: Expr::Value(Value::SingleQuotedString(">=0".to_string())),
67096718
},
6710-
],
6711-
results: vec![
6712-
Expr::Value(Value::SingleQuotedString("null".to_string())),
6713-
Expr::Value(Value::SingleQuotedString("=0".to_string())),
6714-
Expr::Value(Value::SingleQuotedString(">=0".to_string())),
67156719
],
67166720
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
67176721
"<0".to_string()
@@ -6730,8 +6734,10 @@ fn parse_simple_case_expr() {
67306734
assert_eq!(
67316735
&Case {
67326736
operand: Some(Box::new(Identifier(Ident::new("foo")))),
6733-
conditions: vec![Expr::Value(number("1"))],
6734-
results: vec![Expr::Value(Value::SingleQuotedString("Y".to_string()))],
6737+
conditions: vec![CaseWhen {
6738+
condition: Expr::Value(number("1")),
6739+
result: Expr::Value(Value::SingleQuotedString("Y".to_string())),
6740+
}],
67356741
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
67366742
"N".to_string()
67376743
)))),
@@ -13905,6 +13911,31 @@ fn test_trailing_commas_in_from() {
1390513911
);
1390613912
}
1390713913

13914+
#[test]
13915+
#[cfg(feature = "visitor")]
13916+
fn test_visit_order() {
13917+
let sql = "SELECT CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END";
13918+
let stmt = verified_stmt(sql);
13919+
let mut visited = vec![];
13920+
sqlparser::ast::visit_expressions(&stmt, |expr| {
13921+
visited.push(expr.to_string());
13922+
core::ops::ControlFlow::<()>::Continue(())
13923+
});
13924+
13925+
assert_eq!(
13926+
visited,
13927+
[
13928+
"CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END",
13929+
"a",
13930+
"1",
13931+
"2",
13932+
"3",
13933+
"4",
13934+
"5"
13935+
]
13936+
);
13937+
}
13938+
1390813939
#[test]
1390913940
fn test_lambdas() {
1391013941
let dialects = all_dialects_where(|d| d.supports_lambda_functions());
@@ -13932,28 +13963,30 @@ fn test_lambdas() {
1393213963
body: Box::new(Expr::Case {
1393313964
operand: None,
1393413965
conditions: vec![
13935-
Expr::BinaryOp {
13936-
left: Box::new(Expr::Identifier(Ident::new("p1"))),
13937-
op: BinaryOperator::Eq,
13938-
right: Box::new(Expr::Identifier(Ident::new("p2")))
13966+
CaseWhen {
13967+
condition: Expr::BinaryOp {
13968+
left: Box::new(Expr::Identifier(Ident::new("p1"))),
13969+
op: BinaryOperator::Eq,
13970+
right: Box::new(Expr::Identifier(Ident::new("p2")))
13971+
},
13972+
result: Expr::Value(number("0"))
1393913973
},
13940-
Expr::BinaryOp {
13941-
left: Box::new(call(
13942-
"reverse",
13943-
[Expr::Identifier(Ident::new("p1"))]
13944-
)),
13945-
op: BinaryOperator::Lt,
13946-
right: Box::new(call(
13947-
"reverse",
13948-
[Expr::Identifier(Ident::new("p2"))]
13949-
))
13950-
}
13951-
],
13952-
results: vec![
13953-
Expr::Value(number("0")),
13954-
Expr::UnaryOp {
13955-
op: UnaryOperator::Minus,
13956-
expr: Box::new(Expr::Value(number("1")))
13974+
CaseWhen {
13975+
condition: Expr::BinaryOp {
13976+
left: Box::new(call(
13977+
"reverse",
13978+
[Expr::Identifier(Ident::new("p1"))]
13979+
)),
13980+
op: BinaryOperator::Lt,
13981+
right: Box::new(call(
13982+
"reverse",
13983+
[Expr::Identifier(Ident::new("p2"))]
13984+
))
13985+
},
13986+
result: Expr::UnaryOp {
13987+
op: UnaryOperator::Minus,
13988+
expr: Box::new(Expr::Value(number("1")))
13989+
}
1395713990
}
1395813991
],
1395913992
else_result: Some(Box::new(Expr::Value(number("1"))))

tests/sqlparser_databricks.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,71 @@ fn test_databricks_exists() {
8383
);
8484
}
8585

86+
#[test]
87+
fn test_databricks_lambdas() {
88+
#[rustfmt::skip]
89+
let sql = concat!(
90+
"SELECT array_sort(array('Hello', 'World'), ",
91+
"(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ",
92+
"WHEN reverse(p1) < reverse(p2) THEN -1 ",
93+
"ELSE 1 END)",
94+
);
95+
pretty_assertions::assert_eq!(
96+
SelectItem::UnnamedExpr(call(
97+
"array_sort",
98+
[
99+
call(
100+
"array",
101+
[
102+
Expr::Value(Value::SingleQuotedString("Hello".to_owned())),
103+
Expr::Value(Value::SingleQuotedString("World".to_owned()))
104+
]
105+
),
106+
Expr::Lambda(LambdaFunction {
107+
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
108+
body: Box::new(Expr::Case {
109+
operand: None,
110+
conditions: vec![
111+
CaseWhen {
112+
condition: Expr::BinaryOp {
113+
left: Box::new(Expr::Identifier(Ident::new("p1"))),
114+
op: BinaryOperator::Eq,
115+
right: Box::new(Expr::Identifier(Ident::new("p2")))
116+
},
117+
result: Expr::Value(number("0"))
118+
},
119+
CaseWhen {
120+
condition: Expr::BinaryOp {
121+
left: Box::new(call(
122+
"reverse",
123+
[Expr::Identifier(Ident::new("p1"))]
124+
)),
125+
op: BinaryOperator::Lt,
126+
right: Box::new(call(
127+
"reverse",
128+
[Expr::Identifier(Ident::new("p2"))]
129+
)),
130+
},
131+
result: Expr::UnaryOp {
132+
op: UnaryOperator::Minus,
133+
expr: Box::new(Expr::Value(number("1")))
134+
}
135+
},
136+
],
137+
else_result: Some(Box::new(Expr::Value(number("1"))))
138+
})
139+
})
140+
]
141+
)),
142+
databricks().verified_only_select(sql).projection[0]
143+
);
144+
145+
databricks().verified_expr(
146+
"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
147+
);
148+
databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)");
149+
}
150+
86151
#[test]
87152
fn test_values_clause() {
88153
let values = Values {

0 commit comments

Comments
 (0)