Skip to content

Improve support for cursors for SQL Server #1831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 2, 2025
90 changes: 86 additions & 4 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2226,7 +2226,33 @@ impl fmt::Display for IfStatement {
}
}

/// A block within a [Statement::Case] or [Statement::If]-like statement
/// A `WHILE` statement.
///
/// Example:
/// ```sql
/// WHILE @@FETCH_STATUS = 0
/// BEGIN
/// FETCH NEXT FROM c1 INTO @var1, @var2;
/// END
/// ```
///
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/while-transact-sql)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct WhileStatement {
pub while_block: ConditionalStatementBlock,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't absolutely need a WhileStatement struct; we could be doing Statement::While(ConditionalStatementBlock) instead. I'm following the example of CASE & IF, which also do it this way.

}

impl fmt::Display for WhileStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let WhileStatement { while_block } = self;
write!(f, "{while_block}")?;
Ok(())
}
}

/// A block within a [Statement::Case] or [Statement::If] or [Statement::While]-like statement
///
/// Example 1:
/// ```sql
Expand All @@ -2242,6 +2268,14 @@ impl fmt::Display for IfStatement {
/// ```sql
/// ELSE SELECT 1; SELECT 2;
/// ```
///
/// Example 4:
/// ```sql
/// WHILE @@FETCH_STATUS = 0
/// BEGIN
/// FETCH NEXT FROM c1 INTO @var1, @var2;
/// END
/// ```
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
Expand Down Expand Up @@ -2981,6 +3015,8 @@ pub enum Statement {
Case(CaseStatement),
/// An `IF` statement.
If(IfStatement),
/// A `WHILE` statement.
While(WhileStatement),
/// A `RAISE` statement.
Raise(RaiseStatement),
/// ```sql
Expand Down Expand Up @@ -3032,6 +3068,11 @@ pub enum Statement {
partition: Option<Box<Expr>>,
},
/// ```sql
/// OPEN cursor_name
/// ```
/// Opens a cursor.
Open(OpenStatement),
/// ```sql
/// CLOSE
/// ```
/// Closes the portal underlying an open cursor.
Expand Down Expand Up @@ -3403,6 +3444,7 @@ pub enum Statement {
/// Cursor name
name: Ident,
direction: FetchDirection,
position: FetchPosition,
/// Optional, It's possible to fetch rows form cursor to the table
into: Option<ObjectName>,
},
Expand Down Expand Up @@ -4225,11 +4267,10 @@ impl fmt::Display for Statement {
Statement::Fetch {
name,
direction,
position,
into,
} => {
write!(f, "FETCH {direction} ")?;

write!(f, "IN {name}")?;
write!(f, "FETCH {direction} {position} {name}")?;
Comment on lines -4230 to +4273
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could probably be write!(f, "{position} {name}")?;, not sure what the pro/con is on that


if let Some(into) = into {
write!(f, " INTO {into}")?;
Expand Down Expand Up @@ -4319,6 +4360,9 @@ impl fmt::Display for Statement {
Statement::If(stmt) => {
write!(f, "{stmt}")
}
Statement::While(stmt) => {
write!(f, "{stmt}")
}
Statement::Raise(stmt) => {
write!(f, "{stmt}")
}
Expand Down Expand Up @@ -4488,6 +4532,7 @@ impl fmt::Display for Statement {
Ok(())
}
Statement::Delete(delete) => write!(f, "{delete}"),
Statement::Open(open) => write!(f, "{open}"),
Statement::Close { cursor } => {
write!(f, "CLOSE {cursor}")?;

Expand Down Expand Up @@ -6162,6 +6207,28 @@ impl fmt::Display for FetchDirection {
}
}

/// The "position" for a FETCH statement.
///
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/fetch-transact-sql)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FetchPosition {
From,
In,
}

impl fmt::Display for FetchPosition {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FetchPosition::From => f.write_str("FROM")?,
FetchPosition::In => f.write_str("IN")?,
};

Ok(())
}
}

/// A privilege on a database object (table, sequence, etc.).
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down Expand Up @@ -9316,6 +9383,21 @@ pub enum ReturnStatementValue {
Expr(Expr),
}

/// Represents an `OPEN` statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct OpenStatement {
/// Cursor name
pub cursor_name: Ident,
}

impl fmt::Display for OpenStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "OPEN {}", self.cursor_name)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
31 changes: 24 additions & 7 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ use super::{
FunctionArguments, GroupByExpr, HavingBound, IfStatement, IlikeSelectItem, Insert, Interpolate,
InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView,
LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition, ObjectName, ObjectNamePart,
Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition,
PivotValueSource, ProjectionSelect, Query, RaiseStatement, RaiseStatementValue,
ReferentialAction, RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select,
SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias,
TableAliasColumnDef, TableConstraint, TableFactor, TableObject, TableOptionsClustered,
TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
WildcardAdditionalOptions, With, WithFill,
Offset, OnConflict, OnConflictAction, OnInsert, OpenStatement, OrderBy, OrderByExpr,
OrderByKind, Partition, PivotValueSource, ProjectionSelect, Query, RaiseStatement,
RaiseStatementValue, ReferentialAction, RenameSelectItem, ReplaceSelectElement,
ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript,
SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, TableFactor, TableObject,
TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
WhileStatement, WildcardAdditionalOptions, With, WithFill,
};

/// Given an iterator of spans, return the [Span::union] of all spans.
Expand Down Expand Up @@ -338,6 +338,7 @@ impl Spanned for Statement {
} => source.span(),
Statement::Case(stmt) => stmt.span(),
Statement::If(stmt) => stmt.span(),
Statement::While(stmt) => stmt.span(),
Statement::Raise(stmt) => stmt.span(),
Statement::Call(function) => function.span(),
Statement::Copy {
Expand All @@ -364,6 +365,7 @@ impl Spanned for Statement {
from_query: _,
partition: _,
} => Span::empty(),
Statement::Open(open) => open.span(),
Statement::Close { cursor } => match cursor {
CloseCursor::All => Span::empty(),
CloseCursor::Specific { name } => name.span,
Expand Down Expand Up @@ -774,6 +776,14 @@ impl Spanned for IfStatement {
}
}

impl Spanned for WhileStatement {
fn span(&self) -> Span {
let WhileStatement { while_block } = self;

while_block.span()
}
}

impl Spanned for ConditionalStatements {
fn span(&self) -> Span {
match self {
Expand Down Expand Up @@ -2295,6 +2305,13 @@ impl Spanned for BeginEndStatements {
}
}

impl Spanned for OpenStatement {
fn span(&self) -> Span {
let OpenStatement { cursor_name } = self;
cursor_name.span
}
}

#[cfg(test)]
pub mod tests {
use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect};
Expand Down
2 changes: 2 additions & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,7 @@ define_keywords!(
WHEN,
WHENEVER,
WHERE,
WHILE,
WIDTH_BUCKET,
WINDOW,
WITH,
Expand Down Expand Up @@ -1064,6 +1065,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[
Keyword::SAMPLE,
Keyword::TABLESAMPLE,
Keyword::FROM,
Keyword::OPEN,
];

/// Can't be used as a column alias, so that `SELECT <expr> alias`
Expand Down
70 changes: 63 additions & 7 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,10 @@ impl<'a> Parser<'a> {
self.prev_token();
self.parse_if_stmt()
}
Keyword::WHILE => {
self.prev_token();
self.parse_while()
}
Keyword::RAISE => {
self.prev_token();
self.parse_raise_stmt()
Expand Down Expand Up @@ -570,6 +574,10 @@ impl<'a> Parser<'a> {
Keyword::ALTER => self.parse_alter(),
Keyword::CALL => self.parse_call(),
Keyword::COPY => self.parse_copy(),
Keyword::OPEN => {
self.prev_token();
self.parse_open()
}
Keyword::CLOSE => self.parse_close(),
Keyword::SET => self.parse_set(),
Keyword::SHOW => self.parse_show(),
Expand Down Expand Up @@ -700,8 +708,18 @@ impl<'a> Parser<'a> {
}))
}

/// Parse a `WHILE` statement.
///
/// See [Statement::While]
fn parse_while(&mut self) -> Result<Statement, ParserError> {
self.expect_keyword_is(Keyword::WHILE)?;
let while_block = self.parse_conditional_statement_block(&[Keyword::END])?;

Ok(Statement::While(WhileStatement { while_block }))
}

/// Parses an expression and associated list of statements
/// belonging to a conditional statement like `IF` or `WHEN`.
/// belonging to a conditional statement like `IF` or `WHEN` or `WHILE`.
///
/// Example:
/// ```sql
Expand All @@ -716,20 +734,36 @@ impl<'a> Parser<'a> {

let condition = match &start_token.token {
Token::Word(w) if w.keyword == Keyword::ELSE => None,
Token::Word(w) if w.keyword == Keyword::WHILE => {
let expr = self.parse_expr()?;
Some(expr)
}
_ => {
let expr = self.parse_expr()?;
then_token = Some(AttachedToken(self.expect_keyword(Keyword::THEN)?));
Some(expr)
}
};

let statements = self.parse_statement_list(terminal_keywords)?;
let conditional_statements = if self.peek_keyword(Keyword::BEGIN) {
let begin_token = self.expect_keyword(Keyword::BEGIN)?;
let statements = self.parse_statement_list(terminal_keywords)?;
let end_token = self.expect_keyword(Keyword::END)?;
Comment on lines +749 to +751
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We seem to have this pattern upcoming in a few places, like #1810 maybe it would be good to pull it out into a method and reuse it both here and the preexisting usage here? We can probably do so in the former PR instead

ConditionalStatements::BeginEnd(BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
})
} else {
let statements = self.parse_statement_list(terminal_keywords)?;
ConditionalStatements::Sequence { statements }
};

Ok(ConditionalStatementBlock {
start_token: AttachedToken(start_token),
condition,
then_token,
conditional_statements: ConditionalStatements::Sequence { statements },
conditional_statements,
})
}

Expand Down Expand Up @@ -4448,11 +4482,16 @@ impl<'a> Parser<'a> {
) -> Result<Vec<Statement>, ParserError> {
let mut values = vec![];
loop {
if let Token::Word(w) = &self.peek_nth_token_ref(0).token {
if w.quote_style.is_none() && terminal_keywords.contains(&w.keyword) {
break;
match &self.peek_nth_token_ref(0).token {
Token::EOF => break,
Token::Word(w) => {
if w.quote_style.is_none() && terminal_keywords.contains(&w.keyword) {
break;
}
}
_ => {}
}

values.push(self.parse_statement()?);
self.expect_token(&Token::SemiColon)?;
}
Expand Down Expand Up @@ -6609,7 +6648,15 @@ impl<'a> Parser<'a> {
}
};

self.expect_one_of_keywords(&[Keyword::FROM, Keyword::IN])?;
let position = if self.peek_keyword(Keyword::FROM) {
self.expect_keyword(Keyword::FROM)?;
FetchPosition::From
} else if self.peek_keyword(Keyword::IN) {
self.expect_keyword(Keyword::IN)?;
FetchPosition::In
} else {
return parser_err!("Expected FROM or IN", self.peek_token().span.start);
};

let name = self.parse_identifier()?;

Expand All @@ -6622,6 +6669,7 @@ impl<'a> Parser<'a> {
Ok(Statement::Fetch {
name,
direction,
position,
into,
})
}
Expand Down Expand Up @@ -8735,6 +8783,14 @@ impl<'a> Parser<'a> {
})
}

/// Parse [Statement::Open]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I missed it, we seem to be lacking test cases for the open statement feature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's part of test_mssql_cursor, but I'll make a separate test function just for OPEN for clarity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

fn parse_open(&mut self) -> Result<Statement, ParserError> {
self.expect_keyword(Keyword::OPEN)?;
Ok(Statement::Open(OpenStatement {
cursor_name: self.parse_identifier()?,
}))
}

pub fn parse_close(&mut self) -> Result<Statement, ParserError> {
let cursor = if self.parse_keyword(Keyword::ALL) {
CloseCursor::All
Expand Down
Loading