Skip to content

Add CREATE TRIGGER support for SQL Server #1810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2380,11 +2380,16 @@ impl fmt::Display for BeginEndStatements {
end_token: AttachedToken(end_token),
} = self;

write!(f, "{begin_token} ")?;
if begin_token.token != Token::EOF {
write!(f, "{begin_token} ")?;
}
if !statements.is_empty() {
format_statement_list(f, statements)?;
}
write!(f, " {end_token}")
if end_token.token != Token::EOF {
write!(f, " {end_token}")?;
}
Ok(())
}
}

Expand Down Expand Up @@ -3729,7 +3734,12 @@ pub enum Statement {
/// ```
///
/// Postgres: <https://www.postgresql.org/docs/current/sql-createtrigger.html>
/// SQL Server: <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql>
CreateTrigger {
/// True if this is a `CREATE OR ALTER TRIGGER` statement
///
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql?view=sql-server-ver16#arguments)
or_alter: bool,
/// The `OR REPLACE` clause is used to re-create the trigger if it already exists.
///
/// Example:
Expand Down Expand Up @@ -3790,7 +3800,9 @@ pub enum Statement {
/// Triggering conditions
condition: Option<Expr>,
/// Execute logic block
exec_body: TriggerExecBody,
exec_body: Option<TriggerExecBody>,
/// For SQL dialects with statement(s) for a body
statements: Option<ConditionalStatements>,
/// The characteristic of the trigger, which include whether the trigger is `DEFERRABLE`, `INITIALLY DEFERRED`, or `INITIALLY IMMEDIATE`,
characteristics: Option<ConstraintCharacteristics>,
},
Expand Down Expand Up @@ -4587,6 +4599,7 @@ impl fmt::Display for Statement {
}
Statement::CreateFunction(create_function) => create_function.fmt(f),
Statement::CreateTrigger {
or_alter,
or_replace,
is_constraint,
name,
Expand All @@ -4599,19 +4612,30 @@ impl fmt::Display for Statement {
condition,
include_each,
exec_body,
statements,
characteristics,
} => {
write!(
f,
"CREATE {or_replace}{is_constraint}TRIGGER {name} {period}",
"CREATE {or_alter}{or_replace}{is_constraint}TRIGGER {name} ",
or_alter = if *or_alter { "OR ALTER " } else { "" },
or_replace = if *or_replace { "OR REPLACE " } else { "" },
is_constraint = if *is_constraint { "CONSTRAINT " } else { "" },
)?;

if !events.is_empty() {
write!(f, " {}", display_separated(events, " OR "))?;
if exec_body.is_some() {
write!(f, "{period}")?;
if !events.is_empty() {
write!(f, " {}", display_separated(events, " OR "))?;
}
write!(f, " ON {table_name}")?;
} else {
write!(f, "ON {table_name}")?;
write!(f, " {period}")?;
if !events.is_empty() {
write!(f, " {}", display_separated(events, ", "))?;
}
}
write!(f, " ON {table_name}")?;

if let Some(referenced_table_name) = referenced_table_name {
write!(f, " FROM {referenced_table_name}")?;
Expand All @@ -4627,13 +4651,19 @@ impl fmt::Display for Statement {

if *include_each {
write!(f, " FOR EACH {trigger_object}")?;
} else {
} else if exec_body.is_some() {
write!(f, " FOR {trigger_object}")?;
}
if let Some(condition) = condition {
write!(f, " WHEN {condition}")?;
}
write!(f, " EXECUTE {exec_body}")
if let Some(exec_body) = exec_body {
write!(f, " EXECUTE {exec_body}")?;
}
if let Some(statements) = statements {
write!(f, " AS {statements}")?;
}
Ok(())
}
Statement::DropTrigger {
if_exists,
Expand Down
2 changes: 2 additions & 0 deletions src/ast/trigger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ impl fmt::Display for TriggerEvent {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TriggerPeriod {
For,
After,
Before,
InsteadOf,
Expand All @@ -118,6 +119,7 @@ pub enum TriggerPeriod {
impl fmt::Display for TriggerPeriod {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
TriggerPeriod::For => write!(f, "FOR"),
TriggerPeriod::After => write!(f, "AFTER"),
TriggerPeriod::Before => write!(f, "BEFORE"),
TriggerPeriod::InsteadOf => write!(f, "INSTEAD OF"),
Expand Down
46 changes: 46 additions & 0 deletions src/dialect/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use crate::ast::helpers::attached_token::AttachedToken;
use crate::ast::{
BeginEndStatements, ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement,
TriggerObject,
};
use crate::dialect::Dialect;
use crate::keywords::{self, Keyword};
Expand Down Expand Up @@ -125,6 +126,15 @@ impl Dialect for MsSqlDialect {
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.peek_keyword(Keyword::IF) {
Some(self.parse_if_stmt(parser))
} else if parser.parse_keywords(&[Keyword::CREATE, Keyword::TRIGGER]) {
Some(self.parse_create_trigger(parser, false))
} else if parser.parse_keywords(&[
Keyword::CREATE,
Keyword::OR,
Keyword::ALTER,
Keyword::TRIGGER,
]) {
Some(self.parse_create_trigger(parser, true))
} else {
None
}
Expand Down Expand Up @@ -215,6 +225,42 @@ impl MsSqlDialect {
}))
}

/// Parse `CREATE TRIGGER` for [MsSql]
///
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql
fn parse_create_trigger(
&self,
parser: &mut Parser,
or_alter: bool,
) -> Result<Statement, ParserError> {
let name = parser.parse_object_name(false)?;
parser.expect_keyword_is(Keyword::ON)?;
let table_name = parser.parse_object_name(false)?;
let period = parser.parse_trigger_period()?;
let events = parser.parse_comma_separated(Parser::parse_trigger_event)?;

parser.expect_keyword_is(Keyword::AS)?;
let statements = Some(parser.parse_conditional_statements(&[Keyword::END])?);

Ok(Statement::CreateTrigger {
or_alter,
or_replace: false,
is_constraint: false,
name,
period,
events,
table_name,
referenced_table_name: None,
referencing: Vec::new(),
trigger_object: TriggerObject::Statement,
include_each: false,
condition: None,
exec_body: None,
statements,
characteristics: None,
})
}

/// Parse a sequence of statements, optionally separated by semicolon.
///
/// Stops parsing when reaching EOF or the given keyword.
Expand Down
43 changes: 30 additions & 13 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -745,26 +745,38 @@ impl<'a> Parser<'a> {
}
};

let conditional_statements = self.parse_conditional_statements(terminal_keywords)?;

Ok(ConditionalStatementBlock {
start_token: AttachedToken(start_token),
condition,
then_token,
conditional_statements,
})
}

/// Parse a BEGIN/END block or a sequence of statements
/// This could be inside of a conditional (IF, CASE, WHILE etc.) or an object body defined optionally BEGIN/END and one or more statements.
pub(crate) fn parse_conditional_statements(
&mut self,
terminal_keywords: &[Keyword],
) -> Result<ConditionalStatements, ParserError> {
let conditional_statements = if self.peek_keyword(Keyword::BEGIN) {
let begin_token = self.expect_keyword(Keyword::BEGIN)?;
let statements = self.parse_statement_list(terminal_keywords)?;
let end_token = self.expect_keyword(Keyword::END)?;

ConditionalStatements::BeginEnd(BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
})
} else {
let statements = self.parse_statement_list(terminal_keywords)?;
ConditionalStatements::Sequence { statements }
ConditionalStatements::Sequence {
statements: self.parse_statement_list(terminal_keywords)?,
}
};

Ok(ConditionalStatementBlock {
start_token: AttachedToken(start_token),
condition,
then_token,
conditional_statements,
})
Ok(conditional_statements)
}

/// Parse a `RAISE` statement.
Expand Down Expand Up @@ -4614,9 +4626,9 @@ impl<'a> Parser<'a> {
} else if self.parse_keyword(Keyword::FUNCTION) {
self.parse_create_function(or_alter, or_replace, temporary)
} else if self.parse_keyword(Keyword::TRIGGER) {
self.parse_create_trigger(or_replace, false)
self.parse_create_trigger(or_alter, or_replace, false)
} else if self.parse_keywords(&[Keyword::CONSTRAINT, Keyword::TRIGGER]) {
self.parse_create_trigger(or_replace, true)
self.parse_create_trigger(or_alter, or_replace, true)
} else if self.parse_keyword(Keyword::MACRO) {
self.parse_create_macro(or_replace, temporary)
} else if self.parse_keyword(Keyword::SECRET) {
Expand Down Expand Up @@ -5314,10 +5326,11 @@ impl<'a> Parser<'a> {

pub fn parse_create_trigger(
&mut self,
or_alter: bool,
or_replace: bool,
is_constraint: bool,
) -> Result<Statement, ParserError> {
if !dialect_of!(self is PostgreSqlDialect | GenericDialect | MySqlDialect) {
if !dialect_of!(self is PostgreSqlDialect | GenericDialect | MySqlDialect | MsSqlDialect) {
self.prev_token();
return self.expected("an object type after CREATE", self.peek_token());
}
Expand Down Expand Up @@ -5363,6 +5376,7 @@ impl<'a> Parser<'a> {
let exec_body = self.parse_trigger_exec_body()?;

Ok(Statement::CreateTrigger {
or_alter,
or_replace,
is_constraint,
name,
Expand All @@ -5374,18 +5388,21 @@ impl<'a> Parser<'a> {
trigger_object,
include_each,
condition,
exec_body,
exec_body: Some(exec_body),
statements: None,
characteristics,
})
}

pub fn parse_trigger_period(&mut self) -> Result<TriggerPeriod, ParserError> {
Ok(
match self.expect_one_of_keywords(&[
Keyword::FOR,
Keyword::BEFORE,
Keyword::AFTER,
Keyword::INSTEAD,
])? {
Keyword::FOR => TriggerPeriod::For,
Keyword::BEFORE => TriggerPeriod::Before,
Keyword::AFTER => TriggerPeriod::After,
Keyword::INSTEAD => self
Expand Down
Loading