From fbbffea3efdc968265b4651c36b35cc904f23224 Mon Sep 17 00:00:00 2001 From: Roman Borschel Date: Tue, 1 Apr 2025 17:35:20 +0200 Subject: [PATCH 1/9] Add support for MSSQL IF/ELSE statements. These are syntactically quite different from the already supported IF ... THEN ... ELSEIF ... END IF statements. Hence IfStatement is now an enum with two variants and statement parsing is overridden for the MSSQL dialect in order to parse IF statements differently for MSSQL. Thereby fix spans for if/case AST nodes by including start/end tokens, if present. --- src/ast/mod.rs | 203 ++++++++++++++++++++++++++------------ src/ast/spans.rs | 77 +++++++++------ src/dialect/mssql.rs | 97 ++++++++++++++++++ src/parser/mod.rs | 56 ++++++----- tests/sqlparser_common.rs | 106 +++++++++++++++++--- tests/sqlparser_mssql.rs | 103 ++++++++++++++++++- 6 files changed, 508 insertions(+), 134 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index f187df995..81e6aa867 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -37,7 +37,8 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "visitor")] use sqlparser_derive::{Visit, VisitMut}; -use crate::tokenizer::Span; +use crate::keywords::Keyword; +use crate::tokenizer::{Span, Token, TokenWithSpan}; pub use self::data_type::{ ArrayElemTypeDef, BinaryLength, CharLengthUnits, CharacterLength, DataType, EnumMember, @@ -2118,20 +2119,23 @@ pub enum Password { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct CaseStatement { + /// The `CASE` token that starts the statement. + pub case_token: TokenWithSpan, pub match_expr: Option, pub when_blocks: Vec, - pub else_block: Option>, - /// TRUE if the statement ends with `END CASE` (vs `END`). - pub has_end_case: bool, + pub else_block: Option, + /// The last token of the statement (`END` or `CASE`). + pub end_case_token: TokenWithSpan, } impl fmt::Display for CaseStatement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let CaseStatement { + case_token: _, match_expr, when_blocks, else_block, - has_end_case, + end_case_token, } = self; write!(f, "CASE")?; @@ -2145,13 +2149,15 @@ impl fmt::Display for CaseStatement { } if let Some(else_block) = else_block { - write!(f, " ELSE ")?; - format_statement_list(f, else_block)?; + write!(f, " {else_block}")?; } write!(f, " END")?; - if *has_end_case { - write!(f, " CASE")?; + + if let Token::Word(w) = &end_case_token.token { + if w.keyword == Keyword::CASE { + write!(f, " CASE")?; + } } Ok(()) @@ -2159,102 +2165,173 @@ impl fmt::Display for CaseStatement { } /// An `IF` statement. -/// -/// Examples: -/// ```sql -/// IF TRUE THEN -/// SELECT 1; -/// SELECT 2; -/// ELSEIF TRUE THEN -/// SELECT 3; -/// ELSE -/// SELECT 4; -/// END IF -/// ``` -/// -/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if) -/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if) #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub struct IfStatement { - pub if_block: ConditionalStatements, - pub elseif_blocks: Vec, - pub else_block: Option>, +pub enum IfStatement { + /// An `IF ... THEN [ELSE[IF] ...] END IF` statement. + /// + /// Example: + /// ```sql + /// IF TRUE THEN + /// SELECT 1; + /// SELECT 2; + /// ELSEIF TRUE THEN + /// SELECT 3; + /// ELSE + /// SELECT 4; + /// END IF + /// ``` + /// + /// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if) + /// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if) + IfThenElseEnd { + /// The `IF` token that starts the statement. + if_token: TokenWithSpan, + if_block: ConditionalStatements, + elseif_blocks: Vec, + else_block: Option, + /// The `IF` token that ends the statement. + end_if_token: TokenWithSpan, + }, + /// An MSSQL `IF ... ELSE ...` statement. + /// + /// Example: + /// ```sql + /// IF 1=1 SELECT 1 ELSE SELECT 2 + /// ``` + /// + /// [MSSQL](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/if-else-transact-sql?view=sql-server-ver16) + MsSqlIfElse { + if_token: TokenWithSpan, + condition: Expr, + if_statements: MsSqlIfStatements, + else_statements: Option, + }, } impl fmt::Display for IfStatement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let IfStatement { - if_block, - elseif_blocks, - else_block, - } = self; + match self { + IfStatement::IfThenElseEnd { + if_token: _, + if_block, + elseif_blocks, + else_block, + end_if_token: _, + } => { + write!(f, "{if_block}")?; - write!(f, "{if_block}")?; + if !elseif_blocks.is_empty() { + write!(f, " {}", display_separated(elseif_blocks, " "))?; + } - if !elseif_blocks.is_empty() { - write!(f, " {}", display_separated(elseif_blocks, " "))?; - } + if let Some(else_block) = else_block { + write!(f, " {else_block}")?; + } - if let Some(else_block) = else_block { - write!(f, " ELSE ")?; - format_statement_list(f, else_block)?; - } + write!(f, " END IF")?; + + Ok(()) + } + IfStatement::MsSqlIfElse { + if_token: _, + condition, + if_statements, + else_statements, + } => { + write!(f, "IF {condition} {if_statements}")?; - write!(f, " END IF")?; + if let Some(els) = else_statements { + write!(f, " ELSE {els}")?; + } - Ok(()) + Ok(()) + } + } } } -/// Represents a type of [ConditionalStatements] +/// (MSSQL) Either a single [Statement] or a block of statements +/// enclosed in `BEGIN` and `END`. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub enum ConditionalStatementKind { - /// `WHEN THEN ` - When, - /// `IF THEN ` - If, - /// `ELSEIF THEN ` - ElseIf, +pub enum MsSqlIfStatements { + /// A single statement. + Single(Box), + /// ```sql + /// A logical block of statements. + /// + /// BEGIN + /// ; + /// ; + /// ... + /// END + /// ``` + Block { + begin_token: TokenWithSpan, + statements: Vec, + end_token: TokenWithSpan, + }, +} + +impl fmt::Display for MsSqlIfStatements { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + MsSqlIfStatements::Single(stmt) => stmt.fmt(f), + MsSqlIfStatements::Block { statements, .. } => { + write!(f, "BEGIN ")?; + format_statement_list(f, statements)?; + write!(f, " END") + } + } + } } /// A block within a [Statement::Case] or [Statement::If]-like statement /// -/// Examples: +/// Example 1: /// ```sql /// WHEN EXISTS(SELECT 1) THEN SELECT 1; +/// ``` /// +/// Example 2: +/// ```sql /// IF TRUE THEN SELECT 1; SELECT 2; /// ``` +/// +/// Example 3: +/// ```sql +/// ELSE SELECT 1; SELECT 2; +/// ``` #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct ConditionalStatements { - /// The condition expression. - pub condition: Expr, + /// The start token of the conditional (`WHEN`, `IF`, `ELSEIF` or `ELSE`). + pub start_token: TokenWithSpan, + /// The condition expression. `None` for `ELSE` statements. + pub condition: Option, /// Statement list of the `THEN` clause. pub statements: Vec, - pub kind: ConditionalStatementKind, } impl fmt::Display for ConditionalStatements { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let ConditionalStatements { - condition: expr, + start_token, + condition, statements, - kind, } = self; - let kind = match kind { - ConditionalStatementKind::When => "WHEN", - ConditionalStatementKind::If => "IF", - ConditionalStatementKind::ElseIf => "ELSEIF", - }; + let keyword = &start_token.token; - write!(f, "{kind} {expr} THEN")?; + if let Some(expr) = condition { + write!(f, "{keyword} {expr} THEN")?; + } else { + write!(f, "{keyword}")?; + } if !statements.is_empty() { write!(f, " ")?; diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 11770d1bc..d6d8f5683 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -30,13 +30,14 @@ use super::{ FunctionArgumentClause, FunctionArgumentList, FunctionArguments, GroupByExpr, HavingBound, IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause, MatchRecognizePattern, Measure, - NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, OnConflictAction, - OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, ProjectionSelect, - Query, RaiseStatement, RaiseStatementValue, ReferentialAction, RenameSelectItem, - ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, - Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, - TableFactor, TableObject, TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, - Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill, + MsSqlIfStatements, NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, + OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, + ProjectionSelect, Query, RaiseStatement, RaiseStatementValue, ReferentialAction, + RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, + SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, + TableConstraint, TableFactor, TableObject, TableOptionsClustered, TableWithJoins, + UpdateTableFromKind, Use, Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, + WithFill, }; /// Given an iterator of spans, return the [Span::union] of all spans. @@ -739,47 +740,63 @@ impl Spanned for CreateIndex { impl Spanned for CaseStatement { fn span(&self) -> Span { let CaseStatement { - match_expr, - when_blocks, - else_block, - has_end_case: _, + case_token, + end_case_token, + .. } = self; - union_spans( - match_expr - .iter() - .map(|e| e.span()) - .chain(when_blocks.iter().map(|b| b.span())) - .chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))), - ) + union_spans([case_token.span, end_case_token.span].into_iter()) } } impl Spanned for IfStatement { fn span(&self) -> Span { - let IfStatement { - if_block, - elseif_blocks, - else_block, - } = self; + match self { + IfStatement::IfThenElseEnd { + if_token, + end_if_token, + .. + } => union_spans([if_token.span, end_if_token.span].into_iter()), + IfStatement::MsSqlIfElse { + if_token, + if_statements, + else_statements, + .. + } => union_spans( + [if_token.span, if_statements.span()] + .into_iter() + .chain(else_statements.as_ref().into_iter().map(|s| s.span())), + ), + } + } +} - union_spans( - iter::once(if_block.span()) - .chain(elseif_blocks.iter().map(|b| b.span())) - .chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))), - ) +impl Spanned for MsSqlIfStatements { + fn span(&self) -> Span { + match self { + MsSqlIfStatements::Single(s) => s.span(), + MsSqlIfStatements::Block { + begin_token, + end_token, + .. + } => union_spans([begin_token.span, end_token.span].into_iter()), + } } } impl Spanned for ConditionalStatements { fn span(&self) -> Span { let ConditionalStatements { + start_token, condition, statements, - kind: _, } = self; - union_spans(iter::once(condition.span()).chain(statements.iter().map(|s| s.span()))) + union_spans( + iter::once(start_token.span) + .chain(condition.as_ref().map(|c| c.span()).into_iter()) + .chain(statements.iter().map(|s| s.span())), + ) } } diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 18a963a4b..67fdccd65 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -15,7 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::ast::{IfStatement, MsSqlIfStatements, Statement}; use crate::dialect::Dialect; +use crate::keywords::{self, Keyword}; +use crate::parser::{Parser, ParserError}; +use crate::tokenizer::Token; + +const RESERVED_FOR_COLUMN_ALIAS: &[Keyword] = &[Keyword::IF, Keyword::ELSE]; /// A [`Dialect`] for [Microsoft SQL Server](https://www.microsoft.com/en-us/sql-server/) #[derive(Debug)] @@ -106,4 +112,95 @@ impl Dialect for MsSqlDialect { fn supports_object_name_double_dot_notation(&self) -> bool { true } + + fn is_column_alias(&self, kw: &Keyword, _parser: &mut Parser) -> bool { + !keywords::RESERVED_FOR_COLUMN_ALIAS.contains(kw) && !RESERVED_FOR_COLUMN_ALIAS.contains(kw) + } + + fn parse_statement(&self, parser: &mut Parser) -> Option> { + if parser.peek_keyword(Keyword::IF) { + Some(self.parse_if_stmt(parser)) + } else { + None + } + } +} + +impl MsSqlDialect { + /// ```sql + /// IF boolean_expression + /// { sql_statement | statement_block } + /// [ ELSE + /// { sql_statement | statement_block } ] + /// ``` + fn parse_if_stmt(&self, parser: &mut Parser) -> Result { + let if_token = parser.expect_keyword(Keyword::IF)?; + + let condition = parser.parse_expr()?; + + let if_statements; + if parser.peek_keyword(Keyword::BEGIN) { + let begin_token = parser.expect_keyword(Keyword::BEGIN)?; + let statements = self.parse_statement_list(parser, Some(Keyword::END))?; + let end_token = parser.expect_keyword(Keyword::END)?; + if_statements = MsSqlIfStatements::Block { + begin_token, + statements, + end_token, + }; + } else { + let stmt = parser.parse_statement()?; + if_statements = MsSqlIfStatements::Single(Box::new(stmt)); + } + + let mut else_statements = None; + if parser.parse_keyword(Keyword::ELSE) { + if parser.peek_keyword(Keyword::BEGIN) { + let begin_token = parser.expect_keyword(Keyword::BEGIN)?; + let statements = self.parse_statement_list(parser, Some(Keyword::END))?; + let end_token = parser.expect_keyword(Keyword::END)?; + else_statements = Some(MsSqlIfStatements::Block { + begin_token, + statements, + end_token, + }); + } else { + let stmt = parser.parse_statement()?; + else_statements = Some(MsSqlIfStatements::Single(Box::new(stmt))); + } + } + + Ok(Statement::If(IfStatement::MsSqlIfElse { + if_token, + condition, + if_statements, + else_statements, + })) + } + + /// Parse a sequence of statements, optionally separated by semicolon. + /// + /// Stops parsing when reaching EOF or the given keyword. + fn parse_statement_list( + &self, + parser: &mut Parser, + terminal_keyword: Option, + ) -> Result, ParserError> { + let mut stmts = Vec::new(); + loop { + if let Token::EOF = parser.peek_token_ref().token { + break; + } + if let Some(term) = terminal_keyword { + if parser.peek_keyword(term) { + break; + } + } + stmts.push(parser.parse_statement()?); + while let Token::SemiColon = parser.peek_token_ref().token { + parser.advance_token(); + } + } + Ok(stmts) + } } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 2b61529ff..f5c62b630 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -631,7 +631,7 @@ impl<'a> Parser<'a> { /// /// See [Statement::Case] pub fn parse_case_stmt(&mut self) -> Result { - self.expect_keyword_is(Keyword::CASE)?; + let case_token = self.expect_keyword(Keyword::CASE)?; let match_expr = if self.peek_keyword(Keyword::WHEN) { None @@ -641,26 +641,26 @@ impl<'a> Parser<'a> { self.expect_keyword_is(Keyword::WHEN)?; let when_blocks = self.parse_keyword_separated(Keyword::WHEN, |parser| { - parser.parse_conditional_statements( - ConditionalStatementKind::When, - &[Keyword::WHEN, Keyword::ELSE, Keyword::END], - ) + parser.parse_conditional_statements(&[Keyword::WHEN, Keyword::ELSE, Keyword::END]) })?; let else_block = if self.parse_keyword(Keyword::ELSE) { - Some(self.parse_statement_list(&[Keyword::END])?) + Some(self.parse_conditional_statements(&[Keyword::END])?) } else { None }; - self.expect_keyword_is(Keyword::END)?; - let has_end_case = self.parse_keyword(Keyword::CASE); + let mut end_case_token = self.expect_keyword(Keyword::END)?; + if self.peek_keyword(Keyword::CASE) { + end_case_token = self.expect_keyword(Keyword::CASE)?; + } Ok(Statement::Case(CaseStatement { + case_token, match_expr, when_blocks, else_block, - has_end_case, + end_case_token, })) } @@ -668,35 +668,33 @@ impl<'a> Parser<'a> { /// /// See [Statement::If] pub fn parse_if_stmt(&mut self) -> Result { - self.expect_keyword_is(Keyword::IF)?; - let if_block = self.parse_conditional_statements( - ConditionalStatementKind::If, - &[Keyword::ELSE, Keyword::ELSEIF, Keyword::END], - )?; + let if_token = self.expect_keyword(Keyword::IF)?; + let if_block = + self.parse_conditional_statements(&[Keyword::ELSE, Keyword::ELSEIF, Keyword::END])?; let elseif_blocks = if self.parse_keyword(Keyword::ELSEIF) { self.parse_keyword_separated(Keyword::ELSEIF, |parser| { - parser.parse_conditional_statements( - ConditionalStatementKind::ElseIf, - &[Keyword::ELSEIF, Keyword::ELSE, Keyword::END], - ) + parser.parse_conditional_statements(&[Keyword::ELSEIF, Keyword::ELSE, Keyword::END]) })? } else { vec![] }; let else_block = if self.parse_keyword(Keyword::ELSE) { - Some(self.parse_statement_list(&[Keyword::END])?) + Some(self.parse_conditional_statements(&[Keyword::END])?) } else { None }; - self.expect_keywords(&[Keyword::END, Keyword::IF])?; + self.expect_keyword_is(Keyword::END)?; + let end_if_token = self.expect_keyword(Keyword::IF)?; - Ok(Statement::If(IfStatement { + Ok(Statement::If(IfStatement::IfThenElseEnd { + if_token, if_block, elseif_blocks, else_block, + end_if_token, })) } @@ -709,17 +707,25 @@ impl<'a> Parser<'a> { /// ``` fn parse_conditional_statements( &mut self, - kind: ConditionalStatementKind, terminal_keywords: &[Keyword], ) -> Result { - let condition = self.parse_expr()?; - self.expect_keyword_is(Keyword::THEN)?; + let start_token = self.get_current_token().clone(); + + let condition = match &start_token.token { + Token::Word(w) if w.keyword == Keyword::ELSE => None, + _ => { + let expr = self.parse_expr()?; + self.expect_keyword_is(Keyword::THEN)?; + Some(expr) + } + }; + let statements = self.parse_statement_list(terminal_keywords)?; Ok(ConditionalStatements { + start_token, condition, statements, - kind, }) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 795dae4b3..5914a3dea 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -14215,9 +14215,12 @@ fn parse_case_statement() { }; assert_eq!(Some(Expr::value(number("1"))), stmt.match_expr); - assert_eq!(Expr::value(number("2")), stmt.when_blocks[0].condition); + assert_eq!( + Some(Expr::value(number("2"))), + stmt.when_blocks[0].condition + ); assert_eq!(2, stmt.when_blocks[0].statements.len()); - assert_eq!(1, stmt.else_block.unwrap().len()); + assert_eq!(1, stmt.else_block.unwrap().statements.len()); verified_stmt(concat!( "CASE 1", @@ -14260,17 +14263,35 @@ fn parse_case_statement() { ); } +#[test] +fn test_case_statement_span() { + let sql = "CASE 1 WHEN 2 THEN SELECT 1; SELECT 2; ELSE SELECT 3; END CASE"; + let mut parser = Parser::new(&GenericDialect {}).try_with_sql(sql).unwrap(); + assert_eq!( + parser.parse_statement().unwrap().span(), + Span::new(Location::new(1, 1), Location::new(1, sql.len() as u64 + 1)) + ); +} + #[test] fn parse_if_statement() { + let dialects = all_dialects_except(|d| d.is::()); + let sql = "IF 1 THEN SELECT 1; ELSEIF 2 THEN SELECT 2; ELSE SELECT 3; END IF"; - let Statement::If(stmt) = verified_stmt(sql) else { + let Statement::If(IfStatement::IfThenElseEnd { + if_block, + elseif_blocks, + else_block, + .. + }) = dialects.verified_stmt(sql) + else { unreachable!() }; - assert_eq!(Expr::value(number("1")), stmt.if_block.condition); - assert_eq!(Expr::value(number("2")), stmt.elseif_blocks[0].condition); - assert_eq!(1, stmt.else_block.unwrap().len()); + assert_eq!(Some(Expr::value(number("1"))), if_block.condition); + assert_eq!(Some(Expr::value(number("2"))), elseif_blocks[0].condition); + assert_eq!(1, else_block.unwrap().statements.len()); - verified_stmt(concat!( + dialects.verified_stmt(concat!( "IF 1 THEN", " SELECT 1;", " SELECT 2;", @@ -14286,7 +14307,7 @@ fn parse_if_statement() { " SELECT 9;", " END IF" )); - verified_stmt(concat!( + dialects.verified_stmt(concat!( "IF 1 THEN", " SELECT 1;", " SELECT 2;", @@ -14295,7 +14316,7 @@ fn parse_if_statement() { " SELECT 4;", " END IF" )); - verified_stmt(concat!( + dialects.verified_stmt(concat!( "IF 1 THEN", " SELECT 1;", " SELECT 2;", @@ -14305,22 +14326,79 @@ fn parse_if_statement() { " SELECT 4;", " END IF" )); - verified_stmt(concat!("IF 1 THEN", " SELECT 1;", " SELECT 2;", " END IF")); - verified_stmt(concat!( + dialects.verified_stmt(concat!("IF 1 THEN", " SELECT 1;", " SELECT 2;", " END IF")); + dialects.verified_stmt(concat!( "IF (1) THEN", " SELECT 1;", " SELECT 2;", " END IF" )); - verified_stmt("IF 1 THEN END IF"); - verified_stmt("IF 1 THEN SELECT 1; ELSEIF 1 THEN END IF"); + dialects.verified_stmt("IF 1 THEN END IF"); + dialects.verified_stmt("IF 1 THEN SELECT 1; ELSEIF 1 THEN END IF"); assert_eq!( ParserError::ParserError("Expected: IF, found: EOF".to_string()), - parse_sql_statements("IF 1 THEN SELECT 1; ELSEIF 1 THEN SELECT 2; END").unwrap_err() + dialects + .parse_sql_statements("IF 1 THEN SELECT 1; ELSEIF 1 THEN SELECT 2; END") + .unwrap_err() ); } +#[test] +fn test_if_statement_span() { + let sql = "IF 1=1 THEN SELECT 1; ELSEIF 1=2 THEN SELECT 2; ELSE SELECT 3; END IF"; + let mut parser = Parser::new(&GenericDialect {}).try_with_sql(sql).unwrap(); + assert_eq!( + parser.parse_statement().unwrap().span(), + Span::new(Location::new(1, 1), Location::new(1, sql.len() as u64 + 1)) + ); +} + +#[test] +fn test_if_statement_multiline_span() { + let sql_line1 = "IF 1 = 1 THEN SELECT 1;"; + let sql_line2 = "ELSEIF 1 = 2 THEN SELECT 2;"; + let sql_line3 = "ELSE SELECT 3;"; + let sql_line4 = "END IF"; + let sql = [sql_line1, sql_line2, sql_line3, sql_line4].join("\n"); + let mut parser = Parser::new(&GenericDialect {}).try_with_sql(&sql).unwrap(); + assert_eq!( + parser.parse_statement().unwrap().span(), + Span::new( + Location::new(1, 1), + Location::new(4, sql_line4.len() as u64 + 1) + ) + ); +} + +#[test] +fn test_conditional_statement_span() { + let sql = "IF 1=1 THEN SELECT 1; ELSEIF 1=2 THEN SELECT 2; ELSE SELECT 3; END IF"; + let mut parser = Parser::new(&GenericDialect {}).try_with_sql(sql).unwrap(); + match parser.parse_statement().unwrap() { + Statement::If(IfStatement::IfThenElseEnd { + if_block, + elseif_blocks, + else_block, + .. + }) => { + assert_eq!( + Span::new(Location::new(1, 1), Location::new(1, 21)), + if_block.span() + ); + assert_eq!( + Span::new(Location::new(1, 23), Location::new(1, 47)), + elseif_blocks[0].span() + ); + assert_eq!( + Span::new(Location::new(1, 49), Location::new(1, 62)), + else_block.unwrap().span() + ); + } + stmt => panic!("Unexpected statement: {:?}", stmt), + } +} + #[test] fn parse_raise_statement() { let sql = "RAISE USING MESSAGE = 42"; diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 2bfc38a6a..4c13760c3 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -23,7 +23,7 @@ mod test_utils; use helpers::attached_token::AttachedToken; -use sqlparser::tokenizer::Span; +use sqlparser::tokenizer::{Location, Span}; use test_utils::*; use sqlparser::ast::DataType::{Int, Text, Varbinary}; @@ -31,7 +31,7 @@ use sqlparser::ast::DeclareAssignment::MsSqlAssignment; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::*; use sqlparser::dialect::{GenericDialect, MsSqlDialect}; -use sqlparser::parser::ParserError; +use sqlparser::parser::{Parser, ParserError}; #[test] fn parse_mssql_identifiers() { @@ -1857,6 +1857,104 @@ fn parse_mssql_set_session_value() { ms().verified_stmt("SET ANSI_NULLS, ANSI_PADDING ON"); } +#[test] +fn parse_mssql_if_else() { + // Simple statements and blocks + ms().verified_stmt("IF 1 = 1 SELECT '1' ELSE SELECT '2'"); + ms().verified_stmt("IF 1 = 1 BEGIN SET @A = 1; END ELSE SET @A = 2"); + ms().verified_stmt( + "IF DATENAME(weekday, GETDATE()) IN (N'Saturday', N'Sunday') SELECT 'Weekend' ELSE SELECT 'Weekday'" + ); + ms().verified_stmt( + "IF (SELECT COUNT(*) FROM a.b WHERE c LIKE 'x%') > 1 SELECT 'yes' ELSE SELECT 'No'", + ); + + // Multiple statements + let stmts = ms() + .parse_sql_statements("DECLARE @A INT; IF 1=1 BEGIN SET @A = 1 END ELSE SET @A = 2") + .unwrap(); + match &stmts[..] { + [Statement::Declare { .. }, Statement::If(stmt)] => { + assert_eq!( + stmt.to_string(), + "IF 1 = 1 BEGIN SET @A = 1; END ELSE SET @A = 2" + ); + } + _ => panic!("Unexpected statements: {:?}", stmts), + } +} + +#[test] +fn test_mssql_if_else_span() { + let sql = "IF 1 = 1 SELECT '1' ELSE SELECT '2'"; + let mut parser = Parser::new(&MsSqlDialect {}).try_with_sql(sql).unwrap(); + assert_eq!( + parser.parse_statement().unwrap().span(), + Span::new(Location::new(1, 1), Location::new(1, sql.len() as u64 + 1)) + ); +} + +#[test] +fn test_mssql_if_else_multiline_span() { + let sql_line1 = "IF 1 = 1"; + let sql_line2 = "SELECT '1'"; + let sql_line3 = "ELSE SELECT '2'"; + let sql = [sql_line1, sql_line2, sql_line3].join("\n"); + let mut parser = Parser::new(&MsSqlDialect {}).try_with_sql(&sql).unwrap(); + assert_eq!( + parser.parse_statement().unwrap().span(), + Span::new( + Location::new(1, 1), + Location::new(3, sql_line3.len() as u64 + 1) + ) + ); +} + +#[test] +fn test_mssql_if_statements_span() { + // Simple statements + let mut sql = "IF 1 = 1 SELECT '1' ELSE SELECT '2'"; + let mut parser = Parser::new(&MsSqlDialect {}).try_with_sql(sql).unwrap(); + match parser.parse_statement().unwrap() { + Statement::If(IfStatement::MsSqlIfElse { + if_statements, + else_statements: Some(else_statements), + .. + }) => { + assert_eq!( + if_statements.span(), + Span::new(Location::new(1, 10), Location::new(1, 20)) + ); + assert_eq!( + else_statements.span(), + Span::new(Location::new(1, 26), Location::new(1, 36)) + ); + } + stmt => panic!("Unexpected statement: {:?}", stmt), + } + + // Blocks + sql = "IF 1 = 1 BEGIN SET @A = 1; END ELSE BEGIN SET @A = 2 END"; + parser = Parser::new(&MsSqlDialect {}).try_with_sql(sql).unwrap(); + match parser.parse_statement().unwrap() { + Statement::If(IfStatement::MsSqlIfElse { + if_statements, + else_statements: Some(else_statements), + .. + }) => { + assert_eq!( + if_statements.span(), + Span::new(Location::new(1, 10), Location::new(1, 31)) + ); + assert_eq!( + else_statements.span(), + Span::new(Location::new(1, 37), Location::new(1, 57)) + ); + } + stmt => panic!("Unexpected statement: {:?}", stmt), + } +} + #[test] fn parse_mssql_varbinary_max_length() { let sql = "CREATE TABLE example (var_binary_col VARBINARY(MAX))"; @@ -1918,6 +2016,7 @@ fn parse_mssql_table_identifier_with_default_schema() { fn ms() -> TestedDialects { TestedDialects::new(vec![Box::new(MsSqlDialect {})]) } + fn ms_and_generic() -> TestedDialects { TestedDialects::new(vec![Box::new(MsSqlDialect {}), Box::new(GenericDialect {})]) } From e610c8b3e48b85573147363042a08bfe6642328d Mon Sep 17 00:00:00 2001 From: Roman Borschel Date: Wed, 2 Apr 2025 17:39:26 +0200 Subject: [PATCH 2/9] Wrap tokens attached to AST nodes in AttachedToken. --- src/ast/mod.rs | 26 +++++++++++++------------- src/ast/spans.rs | 28 ++++++++++++++-------------- src/dialect/mssql.rs | 11 ++++++----- src/parser/mod.rs | 10 +++++----- 4 files changed, 38 insertions(+), 37 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 81e6aa867..89e3ef7ea 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -38,7 +38,7 @@ use serde::{Deserialize, Serialize}; use sqlparser_derive::{Visit, VisitMut}; use crate::keywords::Keyword; -use crate::tokenizer::{Span, Token, TokenWithSpan}; +use crate::tokenizer::{Span, Token}; pub use self::data_type::{ ArrayElemTypeDef, BinaryLength, CharLengthUnits, CharacterLength, DataType, EnumMember, @@ -2120,12 +2120,12 @@ pub enum Password { #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct CaseStatement { /// The `CASE` token that starts the statement. - pub case_token: TokenWithSpan, + pub case_token: AttachedToken, pub match_expr: Option, pub when_blocks: Vec, pub else_block: Option, /// The last token of the statement (`END` or `CASE`). - pub end_case_token: TokenWithSpan, + pub end_case_token: AttachedToken, } impl fmt::Display for CaseStatement { @@ -2135,7 +2135,7 @@ impl fmt::Display for CaseStatement { match_expr, when_blocks, else_block, - end_case_token, + end_case_token: AttachedToken(end), } = self; write!(f, "CASE")?; @@ -2154,7 +2154,7 @@ impl fmt::Display for CaseStatement { write!(f, " END")?; - if let Token::Word(w) = &end_case_token.token { + if let Token::Word(w) = &end.token { if w.keyword == Keyword::CASE { write!(f, " CASE")?; } @@ -2187,12 +2187,12 @@ pub enum IfStatement { /// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if) IfThenElseEnd { /// The `IF` token that starts the statement. - if_token: TokenWithSpan, + if_token: AttachedToken, if_block: ConditionalStatements, elseif_blocks: Vec, else_block: Option, /// The `IF` token that ends the statement. - end_if_token: TokenWithSpan, + end_if_token: AttachedToken, }, /// An MSSQL `IF ... ELSE ...` statement. /// @@ -2203,7 +2203,7 @@ pub enum IfStatement { /// /// [MSSQL](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/if-else-transact-sql?view=sql-server-ver16) MsSqlIfElse { - if_token: TokenWithSpan, + if_token: AttachedToken, condition: Expr, if_statements: MsSqlIfStatements, else_statements: Option, @@ -2270,9 +2270,9 @@ pub enum MsSqlIfStatements { /// END /// ``` Block { - begin_token: TokenWithSpan, + begin_token: AttachedToken, statements: Vec, - end_token: TokenWithSpan, + end_token: AttachedToken, }, } @@ -2310,7 +2310,7 @@ impl fmt::Display for MsSqlIfStatements { #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct ConditionalStatements { /// The start token of the conditional (`WHEN`, `IF`, `ELSEIF` or `ELSE`). - pub start_token: TokenWithSpan, + pub start_token: AttachedToken, /// The condition expression. `None` for `ELSE` statements. pub condition: Option, /// Statement list of the `THEN` clause. @@ -2320,12 +2320,12 @@ pub struct ConditionalStatements { impl fmt::Display for ConditionalStatements { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let ConditionalStatements { - start_token, + start_token: AttachedToken(start), condition, statements, } = self; - let keyword = &start_token.token; + let keyword = &start.token; if let Some(expr) = condition { write!(f, "{keyword} {expr} THEN")?; diff --git a/src/ast/spans.rs b/src/ast/spans.rs index d6d8f5683..f4685e960 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -21,7 +21,7 @@ use core::iter; use crate::tokenizer::Span; use super::{ - dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation, + dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation, AttachedToken, AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConditionalStatements, ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, @@ -740,12 +740,12 @@ impl Spanned for CreateIndex { impl Spanned for CaseStatement { fn span(&self) -> Span { let CaseStatement { - case_token, - end_case_token, + case_token: AttachedToken(start), + end_case_token: AttachedToken(end), .. } = self; - union_spans([case_token.span, end_case_token.span].into_iter()) + union_spans([start.span, end.span].into_iter()) } } @@ -753,17 +753,17 @@ impl Spanned for IfStatement { fn span(&self) -> Span { match self { IfStatement::IfThenElseEnd { - if_token, - end_if_token, + if_token: AttachedToken(start), + end_if_token: AttachedToken(end), .. - } => union_spans([if_token.span, end_if_token.span].into_iter()), + } => union_spans([start.span, end.span].into_iter()), IfStatement::MsSqlIfElse { - if_token, + if_token: AttachedToken(start), if_statements, else_statements, .. } => union_spans( - [if_token.span, if_statements.span()] + [start.span, if_statements.span()] .into_iter() .chain(else_statements.as_ref().into_iter().map(|s| s.span())), ), @@ -776,10 +776,10 @@ impl Spanned for MsSqlIfStatements { match self { MsSqlIfStatements::Single(s) => s.span(), MsSqlIfStatements::Block { - begin_token, - end_token, + begin_token: AttachedToken(start), + end_token: AttachedToken(end), .. - } => union_spans([begin_token.span, end_token.span].into_iter()), + } => union_spans([start.span, end.span].into_iter()), } } } @@ -787,13 +787,13 @@ impl Spanned for MsSqlIfStatements { impl Spanned for ConditionalStatements { fn span(&self) -> Span { let ConditionalStatements { - start_token, + start_token: AttachedToken(start), condition, statements, } = self; union_spans( - iter::once(start_token.span) + iter::once(start.span) .chain(condition.as_ref().map(|c| c.span()).into_iter()) .chain(statements.iter().map(|s| s.span())), ) diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 67fdccd65..e9d5532b4 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -16,6 +16,7 @@ // under the License. use crate::ast::{IfStatement, MsSqlIfStatements, Statement}; +use crate::ast::helpers::attached_token::AttachedToken; use crate::dialect::Dialect; use crate::keywords::{self, Keyword}; use crate::parser::{Parser, ParserError}; @@ -144,9 +145,9 @@ impl MsSqlDialect { let statements = self.parse_statement_list(parser, Some(Keyword::END))?; let end_token = parser.expect_keyword(Keyword::END)?; if_statements = MsSqlIfStatements::Block { - begin_token, + begin_token: AttachedToken(begin_token), statements, - end_token, + end_token: AttachedToken(end_token), }; } else { let stmt = parser.parse_statement()?; @@ -160,9 +161,9 @@ impl MsSqlDialect { let statements = self.parse_statement_list(parser, Some(Keyword::END))?; let end_token = parser.expect_keyword(Keyword::END)?; else_statements = Some(MsSqlIfStatements::Block { - begin_token, + begin_token: AttachedToken(begin_token), statements, - end_token, + end_token: AttachedToken(end_token), }); } else { let stmt = parser.parse_statement()?; @@ -171,7 +172,7 @@ impl MsSqlDialect { } Ok(Statement::If(IfStatement::MsSqlIfElse { - if_token, + if_token: AttachedToken(if_token), condition, if_statements, else_statements, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index f5c62b630..47ad5c8a0 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -656,11 +656,11 @@ impl<'a> Parser<'a> { } Ok(Statement::Case(CaseStatement { - case_token, + case_token: AttachedToken(case_token), match_expr, when_blocks, else_block, - end_case_token, + end_case_token: AttachedToken(end_case_token), })) } @@ -690,11 +690,11 @@ impl<'a> Parser<'a> { let end_if_token = self.expect_keyword(Keyword::IF)?; Ok(Statement::If(IfStatement::IfThenElseEnd { - if_token, + if_token: AttachedToken(if_token), if_block, elseif_blocks, else_block, - end_if_token, + end_if_token: AttachedToken(end_if_token), })) } @@ -723,7 +723,7 @@ impl<'a> Parser<'a> { let statements = self.parse_statement_list(terminal_keywords)?; Ok(ConditionalStatements { - start_token, + start_token: AttachedToken(start_token), condition, statements, }) From 54d2325b983d782523109786973e9ce10e6e91d6 Mon Sep 17 00:00:00 2001 From: Roman Borschel Date: Wed, 2 Apr 2025 17:51:10 +0200 Subject: [PATCH 3/9] Address cargo fmt style errors. --- src/ast/spans.rs | 34 +++++++++++++++++----------------- src/dialect/mssql.rs | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/ast/spans.rs b/src/ast/spans.rs index f4685e960..ff9545cdb 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -21,23 +21,23 @@ use core::iter; use crate::tokenizer::Span; use super::{ - dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation, AttachedToken, - AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, CaseStatement, - CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConditionalStatements, - ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, - CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, - ExprWithAlias, Fetch, FromTable, Function, FunctionArg, FunctionArgExpr, - FunctionArgumentClause, FunctionArgumentList, FunctionArguments, GroupByExpr, HavingBound, - IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, JoinConstraint, - JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause, MatchRecognizePattern, Measure, - MsSqlIfStatements, NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, - OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, - ProjectionSelect, Query, RaiseStatement, RaiseStatementValue, ReferentialAction, - RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, - SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, - TableConstraint, TableFactor, TableObject, TableOptionsClustered, TableWithJoins, - UpdateTableFromKind, Use, Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, - WithFill, + dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation, + AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, AttachedToken, + CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, + ConditionalStatements, ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource, + CreateIndex, CreateTable, CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem, + ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable, Function, FunctionArg, + FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, GroupByExpr, + HavingBound, IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, + JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause, + MatchRecognizePattern, Measure, MsSqlIfStatements, NamedWindowDefinition, ObjectName, + ObjectNamePart, Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, + OrderByKind, Partition, PivotValueSource, ProjectionSelect, Query, RaiseStatement, + RaiseStatementValue, ReferentialAction, RenameSelectItem, ReplaceSelectElement, + ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, + SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, TableFactor, TableObject, + TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef, + WildcardAdditionalOptions, With, WithFill, }; /// Given an iterator of spans, return the [Span::union] of all spans. diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index e9d5532b4..9f85c3fa5 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::ast::{IfStatement, MsSqlIfStatements, Statement}; use crate::ast::helpers::attached_token::AttachedToken; +use crate::ast::{IfStatement, MsSqlIfStatements, Statement}; use crate::dialect::Dialect; use crate::keywords::{self, Keyword}; use crate::parser::{Parser, ParserError}; From 1d74a269c4e099edb09e5078d582f5fe53525f6a Mon Sep 17 00:00:00 2001 From: Roman Borschel Date: Fri, 4 Apr 2025 21:28:35 +0200 Subject: [PATCH 4/9] Avoid dialect-specific AST nodes. --- src/ast/mod.rs | 235 ++++++++++++++++++-------------------- src/ast/spans.rs | 88 +++++++------- src/dialect/mssql.rs | 67 ++++++++--- src/parser/mod.rs | 42 ++++--- tests/sqlparser_common.rs | 10 +- tests/sqlparser_mssql.rs | 38 +++--- 6 files changed, 254 insertions(+), 226 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 89e3ef7ea..dad29d17e 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2122,8 +2122,8 @@ pub struct CaseStatement { /// The `CASE` token that starts the statement. pub case_token: AttachedToken, pub match_expr: Option, - pub when_blocks: Vec, - pub else_block: Option, + pub when_blocks: Vec, + pub else_block: Option, /// The last token of the statement (`END` or `CASE`). pub end_case_token: AttachedToken, } @@ -2165,127 +2165,60 @@ impl fmt::Display for CaseStatement { } /// An `IF` statement. +/// +/// Example (BigQuery or Snowflake): +/// ```sql +/// IF TRUE THEN +/// SELECT 1; +/// SELECT 2; +/// ELSEIF TRUE THEN +/// SELECT 3; +/// ELSE +/// SELECT 4; +/// END IF +/// ``` +/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if) +/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if) +/// +/// Example (MSSQL): +/// ```sql +/// IF 1=1 SELECT 1 ELSE SELECT 2 +/// ``` +/// [MSSQL](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/if-else-transact-sql?view=sql-server-ver16) #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub enum IfStatement { - /// An `IF ... THEN [ELSE[IF] ...] END IF` statement. - /// - /// Example: - /// ```sql - /// IF TRUE THEN - /// SELECT 1; - /// SELECT 2; - /// ELSEIF TRUE THEN - /// SELECT 3; - /// ELSE - /// SELECT 4; - /// END IF - /// ``` - /// - /// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if) - /// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if) - IfThenElseEnd { - /// The `IF` token that starts the statement. - if_token: AttachedToken, - if_block: ConditionalStatements, - elseif_blocks: Vec, - else_block: Option, - /// The `IF` token that ends the statement. - end_if_token: AttachedToken, - }, - /// An MSSQL `IF ... ELSE ...` statement. - /// - /// Example: - /// ```sql - /// IF 1=1 SELECT 1 ELSE SELECT 2 - /// ``` - /// - /// [MSSQL](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/if-else-transact-sql?view=sql-server-ver16) - MsSqlIfElse { - if_token: AttachedToken, - condition: Expr, - if_statements: MsSqlIfStatements, - else_statements: Option, - }, +pub struct IfStatement { + pub if_block: ConditionalStatementBlock, + pub elseif_blocks: Vec, + pub else_block: Option, + pub end_token: Option, } impl fmt::Display for IfStatement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - IfStatement::IfThenElseEnd { - if_token: _, - if_block, - elseif_blocks, - else_block, - end_if_token: _, - } => { - write!(f, "{if_block}")?; - - if !elseif_blocks.is_empty() { - write!(f, " {}", display_separated(elseif_blocks, " "))?; - } - - if let Some(else_block) = else_block { - write!(f, " {else_block}")?; - } - - write!(f, " END IF")?; - - Ok(()) - } - IfStatement::MsSqlIfElse { - if_token: _, - condition, - if_statements, - else_statements, - } => { - write!(f, "IF {condition} {if_statements}")?; + let IfStatement { + if_block, + elseif_blocks, + else_block, + end_token, + } = self; - if let Some(els) = else_statements { - write!(f, " ELSE {els}")?; - } + write!(f, "{if_block}")?; - Ok(()) - } + for elseif_block in elseif_blocks { + write!(f, " {elseif_block}")?; } - } -} -/// (MSSQL) Either a single [Statement] or a block of statements -/// enclosed in `BEGIN` and `END`. -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub enum MsSqlIfStatements { - /// A single statement. - Single(Box), - /// ```sql - /// A logical block of statements. - /// - /// BEGIN - /// ; - /// ; - /// ... - /// END - /// ``` - Block { - begin_token: AttachedToken, - statements: Vec, - end_token: AttachedToken, - }, -} + if let Some(else_block) = else_block { + write!(f, " {else_block}")?; + } -impl fmt::Display for MsSqlIfStatements { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - MsSqlIfStatements::Single(stmt) => stmt.fmt(f), - MsSqlIfStatements::Block { statements, .. } => { - write!(f, "BEGIN ")?; - format_statement_list(f, statements)?; - write!(f, " END") - } + if let Some(AttachedToken(end_token)) = end_token { + write!(f, " END {end_token}")?; } + + Ok(()) } } @@ -2308,40 +2241,88 @@ impl fmt::Display for MsSqlIfStatements { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub struct ConditionalStatements { - /// The start token of the conditional (`WHEN`, `IF`, `ELSEIF` or `ELSE`). +pub struct ConditionalStatementBlock { pub start_token: AttachedToken, - /// The condition expression. `None` for `ELSE` statements. pub condition: Option, - /// Statement list of the `THEN` clause. - pub statements: Vec, + pub then_token: Option, + pub conditional_statements: ConditionalStatements, } -impl fmt::Display for ConditionalStatements { +impl ConditionalStatementBlock { + pub fn statements(&self) -> &Vec { + self.conditional_statements.statements() + } +} + +impl fmt::Display for ConditionalStatementBlock { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let ConditionalStatements { - start_token: AttachedToken(start), + let ConditionalStatementBlock { + start_token: AttachedToken(start_token), condition, - statements, + then_token, + conditional_statements, } = self; - let keyword = &start.token; + write!(f, "{start_token}")?; - if let Some(expr) = condition { - write!(f, "{keyword} {expr} THEN")?; - } else { - write!(f, "{keyword}")?; + if let Some(condition) = condition { + write!(f, " {condition}")?; + } + + if then_token.is_some() { + write!(f, " THEN")?; } - if !statements.is_empty() { - write!(f, " ")?; - format_statement_list(f, statements)?; + if conditional_statements.statements().len() > 0 { + write!(f, " {conditional_statements}")?; } Ok(()) } } +/// A list of statements in a [ConditionalStatementBlock]. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum ConditionalStatements { + /// SELECT 1; SELECT 2; SELECT 3; ... + Sequence { statements: Vec }, + /// BEGIN SELECT 1; SELECT 2; SELECT 3; ... END + BeginEnd { + begin_token: AttachedToken, + statements: Vec, + end_token: AttachedToken, + }, +} + +impl ConditionalStatements { + pub fn statements(&self) -> &Vec { + match self { + ConditionalStatements::Sequence { statements } => statements, + ConditionalStatements::BeginEnd { statements, .. } => statements, + } + } +} + +impl fmt::Display for ConditionalStatements { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ConditionalStatements::Sequence { statements } => { + if statements.len() > 0 { + format_statement_list(f, statements)?; + } + Ok(()) + } + ConditionalStatements::BeginEnd { statements, .. } => { + write!(f, "BEGIN ")?; + format_statement_list(f, statements)?; + write!(f, " END") + } + } + } +} + /// A `RAISE` statement. /// /// Examples: diff --git a/src/ast/spans.rs b/src/ast/spans.rs index ff9545cdb..9e7bd51bc 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -24,19 +24,19 @@ use super::{ dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation, AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, AttachedToken, CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, - ConditionalStatements, ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource, - CreateIndex, CreateTable, CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem, - ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable, Function, FunctionArg, - FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, GroupByExpr, - HavingBound, IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, - JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause, - MatchRecognizePattern, Measure, MsSqlIfStatements, NamedWindowDefinition, ObjectName, - ObjectNamePart, Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, - OrderByKind, Partition, PivotValueSource, ProjectionSelect, Query, RaiseStatement, - RaiseStatementValue, ReferentialAction, RenameSelectItem, ReplaceSelectElement, - ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, - SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, TableFactor, TableObject, - TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef, + ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy, + ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte, + Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable, + Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, + FunctionArguments, GroupByExpr, HavingBound, IfStatement, IlikeSelectItem, Insert, Interpolate, + InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView, + LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition, ObjectName, ObjectNamePart, + Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, + PivotValueSource, ProjectionSelect, Query, RaiseStatement, RaiseStatementValue, + ReferentialAction, RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select, + SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias, + TableAliasColumnDef, TableConstraint, TableFactor, TableObject, TableOptionsClustered, + TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill, }; @@ -751,31 +751,34 @@ impl Spanned for CaseStatement { impl Spanned for IfStatement { fn span(&self) -> Span { - match self { - IfStatement::IfThenElseEnd { - if_token: AttachedToken(start), - end_if_token: AttachedToken(end), - .. - } => union_spans([start.span, end.span].into_iter()), - IfStatement::MsSqlIfElse { - if_token: AttachedToken(start), - if_statements, - else_statements, - .. - } => union_spans( - [start.span, if_statements.span()] - .into_iter() - .chain(else_statements.as_ref().into_iter().map(|s| s.span())), - ), - } + let IfStatement { + if_block, + elseif_blocks, + else_block, + end_token, + } = self; + + union_spans( + iter::once(if_block.span()) + .chain(elseif_blocks.iter().map(|b| b.span())) + .chain(else_block.as_ref().map(|b| b.span()).into_iter()) + .chain( + end_token + .as_ref() + .map(|AttachedToken(t)| t.span) + .into_iter(), + ), + ) } } -impl Spanned for MsSqlIfStatements { +impl Spanned for ConditionalStatements { fn span(&self) -> Span { match self { - MsSqlIfStatements::Single(s) => s.span(), - MsSqlIfStatements::Block { + ConditionalStatements::Sequence { statements } => { + union_spans(statements.iter().map(|s| s.span())) + } + ConditionalStatements::BeginEnd { begin_token: AttachedToken(start), end_token: AttachedToken(end), .. @@ -784,18 +787,25 @@ impl Spanned for MsSqlIfStatements { } } -impl Spanned for ConditionalStatements { +impl Spanned for ConditionalStatementBlock { fn span(&self) -> Span { - let ConditionalStatements { - start_token: AttachedToken(start), + let ConditionalStatementBlock { + start_token: AttachedToken(start_token), condition, - statements, + then_token, + conditional_statements, } = self; union_spans( - iter::once(start.span) + iter::once(start_token.span) .chain(condition.as_ref().map(|c| c.span()).into_iter()) - .chain(statements.iter().map(|s| s.span())), + .chain( + then_token + .as_ref() + .map(|AttachedToken(t)| t.span) + .into_iter(), + ) + .chain(iter::once(conditional_statements.span())), ) } } diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 9f85c3fa5..31886d7af 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -16,7 +16,7 @@ // under the License. use crate::ast::helpers::attached_token::AttachedToken; -use crate::ast::{IfStatement, MsSqlIfStatements, Statement}; +use crate::ast::{ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement}; use crate::dialect::Dialect; use crate::keywords::{self, Keyword}; use crate::parser::{Parser, ParserError}; @@ -139,43 +139,72 @@ impl MsSqlDialect { let condition = parser.parse_expr()?; - let if_statements; + let if_block; if parser.peek_keyword(Keyword::BEGIN) { let begin_token = parser.expect_keyword(Keyword::BEGIN)?; let statements = self.parse_statement_list(parser, Some(Keyword::END))?; let end_token = parser.expect_keyword(Keyword::END)?; - if_statements = MsSqlIfStatements::Block { - begin_token: AttachedToken(begin_token), - statements, - end_token: AttachedToken(end_token), + if_block = ConditionalStatementBlock { + start_token: AttachedToken(if_token), + condition: Some(condition), + then_token: None, + conditional_statements: ConditionalStatements::BeginEnd { + begin_token: AttachedToken(begin_token), + statements, + end_token: AttachedToken(end_token), + }, }; } else { let stmt = parser.parse_statement()?; - if_statements = MsSqlIfStatements::Single(Box::new(stmt)); + if_block = ConditionalStatementBlock { + start_token: AttachedToken(if_token), + condition: Some(condition), + then_token: None, + conditional_statements: ConditionalStatements::Sequence { + statements: vec![stmt], + }, + }; + } + + while let Token::SemiColon = parser.peek_token_ref().token { + parser.advance_token(); } - let mut else_statements = None; - if parser.parse_keyword(Keyword::ELSE) { + let mut else_block = None; + if parser.peek_keyword(Keyword::ELSE) { + let else_token = parser.expect_keyword(Keyword::ELSE)?; if parser.peek_keyword(Keyword::BEGIN) { let begin_token = parser.expect_keyword(Keyword::BEGIN)?; let statements = self.parse_statement_list(parser, Some(Keyword::END))?; let end_token = parser.expect_keyword(Keyword::END)?; - else_statements = Some(MsSqlIfStatements::Block { - begin_token: AttachedToken(begin_token), - statements, - end_token: AttachedToken(end_token), + else_block = Some(ConditionalStatementBlock { + start_token: AttachedToken(else_token), + condition: None, + then_token: None, + conditional_statements: ConditionalStatements::BeginEnd { + begin_token: AttachedToken(begin_token), + statements, + end_token: AttachedToken(end_token), + }, }); } else { let stmt = parser.parse_statement()?; - else_statements = Some(MsSqlIfStatements::Single(Box::new(stmt))); + else_block = Some(ConditionalStatementBlock { + start_token: AttachedToken(else_token), + condition: None, + then_token: None, + conditional_statements: ConditionalStatements::Sequence { + statements: vec![stmt], + }, + }); } } - Ok(Statement::If(IfStatement::MsSqlIfElse { - if_token: AttachedToken(if_token), - condition, - if_statements, - else_statements, + Ok(Statement::If(IfStatement { + if_block, + else_block, + elseif_blocks: Vec::new(), + end_token: None, })) } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 47ad5c8a0..4720058a4 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -641,11 +641,11 @@ impl<'a> Parser<'a> { self.expect_keyword_is(Keyword::WHEN)?; let when_blocks = self.parse_keyword_separated(Keyword::WHEN, |parser| { - parser.parse_conditional_statements(&[Keyword::WHEN, Keyword::ELSE, Keyword::END]) + parser.parse_conditional_statement_block(&[Keyword::WHEN, Keyword::ELSE, Keyword::END]) })?; let else_block = if self.parse_keyword(Keyword::ELSE) { - Some(self.parse_conditional_statements(&[Keyword::END])?) + Some(self.parse_conditional_statement_block(&[Keyword::END])?) } else { None }; @@ -668,33 +668,39 @@ impl<'a> Parser<'a> { /// /// See [Statement::If] pub fn parse_if_stmt(&mut self) -> Result { - let if_token = self.expect_keyword(Keyword::IF)?; - let if_block = - self.parse_conditional_statements(&[Keyword::ELSE, Keyword::ELSEIF, Keyword::END])?; + self.expect_keyword_is(Keyword::IF)?; + let if_block = self.parse_conditional_statement_block(&[ + Keyword::ELSE, + Keyword::ELSEIF, + Keyword::END, + ])?; let elseif_blocks = if self.parse_keyword(Keyword::ELSEIF) { self.parse_keyword_separated(Keyword::ELSEIF, |parser| { - parser.parse_conditional_statements(&[Keyword::ELSEIF, Keyword::ELSE, Keyword::END]) + parser.parse_conditional_statement_block(&[ + Keyword::ELSEIF, + Keyword::ELSE, + Keyword::END, + ]) })? } else { vec![] }; let else_block = if self.parse_keyword(Keyword::ELSE) { - Some(self.parse_conditional_statements(&[Keyword::END])?) + Some(self.parse_conditional_statement_block(&[Keyword::END])?) } else { None }; self.expect_keyword_is(Keyword::END)?; - let end_if_token = self.expect_keyword(Keyword::IF)?; + let end_token = self.expect_keyword(Keyword::IF)?; - Ok(Statement::If(IfStatement::IfThenElseEnd { - if_token: AttachedToken(if_token), + Ok(Statement::If(IfStatement { if_block, elseif_blocks, else_block, - end_if_token: AttachedToken(end_if_token), + end_token: Some(AttachedToken(end_token)), })) } @@ -705,27 +711,29 @@ impl<'a> Parser<'a> { /// ```sql /// IF condition THEN statement1; statement2; /// ``` - fn parse_conditional_statements( + fn parse_conditional_statement_block( &mut self, terminal_keywords: &[Keyword], - ) -> Result { - let start_token = self.get_current_token().clone(); + ) -> Result { + let start_token = self.get_current_token().clone(); // self.expect_keyword(keyword)?; + let mut then_token = None; let condition = match &start_token.token { Token::Word(w) if w.keyword == Keyword::ELSE => None, _ => { let expr = self.parse_expr()?; - self.expect_keyword_is(Keyword::THEN)?; + then_token = Some(AttachedToken(self.expect_keyword(Keyword::THEN)?)); Some(expr) } }; let statements = self.parse_statement_list(terminal_keywords)?; - Ok(ConditionalStatements { + Ok(ConditionalStatementBlock { start_token: AttachedToken(start_token), condition, - statements, + then_token, + conditional_statements: ConditionalStatements::Sequence { statements }, }) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 5914a3dea..66215de86 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -14219,8 +14219,8 @@ fn parse_case_statement() { Some(Expr::value(number("2"))), stmt.when_blocks[0].condition ); - assert_eq!(2, stmt.when_blocks[0].statements.len()); - assert_eq!(1, stmt.else_block.unwrap().statements.len()); + assert_eq!(2, stmt.when_blocks[0].statements().len()); + assert_eq!(1, stmt.else_block.unwrap().statements().len()); verified_stmt(concat!( "CASE 1", @@ -14278,7 +14278,7 @@ fn parse_if_statement() { let dialects = all_dialects_except(|d| d.is::()); let sql = "IF 1 THEN SELECT 1; ELSEIF 2 THEN SELECT 2; ELSE SELECT 3; END IF"; - let Statement::If(IfStatement::IfThenElseEnd { + let Statement::If(IfStatement { if_block, elseif_blocks, else_block, @@ -14289,7 +14289,7 @@ fn parse_if_statement() { }; assert_eq!(Some(Expr::value(number("1"))), if_block.condition); assert_eq!(Some(Expr::value(number("2"))), elseif_blocks[0].condition); - assert_eq!(1, else_block.unwrap().statements.len()); + assert_eq!(1, else_block.unwrap().statements().len()); dialects.verified_stmt(concat!( "IF 1 THEN", @@ -14376,7 +14376,7 @@ fn test_conditional_statement_span() { let sql = "IF 1=1 THEN SELECT 1; ELSEIF 1=2 THEN SELECT 2; ELSE SELECT 3; END IF"; let mut parser = Parser::new(&GenericDialect {}).try_with_sql(sql).unwrap(); match parser.parse_statement().unwrap() { - Statement::If(IfStatement::IfThenElseEnd { + Statement::If(IfStatement { if_block, elseif_blocks, else_block, diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 4c13760c3..c25f614a5 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -1860,13 +1860,13 @@ fn parse_mssql_set_session_value() { #[test] fn parse_mssql_if_else() { // Simple statements and blocks - ms().verified_stmt("IF 1 = 1 SELECT '1' ELSE SELECT '2'"); - ms().verified_stmt("IF 1 = 1 BEGIN SET @A = 1; END ELSE SET @A = 2"); + ms().verified_stmt("IF 1 = 1 SELECT '1'; ELSE SELECT '2';"); + ms().verified_stmt("IF 1 = 1 BEGIN SET @A = 1; END ELSE SET @A = 2;"); ms().verified_stmt( - "IF DATENAME(weekday, GETDATE()) IN (N'Saturday', N'Sunday') SELECT 'Weekend' ELSE SELECT 'Weekday'" + "IF DATENAME(weekday, GETDATE()) IN (N'Saturday', N'Sunday') SELECT 'Weekend'; ELSE SELECT 'Weekday';" ); ms().verified_stmt( - "IF (SELECT COUNT(*) FROM a.b WHERE c LIKE 'x%') > 1 SELECT 'yes' ELSE SELECT 'No'", + "IF (SELECT COUNT(*) FROM a.b WHERE c LIKE 'x%') > 1 SELECT 'yes'; ELSE SELECT 'No';", ); // Multiple statements @@ -1877,7 +1877,7 @@ fn parse_mssql_if_else() { [Statement::Declare { .. }, Statement::If(stmt)] => { assert_eq!( stmt.to_string(), - "IF 1 = 1 BEGIN SET @A = 1; END ELSE SET @A = 2" + "IF 1 = 1 BEGIN SET @A = 1; END ELSE SET @A = 2;" ); } _ => panic!("Unexpected statements: {:?}", stmts), @@ -1916,18 +1916,18 @@ fn test_mssql_if_statements_span() { let mut sql = "IF 1 = 1 SELECT '1' ELSE SELECT '2'"; let mut parser = Parser::new(&MsSqlDialect {}).try_with_sql(sql).unwrap(); match parser.parse_statement().unwrap() { - Statement::If(IfStatement::MsSqlIfElse { - if_statements, - else_statements: Some(else_statements), + Statement::If(IfStatement { + if_block, + else_block: Some(else_block), .. }) => { assert_eq!( - if_statements.span(), - Span::new(Location::new(1, 10), Location::new(1, 20)) + if_block.span(), + Span::new(Location::new(1, 1), Location::new(1, 20)) ); assert_eq!( - else_statements.span(), - Span::new(Location::new(1, 26), Location::new(1, 36)) + else_block.span(), + Span::new(Location::new(1, 21), Location::new(1, 36)) ); } stmt => panic!("Unexpected statement: {:?}", stmt), @@ -1937,18 +1937,18 @@ fn test_mssql_if_statements_span() { sql = "IF 1 = 1 BEGIN SET @A = 1; END ELSE BEGIN SET @A = 2 END"; parser = Parser::new(&MsSqlDialect {}).try_with_sql(sql).unwrap(); match parser.parse_statement().unwrap() { - Statement::If(IfStatement::MsSqlIfElse { - if_statements, - else_statements: Some(else_statements), + Statement::If(IfStatement { + if_block, + else_block: Some(else_block), .. }) => { assert_eq!( - if_statements.span(), - Span::new(Location::new(1, 10), Location::new(1, 31)) + if_block.span(), + Span::new(Location::new(1, 1), Location::new(1, 31)) ); assert_eq!( - else_statements.span(), - Span::new(Location::new(1, 37), Location::new(1, 57)) + else_block.span(), + Span::new(Location::new(1, 32), Location::new(1, 57)) ); } stmt => panic!("Unexpected statement: {:?}", stmt), From 90acac011d6a459a1e0ac188f8a67b002d92b550 Mon Sep 17 00:00:00 2001 From: Roman Borschel Date: Fri, 4 Apr 2025 21:46:08 +0200 Subject: [PATCH 5/9] Address linter errors. --- src/ast/mod.rs | 4 ++-- src/ast/spans.rs | 18 ++++-------------- src/dialect/mssql.rs | 13 ++++++------- 3 files changed, 12 insertions(+), 23 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index dad29d17e..7421267eb 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2273,7 +2273,7 @@ impl fmt::Display for ConditionalStatementBlock { write!(f, " THEN")?; } - if conditional_statements.statements().len() > 0 { + if !conditional_statements.statements().is_empty() { write!(f, " {conditional_statements}")?; } @@ -2309,7 +2309,7 @@ impl fmt::Display for ConditionalStatements { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ConditionalStatements::Sequence { statements } => { - if statements.len() > 0 { + if !statements.is_empty() { format_statement_list(f, statements)?; } Ok(()) diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 9e7bd51bc..68001e4c0 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -761,13 +761,8 @@ impl Spanned for IfStatement { union_spans( iter::once(if_block.span()) .chain(elseif_blocks.iter().map(|b| b.span())) - .chain(else_block.as_ref().map(|b| b.span()).into_iter()) - .chain( - end_token - .as_ref() - .map(|AttachedToken(t)| t.span) - .into_iter(), - ), + .chain(else_block.as_ref().map(|b| b.span())) + .chain(end_token.as_ref().map(|AttachedToken(t)| t.span)), ) } } @@ -798,13 +793,8 @@ impl Spanned for ConditionalStatementBlock { union_spans( iter::once(start_token.span) - .chain(condition.as_ref().map(|c| c.span()).into_iter()) - .chain( - then_token - .as_ref() - .map(|AttachedToken(t)| t.span) - .into_iter(), - ) + .chain(condition.as_ref().map(|c| c.span())) + .chain(then_token.as_ref().map(|AttachedToken(t)| t.span)) .chain(iter::once(conditional_statements.span())), ) } diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 31886d7af..522234efe 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -139,12 +139,11 @@ impl MsSqlDialect { let condition = parser.parse_expr()?; - let if_block; - if parser.peek_keyword(Keyword::BEGIN) { + let if_block = if parser.peek_keyword(Keyword::BEGIN) { let begin_token = parser.expect_keyword(Keyword::BEGIN)?; let statements = self.parse_statement_list(parser, Some(Keyword::END))?; let end_token = parser.expect_keyword(Keyword::END)?; - if_block = ConditionalStatementBlock { + ConditionalStatementBlock { start_token: AttachedToken(if_token), condition: Some(condition), then_token: None, @@ -153,18 +152,18 @@ impl MsSqlDialect { statements, end_token: AttachedToken(end_token), }, - }; + } } else { let stmt = parser.parse_statement()?; - if_block = ConditionalStatementBlock { + ConditionalStatementBlock { start_token: AttachedToken(if_token), condition: Some(condition), then_token: None, conditional_statements: ConditionalStatements::Sequence { statements: vec![stmt], }, - }; - } + } + }; while let Token::SemiColon = parser.peek_token_ref().token { parser.advance_token(); From 6a0deb62a14e2650b96a563da37505fcc02d23e6 Mon Sep 17 00:00:00 2001 From: Roman Borschel Date: Fri, 4 Apr 2025 22:06:32 +0200 Subject: [PATCH 6/9] Fix no-std build. --- src/dialect/mssql.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 522234efe..8abb46eba 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#[cfg(not(feature = "std"))] +use alloc::{vec, vec::Vec}; use crate::ast::helpers::attached_token::AttachedToken; use crate::ast::{ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement}; use crate::dialect::Dialect; From 647c84b0f7479b3a8332d5f70a07c9c144448b85 Mon Sep 17 00:00:00 2001 From: Roman Borschel Date: Fri, 4 Apr 2025 22:08:32 +0200 Subject: [PATCH 7/9] One more linter fix. --- src/dialect/mssql.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 8abb46eba..d86d68a20 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -#[cfg(not(feature = "std"))] -use alloc::{vec, vec::Vec}; use crate::ast::helpers::attached_token::AttachedToken; use crate::ast::{ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement}; use crate::dialect::Dialect; use crate::keywords::{self, Keyword}; use crate::parser::{Parser, ParserError}; use crate::tokenizer::Token; +#[cfg(not(feature = "std"))] +use alloc::{vec, vec::Vec}; const RESERVED_FOR_COLUMN_ALIAS: &[Keyword] = &[Keyword::IF, Keyword::ELSE]; From 535793d18c814afbaab7dd4effd2a4e29893911d Mon Sep 17 00:00:00 2001 From: Roman Borschel Date: Sat, 5 Apr 2025 20:06:20 +0200 Subject: [PATCH 8/9] Use exhaustive let-match in Spanned impls. --- src/ast/spans.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 68001e4c0..d253f8914 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -741,8 +741,10 @@ impl Spanned for CaseStatement { fn span(&self) -> Span { let CaseStatement { case_token: AttachedToken(start), + match_expr: _, + when_blocks: _, + else_block: _, end_case_token: AttachedToken(end), - .. } = self; union_spans([start.span, end.span].into_iter()) @@ -775,8 +777,8 @@ impl Spanned for ConditionalStatements { } ConditionalStatements::BeginEnd { begin_token: AttachedToken(start), + statements: _, end_token: AttachedToken(end), - .. } => union_spans([start.span, end.span].into_iter()), } } From 6a6e561932eb5d7761fce8a937a197477f04b99f Mon Sep 17 00:00:00 2001 From: Roman Borschel Date: Sat, 5 Apr 2025 20:12:49 +0200 Subject: [PATCH 9/9] Fix (my own) lint error from a previous PR. --- tests/sqlparser_redshift.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sqlparser_redshift.rs b/tests/sqlparser_redshift.rs index c75abe16f..060e3853d 100644 --- a/tests/sqlparser_redshift.rs +++ b/tests/sqlparser_redshift.rs @@ -395,5 +395,5 @@ fn test_parse_nested_quoted_identifier() { #[test] fn parse_extract_single_quotes() { let sql = "SELECT EXTRACT('month' FROM my_timestamp) FROM my_table"; - redshift().verified_stmt(&sql); + redshift().verified_stmt(sql); }