Skip to content

Commit 289c170

Browse files
committed
Stricter parsing for subqueries (3/4)
This makes the parser more strict when handling SELECTs nested somewhere in the main statement: 1) instead of accepting SELECT anywhere in the expression where an operand was expected, we only accept it inside parens. (I've added a test for the currently supported syntax, <scalar subquery> in ANSI SQL terms) 2) instead of accepting any expression in the derived table context: `FROM ( ... )` - we only look for a SELECT subquery there. Due to #1, I had to swith the 'ansi' test from invoking the expression parser to the statement parser.
1 parent d4ca09d commit 289c170

File tree

4 files changed

+48
-21
lines changed

4 files changed

+48
-21
lines changed

src/sqlast/mod.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ pub enum ASTNode {
7777
relation: Box<ASTNode>, // SQLNested or SQLCompoundIdentifier
7878
alias: Option<SQLIdent>,
7979
},
80-
/// SELECT
81-
SQLSelect(SQLSelect),
80+
/// A parenthesized subquery `(SELECT ...)`, used in expression like
81+
/// `SELECT (subquery) AS x` or `WHERE (subquery) = x`
82+
SQLSubquery(SQLSelect),
8283
}
8384

8485
impl ToString for ASTNode {
@@ -139,7 +140,7 @@ impl ToString for ASTNode {
139140
relation.to_string()
140141
}
141142
}
142-
ASTNode::SQLSelect(s) => s.to_string(),
143+
ASTNode::SQLSubquery(s) => format!("({})", s.to_string()),
143144
}
144145
}
145146
}

src/sqlparser.rs

+10-5
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ impl Parser {
158158
match self.next_token() {
159159
Some(t) => match t {
160160
Token::SQLWord(w) => match w.keyword.as_ref() {
161-
"SELECT" => Ok(ASTNode::SQLSelect(self.parse_select()?)),
162161
"TRUE" | "FALSE" | "NULL" => {
163162
self.prev_token();
164163
self.parse_sql_value()
@@ -197,9 +196,13 @@ impl Parser {
197196
self.parse_sql_value()
198197
}
199198
Token::LParen => {
200-
let expr = self.parse_expr()?;
199+
let expr = if self.parse_keyword("SELECT") {
200+
ASTNode::SQLSubquery(self.parse_select()?)
201+
} else {
202+
ASTNode::SQLNested(Box::new(self.parse_expr()?))
203+
};
201204
self.expect_token(&Token::RParen)?;
202-
Ok(ASTNode::SQLNested(Box::new(expr)))
205+
Ok(expr)
203206
}
204207
_ => parser_err!(format!(
205208
"Prefix parser expected a keyword but found {:?}",
@@ -1184,8 +1187,10 @@ impl Parser {
11841187
/// A table name or a parenthesized subquery, followed by optional `[AS] alias`
11851188
pub fn parse_table_factor(&mut self) -> Result<ASTNode, ParserError> {
11861189
let relation = if self.consume_token(&Token::LParen) {
1187-
self.prev_token();
1188-
self.parse_subexpr(0)? /* TBD (3) */
1190+
self.expect_keyword("SELECT")?;
1191+
let subquery = self.parse_select()?;
1192+
self.expect_token(&Token::RParen)?;
1193+
ASTNode::SQLSubquery(subquery)
11891194
} else {
11901195
self.parse_compound_identifier(&Token::Period)?
11911196
};

tests/sqlparser_ansi.rs

+4-13
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,16 @@ extern crate sqlparser;
44
use sqlparser::dialect::AnsiSqlDialect;
55
use sqlparser::sqlast::*;
66
use sqlparser::sqlparser::*;
7-
use sqlparser::sqltokenizer::*;
87

98
#[test]
109
fn parse_simple_select() {
1110
let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1");
12-
let ast = parse_sql_expr(&sql);
13-
match ast {
14-
ASTNode::SQLSelect(SQLSelect { projection, .. }) => {
11+
let ast = Parser::parse_sql(&AnsiSqlDialect {}, sql).unwrap();
12+
assert_eq!(1, ast.len());
13+
match ast.first().unwrap() {
14+
SQLStatement::SQLSelect(SQLSelect { projection, .. }) => {
1515
assert_eq!(3, projection.len());
1616
}
1717
_ => assert!(false),
1818
}
1919
}
20-
21-
fn parse_sql_expr(sql: &str) -> ASTNode {
22-
let dialect = AnsiSqlDialect {};
23-
let mut tokenizer = Tokenizer::new(&dialect, &sql);
24-
let tokens = tokenizer.tokenize().unwrap();
25-
let mut parser = Parser::new(tokens);
26-
let ast = parser.parse_expr().unwrap();
27-
ast
28-
}

tests/sqlparser_generic.rs

+30
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,13 @@ fn parse_join_syntax_variants() {
664664
);
665665
}
666666

667+
#[test]
668+
fn parse_derived_tables() {
669+
let sql = "SELECT a.x, b.y FROM (SELECT x FROM foo) AS a CROSS JOIN (SELECT y FROM bar) AS b";
670+
let _ = verified_select_stmt(sql);
671+
//TODO: add assertions
672+
}
673+
667674
#[test]
668675
fn parse_multiple_statements() {
669676
fn test_with(sql1: &str, sql2_kw: &str, sql2_rest: &str) {
@@ -695,6 +702,29 @@ fn parse_multiple_statements() {
695702
assert_eq!(0, res.unwrap().len());
696703
}
697704

705+
#[test]
706+
fn parse_scalar_subqueries() {
707+
use self::ASTNode::*;
708+
let sql = "(SELECT 1) + (SELECT 2)";
709+
match verified_expr(sql) {
710+
SQLBinaryExpr {
711+
op: SQLOperator::Plus, ..
712+
//left: box SQLSubquery { .. },
713+
//right: box SQLSubquery { .. },
714+
} => assert!(true),
715+
_ => assert!(false),
716+
};
717+
}
718+
719+
#[test]
720+
fn parse_invalid_subquery_without_parens() {
721+
let res = parse_sql_statements("SELECT SELECT 1 FROM bar WHERE 1=1 FROM baz");
722+
assert_eq!(
723+
ParserError::ParserError("Expected end of statement, found: 1".to_string()),
724+
res.unwrap_err()
725+
);
726+
}
727+
698728
fn only<'a, T>(v: &'a Vec<T>) -> &'a T {
699729
assert_eq!(1, v.len());
700730
v.first().unwrap()

0 commit comments

Comments
 (0)