Skip to content

Commit 0d21ebc

Browse files
committed
aggregations: Support postgres array_agg()
Add essential support for postgres's `array_agg()` aggregation function. Includes an update to logictests that let us convert incoming pg TextArray objects as a single string. This is sufficient for our needs in logictests. Fixes: REA-5905 Release-Note-Core: Add base support for postgres accumulating aggregation function `array_agg()`. Change-Id: I068c0fb0fec591d73a176437044572a2931ca5a2 Reviewed-on: https://gerrit.readyset.name/c/readyset/+/10463 Reviewed-by: Michael Zink <michael.z@readyset.io> Tested-by: Buildkite CI
1 parent fe07a6a commit 0d21ebc

File tree

15 files changed

+226
-21
lines changed

15 files changed

+226
-21
lines changed

dataflow-expression/src/grouped/accumulator.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! This module implements SQL aggregation functions that accumulate multiple input values
2-
//! into a single output value. Currently only supports mysql's `GROUP_CONCAT`, postgres' `STRING_AGG`,
3-
//! and the various JSON object aggregation functions.
2+
//! into a single output value. Currently only supports mysql's `GROUP_CONCAT`, postgres' `STRING_AGG`
3+
//! and `ARRAY_AGG`, and the various JSON object aggregation functions.
44
use crate::eval::json;
55
use readyset_data::DfValue;
66
use readyset_errors::{internal_err, ReadySetResult};
@@ -9,6 +9,8 @@ use serde::{Deserialize, Serialize};
99
/// Supported accumulation operators.
1010
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1111
pub enum AccumulationOp {
12+
/// Concatenate values into an array. Allows NULL values in the output array.
13+
ArrayAgg,
1214
/// Concatenates using the given separator between values. The string result is with the concatenated non-NULL
1315
/// values from a group. It returns NULL if there are no non-NULL values.
1416
GroupConcat { separator: String },
@@ -21,6 +23,25 @@ pub enum AccumulationOp {
2123
}
2224

2325
impl AccumulationOp {
26+
pub fn ignore_nulls(&self) -> bool {
27+
match self {
28+
Self::ArrayAgg { .. } => false,
29+
Self::GroupConcat { .. } | Self::JsonObjectAgg { .. } | Self::StringAgg { .. } => true,
30+
}
31+
}
32+
33+
fn apply_array_agg(&self, data: &[DfValue]) -> ReadySetResult<DfValue> {
34+
if data.is_empty() {
35+
Ok(DfValue::Array(std::sync::Arc::new(
36+
readyset_data::Array::from(vec![]),
37+
)))
38+
} else {
39+
Ok(DfValue::Array(std::sync::Arc::new(
40+
readyset_data::Array::from(data.to_vec()),
41+
)))
42+
}
43+
}
44+
2445
fn apply_group_concat(&self, data: &[DfValue], separator: &str) -> ReadySetResult<DfValue> {
2546
// return SQL NULL if no non-NULL values in the `data`. we won't have NULL values as we've
2647
// filtered those out already.
@@ -70,6 +91,7 @@ impl AccumulationOp {
7091

7192
pub fn apply(&self, data: &[DfValue]) -> ReadySetResult<DfValue> {
7293
match self {
94+
AccumulationOp::ArrayAgg => self.apply_array_agg(data),
7395
AccumulationOp::GroupConcat { separator } => self.apply_group_concat(data, separator),
7496
AccumulationOp::JsonObjectAgg {
7597
allow_duplicate_keys,

dataflow-expression/src/reader_processing.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use serde::{Deserialize, Serialize};
1313
// TODO(aspen): It would be really nice to deduplicate this somehow with the grouped operator itself
1414
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
1515
pub enum PostLookupAggregateFunction {
16+
ArrayAgg,
1617
/// Add together all the input numbers
1718
///
1819
/// Note that this encapsulates both `SUM` *and* `COUNT` in base SQL, as re-aggregating counts
@@ -21,15 +22,21 @@ pub enum PostLookupAggregateFunction {
2122
/// Multiply together all the input numbers
2223
Product,
2324
/// Concatenate together all the input strings with the given separator
24-
GroupConcat { separator: String },
25+
GroupConcat {
26+
separator: String,
27+
},
2528
/// Take the maximum input value
2629
Max,
2730
/// Take the minimum input value
2831
Min,
2932
/// Use specified Key-value pair to build a JSON Object
30-
JsonObjectAgg { allow_duplicate_keys: bool },
33+
JsonObjectAgg {
34+
allow_duplicate_keys: bool,
35+
},
3136
/// Concatenate together all the input strings with the given separator
32-
StringAgg { separator: String },
37+
StringAgg {
38+
separator: String,
39+
},
3340
}
3441

3542
impl PostLookupAggregateFunction {
@@ -38,6 +45,15 @@ impl PostLookupAggregateFunction {
3845
/// This forms a semigroup.
3946
pub fn apply(&self, val1: &DfValue, val2: &DfValue) -> ReadySetResult<DfValue> {
4047
match self {
48+
PostLookupAggregateFunction::ArrayAgg => match (val1, val2) {
49+
(DfValue::Array(a1), DfValue::Array(a2)) => {
50+
let result: Vec<_> = a1.values().chain(a2.values()).cloned().collect();
51+
Ok(DfValue::Array(std::sync::Arc::new(
52+
readyset_data::Array::from(result),
53+
)))
54+
}
55+
_ => internal!("trying to using `array_agg()` for non-array types"),
56+
},
4157
PostLookupAggregateFunction::Sum => val1 + val2,
4258
PostLookupAggregateFunction::Product => val1 * val2,
4359
PostLookupAggregateFunction::GroupConcat { separator } => Ok(format!(
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
statement ok
2+
create table dogs (id int, name text);
3+
4+
statement ok
5+
insert into dogs values
6+
(1, 'kidnap'),
7+
(2, 'snoopy');
8+
9+
statement ok
10+
create table tags (d_id int, tag text);
11+
12+
statement ok
13+
insert into tags values
14+
(1, 'bbb'),
15+
(1, 'aaa'),
16+
(1, 'ccc'),
17+
(2, 'yyy'),
18+
(2, 'zzz');
19+
20+
21+
# simple query (no NULLs in data)
22+
query T
23+
select array_agg(t.tag) from tags t where d_id = 1;
24+
----
25+
{bbb,aaa,ccc}
26+
27+
statement ok
28+
insert into tags values (1, 'bbb');
29+
30+
# simple query (no NULLs in data)
31+
query T
32+
select array_agg(t.tag) from tags t where d_id = 1;
33+
----
34+
{bbb,aaa,ccc,bbb}
35+
36+
37+
# simple query (with NULLs in data)
38+
# Note: array_agg() _does_ include nulls in the output
39+
statement ok
40+
insert into tags values
41+
(1, NULL),
42+
(1, 'aaa');
43+
44+
query T
45+
select array_agg(t.tag) from tags t where d_id = 1;
46+
----
47+
{bbb,aaa,ccc,bbb,NULL,aaa}
48+
49+
50+
# multiple aggregations
51+
query IT
52+
select sum(d_id), array_agg(t.tag) from tags t where d_id = 1;
53+
----
54+
6
55+
{bbb,aaa,ccc,bbb,NULL,aaa}
56+
57+
58+
# join and group by clause
59+
query ITT
60+
select d.id, d.name, array_agg(t.tag) as tag_arr
61+
from dogs as d
62+
left join tags as t on d.id = t.d_id
63+
where d.id = 1
64+
group by d.id, d.name;
65+
----
66+
1
67+
kidnap
68+
{bbb,aaa,ccc,bbb,NULL,aaa}
69+
70+
71+
#######
72+
# test only empty/NULL inputs
73+
74+
# empty results
75+
query
76+
select array_agg(tag) from tags where d_id = 3;
77+
----
78+
NULL
79+
80+
statement ok
81+
insert into tags values
82+
(3, NULL);
83+
84+
# only NULL rows
85+
query
86+
select array_agg(tag) from tags where d_id = 3;
87+
----
88+
{NULL}
89+
90+
91+
#####
92+
# post-lookup aggregations
93+
94+
query T
95+
select array_agg(t.tag) from tags t where d_id in (1, 2);
96+
----
97+
{bbb,aaa,ccc,bbb,NULL,aaa,yyy,zzz}

query-generator/src/lib.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,8 @@ impl Arbitrary for WindowFunctionConfig {
11291129
#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize, Arbitrary)]
11301130
#[arbitrary(args = QueryDialect)]
11311131
pub enum AggregateType {
1132+
#[weight(u32::from(*args_shared == ParseDialect::PostgreSQL))]
1133+
ArrayAgg,
11321134
Count {
11331135
#[any(generate_arrays = false, dialect = Some(args_shared.0))]
11341136
column_type: SqlType,
@@ -1164,6 +1166,7 @@ pub enum AggregateType {
11641166
impl AggregateType {
11651167
pub fn column_type(&self) -> SqlType {
11661168
match self {
1169+
AggregateType::ArrayAgg => SqlType::Text,
11671170
AggregateType::Avg { column_type, .. } => column_type.clone(),
11681171
AggregateType::Count { column_type, .. } => column_type.clone(),
11691172
AggregateType::GroupConcat => SqlType::Text,
@@ -1180,7 +1183,8 @@ impl AggregateType {
11801183
AggregateType::Avg { distinct, .. } => *distinct,
11811184
AggregateType::Count { distinct, .. } => *distinct,
11821185
AggregateType::Sum { distinct, .. } => *distinct,
1183-
AggregateType::GroupConcat
1186+
AggregateType::ArrayAgg
1187+
| AggregateType::GroupConcat
11841188
| AggregateType::JsonObjectAgg { .. }
11851189
| AggregateType::Max { .. }
11861190
| AggregateType::Min { .. }
@@ -1646,6 +1650,7 @@ const ALL_PAGINATE: &[QueryOperation] = &[
16461650
];
16471651

16481652
const ALL_AGGREGATE_TYPES: &[AggregateType] = &[
1653+
AggregateType::ArrayAgg,
16491654
AggregateType::Count {
16501655
column_type: SqlType::Int(None),
16511656
distinct: true,
@@ -1910,6 +1915,7 @@ impl QueryOperation {
19101915
}));
19111916

19121917
let func = match *agg {
1918+
ArrayAgg => FunctionExpr::ArrayAgg { expr },
19131919
Avg { distinct, .. } => FunctionExpr::Avg { expr, distinct },
19141920
Count { distinct, .. } => FunctionExpr::Count { expr, distinct },
19151921
GroupConcat => FunctionExpr::GroupConcat {
@@ -2667,6 +2673,7 @@ impl FromStr for Operations {
26672673
.into()),
26682674
"topk" => Ok(ALL_TOPK.to_vec().into()),
26692675
"paginate" => Ok(ALL_PAGINATE.to_vec().into()),
2676+
"array_agg" => Ok(vec![ColumnAggregate(AggregateType::ArrayAgg)].into()),
26702677
"string_agg" => Ok(vec![ColumnAggregate(AggregateType::StringAgg)].into()),
26712678
s => Err(anyhow!("unknown query operation: {}", s)),
26722679
}
@@ -3315,6 +3322,7 @@ mod tests {
33153322
res,
33163323
vec![
33173324
Operations(vec![
3325+
QueryOperation::ColumnAggregate(AggregateType::ArrayAgg),
33183326
QueryOperation::ColumnAggregate(AggregateType::Count {
33193327
column_type: SqlType::Int(None),
33203328
distinct: true,

readyset-dataflow/src/ops/grouped/accumulator.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ impl Accumulator {
4646
let collation = over_col_ty.collation().unwrap_or(Collation::Utf8);
4747

4848
let out_ty = match &op {
49+
AccumulationOp::ArrayAgg => {
50+
antithesis_sdk::assert_reachable!("Accumulation::ArrayAgg");
51+
DfType::Array(Box::new(over_col_ty.clone()))
52+
}
4953
AccumulationOp::GroupConcat { .. } => {
5054
antithesis_sdk::assert_reachable!("Accumulation::GroupConcat");
5155
DfType::Text(collation)
@@ -147,7 +151,7 @@ impl GroupedOperation for Accumulator {
147151
group_hash,
148152
} in diffs
149153
{
150-
if value.is_none() {
154+
if self.op.ignore_nulls() && value.is_none() {
151155
continue;
152156
}
153157

@@ -182,6 +186,9 @@ impl GroupedOperation for Accumulator {
182186

183187
fn description(&self) -> String {
184188
let op_string = match &self.op {
189+
AccumulationOp::ArrayAgg => {
190+
format!("ArrayAgg({})", self.over)
191+
}
185192
AccumulationOp::GroupConcat { separator } => {
186193
format!("GroupConcat({}, {:?})", self.over, separator)
187194
}

readyset-logictest/src/ast.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,18 @@ impl<'a> pgsql::types::FromSql<'a> for Value {
379379
Type::FLOAT8 => Ok(Self::from(f64::from_sql(ty, raw)?)),
380380
Type::NUMERIC => Ok(Self::Numeric(Decimal::from_sql(ty, raw)?)),
381381
Type::TEXT => Ok(Self::Text(String::from_sql(ty, raw)?)),
382+
Type::TEXT_ARRAY => {
383+
// convert a text array into something like "{string1,string2,NULL,string3}"
384+
// we need to handle NULLs in some sane way, and this mimics what
385+
// pg's array_agg() function does.
386+
let string_array = Vec::<Option<String>>::from_sql(ty, raw)?;
387+
let joined = string_array
388+
.iter()
389+
.map(|opt| opt.as_deref().unwrap_or("NULL"))
390+
.collect::<Vec<_>>()
391+
.join(",");
392+
Ok(Self::Text(format!("{{{}}}", joined)))
393+
}
382394
Type::DATE => {
383395
// This is a hack to work around the fact that we don't have
384396
// a distinct 'Date' type, and that the existing 'Date' is
@@ -422,6 +434,7 @@ impl<'a> pgsql::types::FromSql<'a> for Value {
422434
| Type::FLOAT8
423435
| Type::NUMERIC
424436
| Type::TEXT
437+
| Type::TEXT_ARRAY
425438
| Type::DATE
426439
| Type::TIMESTAMP
427440
| Type::TIMESTAMPTZ

readyset-mir/src/node/node_inner.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ impl MirNodeInner {
447447
..
448448
} => {
449449
let op_string = match kind {
450+
AccumulationOp::ArrayAgg => format!("ArrayAgg({})", on.name.as_str()),
450451
AccumulationOp::GroupConcat { separator } => {
451452
format!(
452453
"GroupConcat([{}], \"{}\")",

readyset-mir/src/visualize.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ impl GraphViz for MirNodeInner {
159159
..
160160
} => {
161161
let op_string = match kind {
162+
AccumulationKind::ArrayAgg => format!("ArrayAgg({on})"),
162163
AccumulationKind::GroupConcat { separator: s } => {
163164
format!("GroupConcat({on}, \\\"{s}\\\")")
164165
}
@@ -301,6 +302,7 @@ impl GraphViz for MirNodeInner {
301302
.map(|aggregate| format!(
302303
"{}({})",
303304
match aggregate.function {
305+
PostLookupAggregateFunction::ArrayAgg => "ArrayAgg",
304306
PostLookupAggregateFunction::Sum => "Σ",
305307
PostLookupAggregateFunction::Product => "Π",
306308
PostLookupAggregateFunction::GroupConcat { .. } => "GC",

readyset-server/src/controller/sql/mir/grouped.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,10 @@ pub(super) fn post_lookup_aggregates(
340340
aggregates.push(PostLookupAggregate {
341341
column: Column::named(alias.clone()).aliased_as_table(query_name.clone()),
342342
function: match function {
343+
ArrayAgg { .. } => {
344+
antithesis_sdk::assert_reachable!("PostLookupAggregateFunction::ArrayAgg");
345+
PostLookupAggregateFunction::ArrayAgg
346+
}
343347
Avg { .. } => {
344348
unsupported!("Average is not supported as a post-lookup aggregate")
345349
}

readyset-server/src/controller/sql/mir/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,11 @@ impl SqlToMirConverter {
882882
}
883883

884884
Ok(match function {
885+
ArrayAgg { ref expr } => mknode(
886+
Column::from(get_column(expr)),
887+
GroupedNodeType::Accumulation(AccumulationOp::ArrayAgg),
888+
false,
889+
),
885890
Sum { ref expr, distinct } if is_column(expr) => mknode(
886891
Column::from(get_column(expr)),
887892
GroupedNodeType::Aggregation(Aggregation::Sum),

0 commit comments

Comments
 (0)