diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 865577cae..7b143349b 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -873,12 +873,28 @@ impl fmt::Display for Assignment { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum FunctionArg { + Named { name: Ident, arg: Expr }, + Unnamed(Expr), +} + +impl fmt::Display for FunctionArg { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + FunctionArg::Named { name, arg } => write!(f, "{} => {}", name, arg), + FunctionArg::Unnamed(unnamed_arg) => write!(f, "{}", unnamed_arg), + } + } +} + /// A function call #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Function { pub name: ObjectName, - pub args: Vec, + pub args: Vec, pub over: Option, // aggregate functions may specify eg `COUNT(DISTINCT x)` pub distinct: bool, diff --git a/src/ast/query.rs b/src/ast/query.rs index 73477b126..ce57fcf7b 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -226,7 +226,7 @@ pub enum TableFactor { /// Arguments of a table-valued function, as supported by Postgres /// and MSSQL. Note that deprecated MSSQL `FROM foo (NOLOCK)` syntax /// will also be parsed as `args`. - args: Vec, + args: Vec, /// MSSQL-specific `WITH (...)` hints such as NOLOCK. with_hints: Vec, }, diff --git a/src/parser.rs b/src/parser.rs index 1c3c4eaf9..625a424fb 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -2203,11 +2203,24 @@ impl Parser { Ok(Assignment { id, value }) } - pub fn parse_optional_args(&mut self) -> Result, ParserError> { + fn parse_function_args(&mut self) -> Result { + if self.peek_nth_token(1) == Token::RArrow { + let name = self.parse_identifier()?; + + self.expect_token(&Token::RArrow)?; + let arg = self.parse_expr()?; + + Ok(FunctionArg::Named { name, arg }) + } else { + Ok(FunctionArg::Unnamed(self.parse_expr()?)) + } + } + + pub fn parse_optional_args(&mut self) -> Result, ParserError> { if self.consume_token(&Token::RParen) { Ok(vec![]) } else { - let args = self.parse_comma_separated(Parser::parse_expr)?; + let args = self.parse_comma_separated(Parser::parse_function_args)?; self.expect_token(&Token::RParen)?; Ok(args) } diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 177402599..644066989 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -99,6 +99,8 @@ pub enum Token { LBrace, /// Right brace `}` RBrace, + /// Right Arrow `=>` + RArrow, } impl fmt::Display for Token { @@ -139,6 +141,7 @@ impl fmt::Display for Token { Token::Pipe => f.write_str("|"), Token::LBrace => f.write_str("{"), Token::RBrace => f.write_str("}"), + Token::RArrow => f.write_str("=>"), } } } @@ -400,7 +403,13 @@ impl<'a> Tokenizer<'a> { _ => Ok(Some(Token::Pipe)), } } - '=' => self.consume_and_return(chars, Token::Eq), + '=' => { + chars.next(); // consume + match chars.peek() { + Some('>') => self.consume_and_return(chars, Token::RArrow), + _ => Ok(Some(Token::Eq)), + } + } '.' => self.consume_and_return(chars, Token::Period), '!' => { chars.next(); // consume @@ -766,6 +775,23 @@ mod tests { compare(expected, tokens); } + #[test] + fn tokenize_right_arrow() { + let sql = String::from("FUNCTION(key=>value)"); + let dialect = GenericDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, &sql); + let tokens = tokenizer.tokenize().unwrap(); + let expected = vec![ + Token::make_word("FUNCTION", None), + Token::LParen, + Token::make_word("key", None), + Token::RArrow, + Token::make_word("value", None), + Token::RParen, + ]; + compare(expected, tokens); + } + #[test] fn tokenize_is_null() { let sql = String::from("a IS NULL"); diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 898c39e47..5443c06e2 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -325,7 +325,7 @@ fn parse_select_count_wildcard() { assert_eq!( &Expr::Function(Function { name: ObjectName(vec![Ident::new("COUNT")]), - args: vec![Expr::Wildcard], + args: vec![FunctionArg::Unnamed(Expr::Wildcard)], over: None, distinct: false, }), @@ -340,10 +340,10 @@ fn parse_select_count_distinct() { assert_eq!( &Expr::Function(Function { name: ObjectName(vec![Ident::new("COUNT")]), - args: vec![Expr::UnaryOp { + args: vec![FunctionArg::Unnamed(Expr::UnaryOp { op: UnaryOperator::Plus, expr: Box::new(Expr::Identifier(Ident::new("x"))) - }], + })], over: None, distinct: true, }), @@ -883,7 +883,7 @@ fn parse_select_having() { Some(Expr::BinaryOp { left: Box::new(Expr::Function(Function { name: ObjectName(vec![Ident::new("COUNT")]), - args: vec![Expr::Wildcard], + args: vec![FunctionArg::Unnamed(Expr::Wildcard)], over: None, distinct: false })), @@ -1589,7 +1589,32 @@ fn parse_scalar_function_in_projection() { assert_eq!( &Expr::Function(Function { name: ObjectName(vec![Ident::new("sqrt")]), - args: vec![Expr::Identifier(Ident::new("id"))], + args: vec![FunctionArg::Unnamed(Expr::Identifier(Ident::new("id")))], + over: None, + distinct: false, + }), + expr_from_projection(only(&select.projection)) + ); +} + +#[test] +fn parse_named_argument_function() { + let sql = "SELECT FUN(a => '1', b => '2') FROM foo"; + let select = verified_only_select(sql); + + assert_eq!( + &Expr::Function(Function { + name: ObjectName(vec![Ident::new("FUN")]), + args: vec![ + FunctionArg::Named { + name: Ident::new("a"), + arg: Expr::Value(Value::SingleQuotedString("1".to_owned())) + }, + FunctionArg::Named { + name: Ident::new("b"), + arg: Expr::Value(Value::SingleQuotedString("2".to_owned())) + }, + ], over: None, distinct: false, }),