Skip to content
This repository was archived by the owner on Jul 3, 2023. It is now read-only.

Commit 3aa3e42

Browse files
committed
Add finalize step to the Aggregator trait.
A recent extension of the `Aggregator` trait requires aggregated values to form a semigroup, which is not the case for instance for the AVG aggregate. The standard solution is to use an intermediate representation of the aggregate that forms a semigroup and only compute the actual value of the aggregate after all partitions have been combined. To support such aggregates we extend the `Aggregator` trait with the `Accumulator` associated type returned by `aggregate`, which must implement `trait Semigroup`. The `Aggregator::Output` associated type represents the final output of the aggregator computed by applying the `finalize` method to the final value of the accumulator.
1 parent 728de5e commit 3aa3e42

File tree

9 files changed

+169
-88
lines changed

9 files changed

+169
-88
lines changed

benches/fraud.rs

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ struct Args {
195195
type Weight = i32;
196196

197197
type EnrichedTransactions = OrdIndexedZSet<(F64, i64), (Transaction, Demographics), Weight>;
198-
type AverageSpendingPerWeek = OrdPartitionedIndexedZSet<F64, i64, Option<Avg<F64, i32>>, Weight>;
199-
type AverageSpendingPerMonth = OrdPartitionedIndexedZSet<F64, i64, Option<Avg<F64, i32>>, Weight>;
198+
type AverageSpendingPerWeek = OrdPartitionedIndexedZSet<F64, i64, Option<F64>, Weight>;
199+
type AverageSpendingPerMonth = OrdPartitionedIndexedZSet<F64, i64, Option<F64>, Weight>;
200200
type TransactionFrequency = OrdPartitionedIndexedZSet<F64, i64, Option<i32>, Weight>;
201201

202202
struct FraudBenchmark {
@@ -233,16 +233,12 @@ impl FraudBenchmark {
233233
let avg_spend_pw: Stream<_, AverageSpendingPerWeek> = amounts
234234
.partitioned_rolling_aggregate_linear(
235235
|amt| Avg::new(*amt, 1),
236+
|avg| avg.compute_avg().unwrap(),
236237
RelRange::new(RelOffset::Before(DAY_IN_SECONDS * 7), RelOffset::Before(1)),
237238
);
238239

239-
// TODO: this should be returned directly by `partitioned_rolling_aggregate`
240-
let avg_spend_pw_indexed = avg_spend_pw.map_index(|(cc_num, (ts, avg_amt))| {
241-
(
242-
(*cc_num, *ts),
243-
avg_amt.as_ref().and_then(|avg| avg.compute_avg()),
244-
)
245-
});
240+
let avg_spend_pw_indexed =
241+
avg_spend_pw.map_index(|(cc_num, (ts, avg_amt))| ((*cc_num, *ts), *avg_amt));
246242

247243
// AVG(amt) OVER(
248244
// PARTITION BY CAST(cc_num AS NUMERIC)
@@ -252,15 +248,12 @@ impl FraudBenchmark {
252248
let avg_spend_pm: Stream<_, AverageSpendingPerMonth> = amounts
253249
.partitioned_rolling_aggregate_linear(
254250
|amt| Avg::new(*amt, 1),
251+
|avg| avg.compute_avg().unwrap(),
255252
RelRange::new(RelOffset::Before(DAY_IN_SECONDS * 30), RelOffset::Before(1)),
256253
);
257254

258-
let avg_spend_pm_indexed = avg_spend_pm.map_index(|(cc_num, (ts, avg_amt))| {
259-
(
260-
(*cc_num, *ts),
261-
avg_amt.as_ref().and_then(|avg| avg.compute_avg()),
262-
)
263-
});
255+
let avg_spend_pm_indexed =
256+
avg_spend_pm.map_index(|(cc_num, (ts, avg_amt))| ((*cc_num, *ts), *avg_amt));
264257

265258
// COUNT(*) OVER(
266259
// PARTITION BY CAST(cc_num AS NUMERIC)
@@ -270,6 +263,7 @@ impl FraudBenchmark {
270263
let trans_freq_24: Stream<_, TransactionFrequency> = amounts
271264
.partitioned_rolling_aggregate_linear(
272265
|_amt| 1,
266+
|cnt| cnt,
273267
RelRange::new(RelOffset::Before(DAY_IN_SECONDS), RelOffset::Before(1)),
274268
);
275269

src/operator/aggregate/fold.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,17 @@ impl<V, T, R, A, S, O, SF, OF> Aggregator<V, T, R> for Fold<A, S, SF, OF>
6363
where
6464
T: Timestamp,
6565
R: MonoidValue,
66-
A: Clone + 'static,
66+
A: DBData,
6767
SF: Fn(&mut A, &V, R) + Clone + 'static,
6868
OF: Fn(A) -> O + Clone + 'static,
69-
S: Semigroup<O> + Clone + 'static,
69+
S: Semigroup<A> + Clone + 'static,
7070
O: DBData,
7171
{
72+
type Accumulator = A;
7273
type Output = O;
7374
type Semigroup = S;
7475

75-
fn aggregate<'s, C>(&self, cursor: &mut C) -> Option<Self::Output>
76+
fn aggregate<'s, C>(&self, cursor: &mut C) -> Option<Self::Accumulator>
7677
where
7778
C: Cursor<'s, V, (), T, R>,
7879
{
@@ -91,8 +92,10 @@ where
9192
cursor.step_key();
9293
}
9394

94-
// Aggregator must return None iff the input cursor is empty (all keys have
95-
// weight 0).
96-
non_empty.then(|| (self.output)(acc))
95+
non_empty.then_some(acc)
96+
}
97+
98+
fn finalize(&self, acc: Self::Accumulator) -> Self::Output {
99+
(self.output)(acc)
97100
}
98101
}

src/operator/aggregate/max.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ where
2828
T: Timestamp,
2929
R: MonoidValue,
3030
{
31+
type Accumulator = V;
3132
type Output = V;
3233
type Semigroup = MaxSemigroup<V>;
3334

3435
// TODO: this can be more efficient with reverse iterator.
35-
fn aggregate<'s, C>(&self, cursor: &mut C) -> Option<Self::Output>
36+
fn aggregate<'s, C>(&self, cursor: &mut C) -> Option<Self::Accumulator>
3637
where
3738
C: Cursor<'s, V, (), T, R>,
3839
{
@@ -52,4 +53,8 @@ where
5253

5354
result
5455
}
56+
57+
fn finalize(&self, accumulator: Self::Accumulator) -> Self::Output {
58+
accumulator
59+
}
5560
}

src/operator/aggregate/min.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ where
3030
T: Timestamp,
3131
R: MonoidValue,
3232
{
33+
type Accumulator = V;
3334
type Output = V;
3435
type Semigroup = MinSemigroup<V>;
3536

@@ -53,4 +54,8 @@ where
5354

5455
None
5556
}
57+
58+
fn finalize(&self, accumulator: Self::Accumulator) -> Self::Output {
59+
accumulator
60+
}
5661
}

src/operator/aggregate/mod.rs

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,42 @@ pub use min::Min;
3939
/// A trait for aggregator objects. An aggregator summarizes the contents
4040
/// of a Z-set into a single value.
4141
///
42+
/// This trait supports two aggregation methods that can be combined in
43+
/// various ways to compute the aggregate efficiently. First, the
44+
/// [`aggregate`](`Self::aggregate`) method takes a cursor pointing to
45+
/// the first key of a Z-set and scans the cursor to compute the aggregate
46+
/// over all values in the cursor. Second, the
47+
/// [`Semigroup`](`Self::Semigroup`) associated type allows aggregating
48+
/// partitioned Z-sets by combining aggregates computed over individual
49+
/// partitions (e.g., computed by different worker threads).
50+
///
51+
/// This design requires aggregate values to form a semigroup with a `+`
52+
/// operation such that `Agg(x + y) = Agg(x) + Agg(y)`, i.e., aggregating
53+
/// a union of two Z-sets must produce the same result as the sum of
54+
/// individual aggregates. Not all aggregates have this property. E.g.,
55+
/// the average value of a Z-set cannot be computed by combining the
56+
/// averages of its subsets. We can get around this problem by representing
57+
/// average as `(sum, count)` tuple with point-wise `+` operator
58+
/// `(sum1, count1) + (sum2, count2) = (sum1+sum2, count1+count2)`.
59+
/// The final result is converted to an actual average value by dividing
60+
/// the first element of the tuple by the second. This is a general technique
61+
/// that works for all aggregators (although it may not always be optimal).
62+
///
63+
/// To support such aggregates, this trait distinguishes between the
64+
/// `Accumulator` type returned by [`aggregate`](`Self::aggregate`),
65+
/// which must implement
66+
/// [`trait Semigroup`](`crate::algebra::Semigroup`), and the final
67+
/// [`Output`](`Self::Output`) of the aggregator. The latter is
68+
/// computed by applying the [`finalize`](`Self::finalize`) method
69+
/// to the final value of the accumulator.
70+
///
4271
/// This is a low-level trait that is mostly used to build libraries of
4372
/// aggregators. Users will typicaly work with ready-made implementations
4473
/// like [`Min`] and [`Fold`].
4574
// TODO: Owned aggregation using `Consumer`
4675
pub trait Aggregator<K, T, R>: Clone + 'static {
47-
/// Aggregate type output by this aggregator.
48-
type Output: DBData;
76+
/// Accumulator type returned by
77+
type Accumulator: DBData;
4978

5079
/// Semigroup structure over aggregate values.
5180
///
@@ -61,10 +90,11 @@ pub trait Aggregator<K, T, R>: Clone + 'static {
6190
// per-worker aggregates computes over arbitrary subsets of values,
6291
// which additionally requires commutativity. Do we want to introduce
6392
// the `CommutativeSemigroup` trait?
64-
type Semigroup: Semigroup<Self::Output>;
93+
type Semigroup: Semigroup<Self::Accumulator>;
94+
95+
/// Aggregate type produced by this aggregator.
96+
type Output: DBData;
6597

66-
/// Compute an aggregate over a Z-set.
67-
///
6898
/// Takes a cursor pointing to the first key of a Z-set and outputs
6999
/// an aggregate of the Z-set.
70100
///
@@ -75,9 +105,20 @@ pub trait Aggregator<K, T, R>: Clone + 'static {
75105
///
76106
/// * The method must return `None` if the total weight of each key is zero.
77107
/// It must return `Some` otherwise.
78-
fn aggregate<'s, C>(&self, cursor: &mut C) -> Option<Self::Output>
108+
fn aggregate<'s, C>(&self, cursor: &mut C) -> Option<Self::Accumulator>
79109
where
80110
C: Cursor<'s, K, (), T, R>;
111+
112+
/// Compute the final value of the aggregate.
113+
fn finalize(&self, accumulator: Self::Accumulator) -> Self::Output;
114+
115+
/// Applies `aggregate` to `cursor` followed by `finalize`.
116+
fn aggregate_and_finalize<'s, C>(&self, cursor: &mut C) -> Option<Self::Output>
117+
where
118+
C: Cursor<'s, K, (), T, R>,
119+
{
120+
self.aggregate(cursor).map(|x| self.finalize(x))
121+
}
81122
}
82123

83124
/// Aggregator used internally by [`Stream::aggregate_linear`]. Computes
@@ -90,6 +131,7 @@ where
90131
T: Timestamp,
91132
R: DBWeight,
92133
{
134+
type Accumulator = R;
93135
type Output = R;
94136
type Semigroup = DefaultSemigroup<R>;
95137

@@ -107,6 +149,10 @@ where
107149
Some(weight)
108150
}
109151
}
152+
153+
fn finalize(&self, accumulator: Self::Accumulator) -> Self::Output {
154+
accumulator
155+
}
110156
}
111157

