Skip to content

Commit 646479e

Browse files
authored
Merge pull request #77 from benesch/count-distinct
Support COUNT(DISTINCT x) and similar
2 parents 86a2fbd + 72ced4b commit 646479e

File tree

3 files changed

+65
-4
lines changed

3 files changed

+65
-4
lines changed

src/sqlast/mod.rs

+14-2
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ pub enum ASTNode {
112112
name: SQLObjectName,
113113
args: Vec<ASTNode>,
114114
over: Option<SQLWindowSpec>,
115+
// aggregate functions may specify eg `COUNT(DISTINCT x)`
116+
distinct: bool,
115117
},
116118
/// CASE [<operand>] WHEN <condition> THEN <result> ... [ELSE <result>] END
117119
/// Note we only recognize a complete single expression as <condition>, not
@@ -190,8 +192,18 @@ impl ToString for ASTNode {
190192
format!("{} {}", operator.to_string(), expr.as_ref().to_string())
191193
}
192194
ASTNode::SQLValue(v) => v.to_string(),
193-
ASTNode::SQLFunction { name, args, over } => {
194-
let mut s = format!("{}({})", name.to_string(), comma_separated_string(args));
195+
ASTNode::SQLFunction {
196+
name,
197+
args,
198+
over,
199+
distinct,
200+
} => {
201+
let mut s = format!(
202+
"{}({}{})",
203+
name.to_string(),
204+
if *distinct { "DISTINCT " } else { "" },
205+
comma_separated_string(args)
206+
);
195207
if let Some(o) = over {
196208
s += &format!(" OVER ({})", o.to_string())
197209
}

src/sqlparser.rs

+14-1
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,14 @@ impl Parser {
270270

271271
pub fn parse_function(&mut self, name: SQLObjectName) -> Result<ASTNode, ParserError> {
272272
self.expect_token(&Token::LParen)?;
273+
let all = self.parse_keyword("ALL");
274+
let distinct = self.parse_keyword("DISTINCT");
275+
if all && distinct {
276+
return parser_err!(format!(
277+
"Cannot specify both ALL and DISTINCT in function: {}",
278+
name.to_string(),
279+
));
280+
}
273281
let args = self.parse_optional_args()?;
274282
let over = if self.parse_keyword("OVER") {
275283
// TBD: support window names (`OVER mywin`) in place of inline specification
@@ -296,7 +304,12 @@ impl Parser {
296304
None
297305
};
298306

299-
Ok(ASTNode::SQLFunction { name, args, over })
307+
Ok(ASTNode::SQLFunction {
308+
name,
309+
args,
310+
over,
311+
distinct,
312+
})
300313
}
301314

302315
pub fn parse_window_frame(&mut self) -> Result<Option<SQLWindowFrame>, ParserError> {

tests/sqlparser_common.rs

+37-1
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,44 @@ fn parse_select_count_wildcard() {
211211
name: SQLObjectName(vec!["COUNT".to_string()]),
212212
args: vec![ASTNode::SQLWildcard],
213213
over: None,
214+
distinct: false,
214215
},
215216
expr_from_projection(only(&select.projection))
216217
);
217218
}
218219

220+
#[test]
221+
fn parse_select_count_distinct() {
222+
let sql = "SELECT COUNT(DISTINCT + x) FROM customer";
223+
let select = verified_only_select(sql);
224+
assert_eq!(
225+
&ASTNode::SQLFunction {
226+
name: SQLObjectName(vec!["COUNT".to_string()]),
227+
args: vec![ASTNode::SQLUnary {
228+
operator: SQLOperator::Plus,
229+
expr: Box::new(ASTNode::SQLIdentifier("x".to_string()))
230+
}],
231+
over: None,
232+
distinct: true,
233+
},
234+
expr_from_projection(only(&select.projection))
235+
);
236+
237+
one_statement_parses_to(
238+
"SELECT COUNT(ALL + x) FROM customer",
239+
"SELECT COUNT(+ x) FROM customer",
240+
);
241+
242+
let sql = "SELECT COUNT(ALL DISTINCT + x) FROM customer";
243+
let res = parse_sql_statements(sql);
244+
assert_eq!(
245+
ParserError::ParserError(
246+
"Cannot specify both ALL and DISTINCT in function: COUNT".to_string()
247+
),
248+
res.unwrap_err()
249+
);
250+
}
251+
219252
#[test]
220253
fn parse_not() {
221254
let sql = "SELECT id FROM customer WHERE NOT salary = ''";
@@ -676,6 +709,7 @@ fn parse_scalar_function_in_projection() {
676709
name: SQLObjectName(vec!["sqrt".to_string()]),
677710
args: vec![ASTNode::SQLIdentifier("id".to_string())],
678711
over: None,
712+
distinct: false,
679713
},
680714
expr_from_projection(only(&select.projection))
681715
);
@@ -704,7 +738,8 @@ fn parse_window_functions() {
704738
asc: Some(false)
705739
}],
706740
window_frame: None,
707-
})
741+
}),
742+
distinct: false,
708743
},
709744
expr_from_projection(&select.projection[0])
710745
);
@@ -776,6 +811,7 @@ fn parse_delimited_identifiers() {
776811
name: SQLObjectName(vec![r#""myfun""#.to_string()]),
777812
args: vec![],
778813
over: None,
814+
distinct: false,
779815
},
780816
expr_from_projection(&select.projection[1]),
781817
);

0 commit comments

Comments
 (0)