diff --git a/src/dialect/generic_sql.rs b/src/dialect/generic_sql.rs index e187794a2..3788a749c 100644 --- a/src/dialect/generic_sql.rs +++ b/src/dialect/generic_sql.rs @@ -11,6 +11,8 @@ impl Dialect for GenericSqlDialect { STORED, CSV, PARQUET, LOCATION, WITH, WITHOUT, HEADER, ROW, // SQL types CHAR, CHARACTER, VARYING, LARGE, OBJECT, VARCHAR, CLOB, BINARY, VARBINARY, BLOB, FLOAT, REAL, DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, + BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL, + CROSS, OUTER, INNER, NATURAL, ON, USING, BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, LIKE, ]; } diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index 0b9d09f9b..20da4bef2 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -14,7 +14,7 @@ impl Dialect for PostgreSqlDialect { DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, BOOLEAN, DATE, TIME, TIMESTAMP, VALUES, DEFAULT, ZONE, REGCLASS, TEXT, BYTEA, TRUE, FALSE, COPY, STDIN, PRIMARY, KEY, UNIQUE, UUID, ADD, CONSTRAINT, FOREIGN, REFERENCES, CASE, WHEN, - THEN, ELSE, END, LIKE, + THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL, CROSS, OUTER, INNER, NATURAL, ON, USING, ]; } diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index 973a3cbc1..1d8328d5a 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -75,6 +75,8 @@ pub enum ASTNode { projection: Vec, /// FROM relation: Option>, + // JOIN + joins: Vec, /// WHERE selection: Option>, /// ORDER BY @@ -189,6 +191,7 @@ impl ToString for ASTNode { ASTNode::SQLSelect { projection, relation, + joins, selection, order_by, group_by, @@ -206,6 +209,9 @@ impl ToString for ASTNode { if let Some(relation) = relation { s += &format!(" FROM {}", relation.as_ref().to_string()); } + for join in joins { + s += &join.to_string(); + } if let Some(selection) = selection { s += &format!(" WHERE {}", selection.as_ref().to_string()); } @@ -408,3 +414,72 @@ impl ToString for SQLColumnDef { s } } + +#[derive(Debug, Clone, PartialEq)] +pub struct Join { + pub relation: ASTNode, + pub join_operator: JoinOperator, +} + +impl ToString for Join { + fn to_string(&self) -> String { + fn prefix(constraint: &JoinConstraint) -> String { + match constraint { + JoinConstraint::Natural => "NATURAL ".to_string(), + _ => "".to_string(), + } + } + fn suffix(constraint: &JoinConstraint) -> String { + match constraint { + JoinConstraint::On(expr) => format!("ON {}", expr.to_string()), + JoinConstraint::Using(attrs) => format!("USING({})", attrs.join(", ")), + _ => "".to_string(), + } + } + match &self.join_operator { + JoinOperator::Inner(constraint) => format!( + " {}JOIN {} {}", + prefix(constraint), + self.relation.to_string(), + suffix(constraint) + ), + JoinOperator::Cross => format!(" CROSS JOIN {}", self.relation.to_string()), + JoinOperator::Implicit => format!(", {}", self.relation.to_string()), + JoinOperator::LeftOuter(constraint) => format!( + " {}LEFT JOIN {} {}", + prefix(constraint), + self.relation.to_string(), + suffix(constraint) + ), + JoinOperator::RightOuter(constraint) => format!( + " {}RIGHT JOIN {} {}", + prefix(constraint), + self.relation.to_string(), + suffix(constraint) + ), + JoinOperator::FullOuter(constraint) => format!( + " {}FULL JOIN {} {}", + prefix(constraint), + self.relation.to_string(), + suffix(constraint) + ), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum JoinOperator { + Inner(JoinConstraint), + LeftOuter(JoinConstraint), + RightOuter(JoinConstraint), + FullOuter(JoinConstraint), + Implicit, + Cross, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum JoinConstraint { + On(ASTNode), + Using(Vec), + Natural, +} diff --git a/src/sqlparser.rs b/src/sqlparser.rs index a9a78aea7..06ad60d1c 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -486,6 +486,18 @@ impl Parser { true } + pub fn expect_keyword(&mut self, expected: &'static str) -> Result<(), ParserError> { + if self.parse_keyword(expected) { + Ok(()) + } else { + parser_err!(format!( + "Expected keyword {}, found {:?}", + expected, + self.peek_token() + )) + } + } + //TODO: this function is inconsistent and sometimes returns bool and sometimes fails /// Consume the next token if it matches the expected token, otherwise return an error @@ -1105,11 +1117,12 @@ impl Parser { pub fn parse_select(&mut self) -> Result { let projection = self.parse_expr_list()?; - let relation: Option> = if self.parse_keyword("FROM") { - //TODO: add support for JOIN - Some(Box::new(self.parse_expr(0)?)) + let (relation, joins): (Option>, Vec) = if self.parse_keyword("FROM") { + let relation = Some(Box::new(self.parse_expr(0)?)); + let joins = self.parse_joins()?; + (relation, joins) } else { - None + (None, vec![]) }; let selection = if self.parse_keyword("WHERE") { @@ -1155,6 +1168,7 @@ impl Parser { projection, selection, relation, + joins, limit, order_by, group_by, @@ -1163,6 +1177,131 @@ impl Parser { } } + fn parse_join_constraint(&mut self, natural: bool) -> Result { + if natural { + Ok(JoinConstraint::Natural) + } else if self.parse_keyword("ON") { + let constraint = self.parse_expr(0)?; + Ok(JoinConstraint::On(constraint)) + } else if self.parse_keyword("USING") { + if self.consume_token(&Token::LParen)? { + let attributes = self + .parse_expr_list()? + .into_iter() + .map(|ast_node| match ast_node { + ASTNode::SQLIdentifier(ident) => Ok(ident), + unexpected => { + parser_err!(format!("Expected identifier, found {:?}", unexpected)) + } + }) + .collect::, ParserError>>()?; + + if self.consume_token(&Token::RParen)? { + Ok(JoinConstraint::Using(attributes)) + } else { + parser_err!(format!("Expected token ')', found {:?}", self.peek_token())) + } + } else { + parser_err!(format!("Expected token '(', found {:?}", self.peek_token())) + } + } else { + parser_err!(format!( + "Unexpected token after JOIN: {:?}", + self.peek_token() + )) + } + } + + fn parse_joins(&mut self) -> Result, ParserError> { + let mut joins = vec![]; + loop { + let natural = match &self.peek_token() { + Some(Token::Comma) => { + self.next_token(); + let relation = self.parse_expr(0)?; + let join = Join { + relation, + join_operator: JoinOperator::Implicit, + }; + joins.push(join); + continue; + } + Some(Token::Keyword(kw)) if kw == "CROSS" => { + self.next_token(); + self.expect_keyword("JOIN")?; + let relation = self.parse_expr(0)?; + let join = Join { + relation, + join_operator: JoinOperator::Cross, + }; + joins.push(join); + continue; + } + Some(Token::Keyword(kw)) if kw == "NATURAL" => { + self.next_token(); + true + } + Some(_) => false, + None => return Ok(joins), + }; + + let join = match &self.peek_token() { + Some(Token::Keyword(kw)) if kw == "INNER" => { + self.next_token(); + self.expect_keyword("JOIN")?; + Join { + relation: self.parse_expr(0)?, + join_operator: JoinOperator::Inner(self.parse_join_constraint(natural)?), + } + } + Some(Token::Keyword(kw)) if kw == "JOIN" => { + self.next_token(); + Join { + relation: self.parse_expr(0)?, + join_operator: JoinOperator::Inner(self.parse_join_constraint(natural)?), + } + } + Some(Token::Keyword(kw)) if kw == "LEFT" => { + self.next_token(); + self.parse_keyword("OUTER"); + self.expect_keyword("JOIN")?; + Join { + relation: self.parse_expr(0)?, + join_operator: JoinOperator::LeftOuter( + self.parse_join_constraint(natural)?, + ), + } + } + Some(Token::Keyword(kw)) if kw == "RIGHT" => { + self.next_token(); + self.parse_keyword("OUTER"); + self.expect_keyword("JOIN")?; + Join { + relation: self.parse_expr(0)?, + join_operator: JoinOperator::RightOuter( + self.parse_join_constraint(natural)?, + ), + } + } + Some(Token::Keyword(kw)) if kw == "FULL" => { + self.next_token(); + self.parse_keyword("OUTER"); + self.expect_keyword("JOIN")?; + Join { + relation: self.parse_expr(0)?, + join_operator: JoinOperator::FullOuter( + self.parse_join_constraint(natural)?, + ), + } + } + _ => break, + }; + joins.push(join); + } + + Ok(joins) + } + /// Parse an INSERT statement pub fn parse_insert(&mut self) -> Result { self.parse_keyword("INTO"); diff --git a/tests/sqlparser_generic.rs b/tests/sqlparser_generic.rs index a63926eec..737caaf5b 100644 --- a/tests/sqlparser_generic.rs +++ b/tests/sqlparser_generic.rs @@ -476,6 +476,150 @@ fn parse_delete_with_semi_colon() { } } +#[test] +fn parse_implicit_join() { + let sql = "SELECT * FROM t1,t2"; + + match parse_sql(sql) { + ASTNode::SQLSelect { joins, .. } => { + assert_eq!(joins.len(), 1); + assert_eq!( + joins[0], + Join { + relation: ASTNode::SQLIdentifier("t2".to_string()), + join_operator: JoinOperator::Implicit + } + ) + } + _ => assert!(false), + } +} + +#[test] +fn parse_cross_join() { + let sql = "SELECT * FROM t1 CROSS JOIN t2"; + + match parse_sql(sql) { + ASTNode::SQLSelect { joins, .. } => { + assert_eq!(joins.len(), 1); + assert_eq!( + joins[0], + Join { + relation: ASTNode::SQLIdentifier("t2".to_string()), + join_operator: JoinOperator::Cross + } + ) + } + _ => assert!(false), + } +} + +#[test] +fn parse_joins_on() { + fn join_with_constraint( + relation: impl Into, + f: impl Fn(JoinConstraint) -> JoinOperator, + ) -> Join { + Join { + relation: ASTNode::SQLIdentifier(relation.into()), + join_operator: f(JoinConstraint::On(ASTNode::SQLBinaryExpr { + left: Box::new(ASTNode::SQLIdentifier("c1".into())), + op: SQLOperator::Eq, + right: Box::new(ASTNode::SQLIdentifier("c2".into())), + })), + } + } + assert_eq!( + joins_from(verified("SELECT * FROM t1 JOIN t2 ON c1 = c2")), + vec![join_with_constraint("t2", JoinOperator::Inner)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2")), + vec![join_with_constraint("t2", JoinOperator::LeftOuter)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2")), + vec![join_with_constraint("t2", JoinOperator::RightOuter)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2")), + vec![join_with_constraint("t2", JoinOperator::FullOuter)] + ); +} + +#[test] +fn parse_joins_using() { + fn join_with_constraint( + relation: impl Into, + f: impl Fn(JoinConstraint) -> JoinOperator, + ) -> Join { + Join { + relation: ASTNode::SQLIdentifier(relation.into()), + join_operator: f(JoinConstraint::Using(vec!["c1".into()])), + } + } + + assert_eq!( + joins_from(verified("SELECT * FROM t1 JOIN t2 USING(c1)")), + vec![join_with_constraint("t2", JoinOperator::Inner)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 LEFT JOIN t2 USING(c1)")), + vec![join_with_constraint("t2", JoinOperator::LeftOuter)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)")), + vec![join_with_constraint("t2", JoinOperator::RightOuter)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 FULL JOIN t2 USING(c1)")), + vec![join_with_constraint("t2", JoinOperator::FullOuter)] + ); +} + +#[test] +fn parse_complex_join() { + let sql = "SELECT c1, c2 FROM t1, t4 JOIN t2 ON t2.c = t1.c LEFT JOIN t3 USING(q, c) WHERE t4.c = t1.c"; + assert_eq!(sql, parse_sql(sql).to_string()); +} + +#[test] +fn parse_join_syntax_variants() { + fn parses_to(from: &str, to: &str) { + assert_eq!(to, &parse_sql(from).to_string()) + } + + parses_to( + "SELECT c1 FROM t1 INNER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 JOIN t2 USING(c1)", + ); + parses_to( + "SELECT c1 FROM t1 LEFT OUTER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 LEFT JOIN t2 USING(c1)", + ); + parses_to( + "SELECT c1 FROM t1 RIGHT OUTER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 RIGHT JOIN t2 USING(c1)", + ); + parses_to( + "SELECT c1 FROM t1 FULL OUTER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 FULL JOIN t2 USING(c1)", + ); +} + +fn verified(query: &str) -> ASTNode { + let ast = parse_sql(query); + assert_eq!(query, &ast.to_string()); + ast +} + +fn joins_from(ast: ASTNode) -> Vec { + match ast { + ASTNode::SQLSelect { joins, .. } => joins, + _ => panic!("Expected SELECT"), + } +} + fn parse_sql(sql: &str) -> ASTNode { let dialect = GenericSqlDialect {}; let mut tokenizer = Tokenizer::new(&dialect, &sql); diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index e731e56f7..b62888a1d 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -717,6 +717,150 @@ fn parse_function_now() { assert_eq!(sql, ast.to_string()); } +#[test] +fn parse_implicit_join() { + let sql = "SELECT * FROM t1, t2"; + + match verified(sql) { + ASTNode::SQLSelect { joins, .. } => { + assert_eq!(joins.len(), 1); + assert_eq!( + joins[0], + Join { + relation: ASTNode::SQLIdentifier("t2".to_string()), + join_operator: JoinOperator::Implicit + } + ) + } + _ => assert!(false), + } +} + +#[test] +fn parse_cross_join() { + let sql = "SELECT * FROM t1 CROSS JOIN t2"; + + match verified(sql) { + ASTNode::SQLSelect { joins, .. } => { + assert_eq!(joins.len(), 1); + assert_eq!( + joins[0], + Join { + relation: ASTNode::SQLIdentifier("t2".to_string()), + join_operator: JoinOperator::Cross + } + ) + } + _ => assert!(false), + } +} + +#[test] +fn parse_joins_on() { + fn join_with_constraint( + relation: impl Into, + f: impl Fn(JoinConstraint) -> JoinOperator, + ) -> Join { + Join { + relation: ASTNode::SQLIdentifier(relation.into()), + join_operator: f(JoinConstraint::On(ASTNode::SQLBinaryExpr { + left: Box::new(ASTNode::SQLIdentifier("c1".into())), + op: SQLOperator::Eq, + right: Box::new(ASTNode::SQLIdentifier("c2".into())), + })), + } + } + assert_eq!( + joins_from(verified("SELECT * FROM t1 JOIN t2 ON c1 = c2")), + vec![join_with_constraint("t2", JoinOperator::Inner)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2")), + vec![join_with_constraint("t2", JoinOperator::LeftOuter)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2")), + vec![join_with_constraint("t2", JoinOperator::RightOuter)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2")), + vec![join_with_constraint("t2", JoinOperator::FullOuter)] + ); +} + +#[test] +fn parse_joins_using() { + fn join_with_constraint( + relation: impl Into, + f: impl Fn(JoinConstraint) -> JoinOperator, + ) -> Join { + Join { + relation: ASTNode::SQLIdentifier(relation.into()), + join_operator: f(JoinConstraint::Using(vec!["c1".into()])), + } + } + + assert_eq!( + joins_from(verified("SELECT * FROM t1 JOIN t2 USING(c1)")), + vec![join_with_constraint("t2", JoinOperator::Inner)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 LEFT JOIN t2 USING(c1)")), + vec![join_with_constraint("t2", JoinOperator::LeftOuter)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)")), + vec![join_with_constraint("t2", JoinOperator::RightOuter)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 FULL JOIN t2 USING(c1)")), + vec![join_with_constraint("t2", JoinOperator::FullOuter)] + ); +} + +#[test] +fn parse_join_syntax_variants() { + fn parses_to(from: &str, to: &str) { + assert_eq!(to, &parse_sql(from).to_string()) + } + + parses_to( + "SELECT c1 FROM t1 INNER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 JOIN t2 USING(c1)", + ); + parses_to( + "SELECT c1 FROM t1 LEFT OUTER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 LEFT JOIN t2 USING(c1)", + ); + parses_to( + "SELECT c1 FROM t1 RIGHT OUTER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 RIGHT JOIN t2 USING(c1)", + ); + parses_to( + "SELECT c1 FROM t1 FULL OUTER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 FULL JOIN t2 USING(c1)", + ); +} + +#[test] +fn parse_complex_join() { + let sql = "SELECT c1, c2 FROM t1, t4 JOIN t2 ON t2.c = t1.c LEFT JOIN t3 USING(q, c) WHERE t4.c = t1.c"; + assert_eq!(sql, parse_sql(sql).to_string()); +} + +fn verified(query: &str) -> ASTNode { + let ast = parse_sql(query); + assert_eq!(query, &ast.to_string()); + ast +} + +fn joins_from(ast: ASTNode) -> Vec { + match ast { + ASTNode::SQLSelect { joins, .. } => joins, + _ => panic!("Expected SELECT"), + } +} + fn parse_sql(sql: &str) -> ASTNode { debug!("sql: {}", sql); let mut parser = parser(sql);