112158
impl<P, Z> Stream<Circuit<P>, Z>
@@ -329,7 +375,7 @@ where
329375
while cursor.key_valid() {
330376
if let Some(agg) = self
331377
.aggregator
332-
.aggregate(&mut CursorGroup::new(&mut cursor, ()))
378+
.aggregate_and_finalize(&mut CursorGroup::new(&mut cursor, ()))
333379
{
334380
builder.push((O::item_from(cursor.key().clone(), agg), O::R::one()));
335381
}
@@ -461,7 +507,7 @@ where
461507
// Z-set associated with `input_cursor.key()` at time `self.time`.
462508
if let Some(aggregate) = self
463509
.aggregator
464-
.aggregate(&mut CursorGroup::new(input_cursor, self.time.clone()))
510+
.aggregate_and_finalize(&mut CursorGroup::new(input_cursor, self.time.clone()))
465511
{
466512
output.push((key.clone(), Some(aggregate)));
467513
} else {

src/operator/time_series/radix_tree/partitioned_tree_aggregate.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,15 @@ where
115115
pub fn partitioned_tree_aggregate<TS, V, Agg>(
116116
&self,
117117
aggregator: Agg,
118-
) -> OrdPartitionedRadixTreeStream<Z::Key, TS, Agg::Output, isize>
118+
) -> OrdPartitionedRadixTreeStream<Z::Key, TS, Agg::Accumulator, isize>
119119
where
120120
Z: PartitionedIndexedZSet<TS, V> + SizeOf,
121121
TS: DBData + PrimInt,
122122
V: DBData,
123123
Agg: Aggregator<V, (), Z::R>,
124-
Agg::Output: Default,
124+
Agg::Accumulator: Default,
125125
{
126-
self.partitioned_tree_aggregate_generic::<TS, V, Agg, OrdPartitionedRadixTree<Z::Key, TS, Agg::Output, isize>>(
126+
self.partitioned_tree_aggregate_generic::<TS, V, Agg, OrdPartitionedRadixTree<Z::Key, TS, Agg::Accumulator, isize>>(
127127
aggregator,
128128
)
129129
}
@@ -139,8 +139,8 @@ where
139139
TS: DBData + PrimInt,
140140
V: DBData,
141141
Agg: Aggregator<V, (), Z::R>,
142-
Agg::Output: Default,
143-
O: PartitionedRadixTreeBatch<TS, Agg::Output, Key = Z::Key>,
142+
Agg::Accumulator: Default,
143+
O: PartitionedRadixTreeBatch<TS, Agg::Accumulator, Key = Z::Key>,
144144
O::R: ZRingValue,
145145
{
146146
self.circuit()
@@ -345,10 +345,10 @@ where
345345
TS: DBData + PrimInt,
346346
V: DBData,
347347
IT: PartitionedBatchReader<TS, V, Key = Z::Key, R = Z::R> + Clone,
348-
OT: PartitionedRadixTreeReader<TS, Agg::Output, Key = Z::Key, R = O::R> + Clone,
348+
OT: PartitionedRadixTreeReader<TS, Agg::Accumulator, Key = Z::Key, R = O::R> + Clone,
349349
Agg: Aggregator<V, (), Z::R>,
350-
Agg::Output: Default,
351-
O: PartitionedRadixTreeBatch<TS, Agg::Output, Key = Z::Key>,
350+
Agg::Accumulator: Default,
351+
O: PartitionedRadixTreeBatch<TS, Agg::Accumulator, Key = Z::Key>,
352352
O::R: ZRingValue,
353353
{
354354
fn eval<'a>(

src/operator/time_series/radix_tree/tree_aggregate.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,16 @@ where
5555
pub fn tree_aggregate<Agg>(
5656
&self,
5757
aggregator: Agg,
58-
) -> Stream<Circuit<P>, OrdRadixTree<Z::Key, Agg::Output, isize>>
58+
) -> Stream<Circuit<P>, OrdRadixTree<Z::Key, Agg::Accumulator, isize>>
5959
where
6060
Z: IndexedZSet + SizeOf + NumEntries + Send,
6161
Z::Key: PrimInt,
6262
Agg: Aggregator<Z::Val, (), Z::R>,
63-
Agg::Output: Default,
63+
Agg::Accumulator: Default,
6464
{
65-
self.tree_aggregate_generic::<Agg, OrdRadixTree<Z::Key, Agg::Output, isize>>(aggregator)
65+
self.tree_aggregate_generic::<Agg, OrdRadixTree<Z::Key, Agg::Accumulator, isize>>(
66+
aggregator,
67+
)
6668
}
6769

6870
/// Like [`Self::tree_aggregate`], but can return any batch type.
@@ -71,8 +73,8 @@ where
7173
Z: IndexedZSet + SizeOf + NumEntries + Send,
7274
Z::Key: PrimInt,
7375
Agg: Aggregator<Z::Val, (), Z::R>,
74-
Agg::Output: Default,
75-
O: RadixTreeBatch<Z::Key, Agg::Output>,
76+
Agg::Accumulator: Default,
77+
O: RadixTreeBatch<Z::Key, Agg::Accumulator>,
7678
O::R: ZRingValue,
7779
{
7880
self.circuit()
@@ -186,10 +188,10 @@ where
186188
Z: IndexedZSet,
187189
Z::Key: PrimInt,
188190
IT: BatchReader<Key = Z::Key, Val = Z::Val, Time = (), R = Z::R> + Clone,
189-
OT: RadixTreeReader<Z::Key, Agg::Output, R = O::R> + Clone,
191+
OT: RadixTreeReader<Z::Key, Agg::Accumulator, R = O::R> + Clone,
190192
Agg: Aggregator<Z::Val, (), Z::R>,
191-
Agg::Output: Default,
192-
O: RadixTreeBatch<Z::Key, Agg::Output>,
193+
Agg::Accumulator: Default,
194+
O: RadixTreeBatch<Z::Key, Agg::Accumulator>,
193195
O::R: ZRingValue,
194196
{
195197
fn eval<'a>(

src/operator/time_series/radix_tree/updater.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,19 +369,19 @@ pub(super) fn radix_tree_update<'a, 'b, TS, V, R, Agg, UC, IC, TC, OR>(
369369
mut input: IC,
370370
tree: TC,
371371
aggregator: &Agg,
372-
output_updates: &'a mut Vec<TreeNodeUpdate<TS, Agg::Output>>,
372+
output_updates: &'a mut Vec<TreeNodeUpdate<TS, Agg::Accumulator>>,
373373
) where
374374
TS: PrimInt + Debug,
375375
R: MonoidValue,
376376
Agg: Aggregator<V, (), R>,
377-
Agg::Output: Clone + Default + Eq + Debug,
377+
Agg::Accumulator: Clone + Default + Eq + Debug,
378378
UC: Cursor<'b, TS, V, (), R>,
379379
IC: Cursor<'b, TS, V, (), R>,
380-
TC: RadixTreeCursor<'b, TS, Agg::Output, OR>,
380+
TC: RadixTreeCursor<'b, TS, Agg::Accumulator, OR>,
381381
OR: MonoidValue,
382382
{
383383
let mut tree_updater =
384-
<TreeUpdater<'a, TS, Agg::Output, OR, Agg::Semigroup, TC>>::new(tree, output_updates);
384+
<TreeUpdater<'a, TS, Agg::Accumulator, OR, Agg::Semigroup, TC>>::new(tree, output_updates);
385385

386386
while input_delta.key_valid() {
387387
//println!("affected key {:x?}", input_delta.key());

0 commit comments

Comments
 (0)