From 72ced4bffebbe81787500ea07e348a363994e948 Mon Sep 17 00:00:00 2001 From: Jamie Brandon Date: Wed, 22 May 2019 15:44:41 +0100 Subject: [PATCH] Support COUNT(DISTINCT x) and similar --- src/sqlast/mod.rs | 16 ++++++++++++++-- src/sqlparser.rs | 15 ++++++++++++++- tests/sqlparser_common.rs | 38 +++++++++++++++++++++++++++++++++++++- 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index 678588ba0..842c3bd6d 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -112,6 +112,8 @@ pub enum ASTNode { name: SQLObjectName, args: Vec, over: Option, + // aggregate functions may specify eg `COUNT(DISTINCT x)` + distinct: bool, }, /// CASE [] WHEN THEN ... [ELSE ] END /// Note we only recognize a complete single expression as , not @@ -190,8 +192,18 @@ impl ToString for ASTNode { format!("{} {}", operator.to_string(), expr.as_ref().to_string()) } ASTNode::SQLValue(v) => v.to_string(), - ASTNode::SQLFunction { name, args, over } => { - let mut s = format!("{}({})", name.to_string(), comma_separated_string(args)); + ASTNode::SQLFunction { + name, + args, + over, + distinct, + } => { + let mut s = format!( + "{}({}{})", + name.to_string(), + if *distinct { "DISTINCT " } else { "" }, + comma_separated_string(args) + ); if let Some(o) = over { s += &format!(" OVER ({})", o.to_string()) } diff --git a/src/sqlparser.rs b/src/sqlparser.rs index edfa0c573..4283a8393 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -253,6 +253,14 @@ impl Parser { pub fn parse_function(&mut self, name: SQLObjectName) -> Result { self.expect_token(&Token::LParen)?; + let all = self.parse_keyword("ALL"); + let distinct = self.parse_keyword("DISTINCT"); + if all && distinct { + return parser_err!(format!( + "Cannot specify both ALL and DISTINCT in function: {}", + name.to_string(), + )); + } let args = self.parse_optional_args()?; let over = if self.parse_keyword("OVER") { // TBD: support window names (`OVER mywin`) in place of inline specification @@ -279,7 +287,12 @@ impl Parser { None }; - Ok(ASTNode::SQLFunction { name, args, over }) + Ok(ASTNode::SQLFunction { + name, + args, + over, + distinct, + }) } pub fn parse_window_frame(&mut self) -> Result, ParserError> { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index e141c840c..9a66d6b0c 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -197,11 +197,44 @@ fn parse_select_count_wildcard() { name: SQLObjectName(vec!["COUNT".to_string()]), args: vec![ASTNode::SQLWildcard], over: None, + distinct: false, }, expr_from_projection(only(&select.projection)) ); } +#[test] +fn parse_select_count_distinct() { + let sql = "SELECT COUNT(DISTINCT + x) FROM customer"; + let select = verified_only_select(sql); + assert_eq!( + &ASTNode::SQLFunction { + name: SQLObjectName(vec!["COUNT".to_string()]), + args: vec![ASTNode::SQLUnary { + operator: SQLOperator::Plus, + expr: Box::new(ASTNode::SQLIdentifier("x".to_string())) + }], + over: None, + distinct: true, + }, + expr_from_projection(only(&select.projection)) + ); + + one_statement_parses_to( + "SELECT COUNT(ALL + x) FROM customer", + "SELECT COUNT(+ x) FROM customer", + ); + + let sql = "SELECT COUNT(ALL DISTINCT + x) FROM customer"; + let res = parse_sql_statements(sql); + assert_eq!( + ParserError::ParserError( + "Cannot specify both ALL and DISTINCT in function: COUNT".to_string() + ), + res.unwrap_err() + ); +} + #[test] fn parse_not() { let sql = "SELECT id FROM customer WHERE NOT salary = ''"; @@ -662,6 +695,7 @@ fn parse_scalar_function_in_projection() { name: SQLObjectName(vec!["sqrt".to_string()]), args: vec![ASTNode::SQLIdentifier("id".to_string())], over: None, + distinct: false, }, expr_from_projection(only(&select.projection)) ); @@ -690,7 +724,8 @@ fn parse_window_functions() { asc: Some(false) }], window_frame: None, - }) + }), + distinct: false, }, expr_from_projection(&select.projection[0]) ); @@ -762,6 +797,7 @@ fn parse_delimited_identifiers() { name: SQLObjectName(vec![r#""myfun""#.to_string()]), args: vec![], over: None, + distinct: false, }, expr_from_projection(&select.projection[1]), );