Skip to content

Commit a8d1fbb

Browse files
authored
Merge pull request #18 from NERC-CEH/feature/FPM-503-add_mean_sum_agg
Add MeanSum and Sum aggregations
2 parents 00c1ef8 + ad7763c commit a8d1fbb

File tree

2 files changed

+86
-5
lines changed

2 files changed

+86
-5
lines changed

src/time_stream/aggregation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def expr(self, columns: List[str]) -> List[pl.Expr]:
4646
"""Return the Polars expressions for this aggregation."""
4747
pass
4848

49+
def post_expr(self, columns: List[str]) -> List[pl.Expr]:
50+
"""Return additional Polars expressions to be applied after the aggregation."""
51+
return []
52+
4953
@classmethod
5054
def get(cls, aggregation: Union[str, Type["AggregationFunction"], "AggregationFunction"]) -> "AggregationFunction":
5155
"""Factory method to get an aggregation function instance from string names, class types, or existing instances.
@@ -158,6 +162,7 @@ def apply(
158162
with_columns_expressions = []
159163
with_columns_expressions.append(self._expected_count_expr(self.ts, aggregation_period))
160164
with_columns_expressions.extend(self._missing_data_expr(self.ts, columns, missing_criteria))
165+
with_columns_expressions.extend(self.post_expr(columns))
161166

162167
# Do the with_column methods
163168
for with_column in with_columns_expressions:
@@ -313,6 +318,34 @@ def expr(self, columns: List[str]) -> List[pl.Expr]:
313318
return [pl.col(col).mean().alias(f"mean_{col}") for col in columns]
314319

315320

321+
@register_aggregation
322+
class Sum(AggregationFunction):
323+
"""An aggregation class to calculate the sum (total) of values within each aggregation period."""
324+
325+
name = "sum"
326+
327+
def expr(self, columns: List[str]) -> List[pl.Expr]:
328+
"""Return the `Polars` expression for calculating the mean in an aggregation period."""
329+
return [pl.col(col).sum().alias(f"sum_{col}") for col in columns]
330+
331+
332+
@register_aggregation
333+
class MeanSum(AggregationFunction):
334+
"""An aggregation class to calculate the mean sum (averaged total) of values within each aggregation period.
335+
This will estimate the sum when values are missing according how many values are expected in the period."""
336+
337+
name = "mean_sum"
338+
339+
def expr(self, columns: List[str]) -> List[pl.Expr]:
340+
"""To calculate the mean sum the expression must return the mean, and be multiplied by the expected
341+
counts, which is calculated after in the post_expr method."""
342+
return [pl.col(col).mean().alias(f"mean_sum_{col}") for col in columns]
343+
344+
def post_expr(self, columns: List[str]) -> List[pl.Expr]:
345+
"""Multiply the mean by the expected count to get the mean sum."""
346+
return [(pl.col(f"mean_sum_{col}") * pl.col(f"expected_count_{self.ts.time_name}")) for col in columns]
347+
348+
316349
@register_aggregation
317350
class Min(AggregationFunction):
318351
"""An aggregation class to find the minimum of values within each aggregation period."""

tests/time_stream/test_aggregation.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from time_stream.period import Period
1111
from time_stream.base import TimeSeries
12-
from time_stream.aggregation import AggregationFunction, Max, Mean, Min
12+
from time_stream.aggregation import AggregationFunction, Max, Mean, Min, MeanSum, Sum
1313

1414

1515
def generate_time_series(resolution: Period, periodicity: Period, length: int, missing_data: bool=False) -> TimeSeries:
@@ -129,31 +129,31 @@ def setUp(self):
129129
self.mock_ts.time_name = "timestamp"
130130

131131
@parameterized.expand([
132-
("mean", Mean), ("min", Min), ("max", Max)
132+
("mean", Mean), ("min", Min), ("max", Max), ("mean_sum", MeanSum), ("sum", Sum)
133133
])
134134
def test_get_with_string(self, get_input, expected):
135135
"""Test AggregationFunction.get() with string input."""
136136
agg = AggregationFunction.get(get_input)
137137
self.assertIsInstance(agg, expected)
138138

139139
@parameterized.expand([
140-
(Mean, Mean), (Min, Min), (Max, Max)
140+
(Mean, Mean), (Min, Min), (Max, Max), (MeanSum, MeanSum), (Sum, Sum)
141141
])
142142
def test_get_with_class(self, get_input, expected):
143143
"""Test AggregationFunction.get() with class input."""
144144
agg = AggregationFunction.get(get_input)
145145
self.assertIsInstance(agg, expected)
146146

147147
@parameterized.expand([
148-
(Mean(), Mean), (Min(), Min), (Max(), Max)
148+
(Mean(), Mean), (Min(), Min), (Max(), Max), (MeanSum(), MeanSum), (Sum(), Sum)
149149
])
150150
def test_get_with_instance(self, get_input, expected):
151151
"""Test AggregationFunction.get() with instance input."""
152152
agg = AggregationFunction.get(get_input)
153153
self.assertIsInstance(agg, expected)
154154

155155
@parameterized.expand([
156-
"Mean", "MIN", "mAx", "123"
156+
"Mean", "MIN", "mAx", "123", "meansum", "sUm"
157157
])
158158
def test_get_with_invalid_string(self, get_input):
159159
"""Test AggregationFunction.get() with invalid string."""
@@ -232,7 +232,14 @@ class TestSimpleAggregations(unittest.TestCase):
232232
233233
("hourly_to_daily_min", ts_PT1H_2days, Min, P1D, "value", [datetime(2025, 1, 1), datetime(2025, 1, 2)],
234234
[24, 24], {"value": [0, 24]}, [datetime(2025, 1, 1), datetime(2025, 1, 2)]),
235+
236+
("hourly_to_daily_mean_sum", ts_PT1H_2days, MeanSum, P1D, "value", [datetime(2025, 1, 1), datetime(2025, 1, 2)],
237+
[24, 24], {"value": [276, 852]}, None),
238+
239+
("hourly_to_daily_sum", ts_PT1H_2days, Sum, P1D, "value", [datetime(2025, 1, 1), datetime(2025, 1, 2)],
240+
[24, 24], {"value": [276, 852]}, None),
235241
])
242+
236243
def test_microsecond_to_microsecond(
237244
self, _, input_ts, aggregator, target_period, column, timestamps, counts, values, timestamps_of
238245
):
@@ -251,6 +258,12 @@ def test_microsecond_to_microsecond(
251258
252259
("hourly_to_monthly_min", ts_PT1H_2month, Min, P1M, "value", [datetime(2025, 1, 1), datetime(2025, 2, 1)],
253260
[744, 672], {"value": [0, 744]}, [datetime(2025, 1, 1), datetime(2025, 2, 1)]),
261+
262+
("hourly_to_monthly_mean_sum", ts_PT1H_2month, MeanSum, P1M, "value", [datetime(2025, 1, 1), datetime(2025, 2, 1)],
263+
[744, 672], {"value": [276396, 725424]}, None),
264+
265+
("hourly_to_monthly_sum", ts_PT1H_2month, Sum, P1M, "value", [datetime(2025, 1, 1), datetime(2025, 2, 1)],
266+
[744, 672], {"value": [276396, 725424]}, None),
254267
])
255268
def test_microsecond_to_month(
256269
self, _, input_ts, aggregator, target_period, column, timestamps, counts, values, timestamps_of
@@ -269,6 +282,12 @@ def test_microsecond_to_month(
269282
270283
("monthly_to_yearly_min", ts_P1M_2years, Min, P1Y, "value", [datetime(2025, 1, 1), datetime(2026, 1, 1)],
271284
[12, 12], {"value": [0, 12]}, [datetime(2025, 1, 1), datetime(2026, 1, 1)]),
285+
286+
("monthly_to_yearly_mean_sum", ts_P1M_2years, MeanSum, P1Y, "value", [datetime(2025, 1, 1), datetime(2026, 1, 1)],
287+
[12, 12], {"value": [66, 210]}, None),
288+
289+
("monthly_to_yearly_sum", ts_P1M_2years, Sum, P1Y, "value", [datetime(2025, 1, 1), datetime(2026, 1, 1)],
290+
[12, 12], {"value": [66, 210]}, None),
272291
])
273292
def test_month_to_month(
274293
self, _, input_ts, aggregator, target_period, column, timestamps, counts, values, timestamps_of
@@ -292,6 +311,14 @@ def test_month_to_month(
292311
[datetime(2025, 1, 1), datetime(2025, 1, 2)], [24, 24],
293312
{"value": [0, 24], "value_plus1": [1, 25], "value_times2": [0, 48]},
294313
[datetime(2025, 1, 1), datetime(2025, 1, 2)]),
314+
315+
("multi_column_mean_sum", ts_PT1H_2days, MeanSum, P1D, ["value", "value_plus1", "value_times2"],
316+
[datetime(2025, 1, 1), datetime(2025, 1, 2)], [24, 24],
317+
{"value": [276, 852], "value_plus1": [300, 876], "value_times2": [552, 1704]}, None),
318+
319+
("multi_column_sum", ts_PT1H_2days, Sum, P1D, ["value", "value_plus1", "value_times2"],
320+
[datetime(2025, 1, 1), datetime(2025, 1, 2)], [24, 24],
321+
{"value": [276, 852], "value_plus1": [300, 876], "value_times2": [552, 1704]}, None),
295322
])
296323
def test_multi_column(
297324
self, _, input_ts, aggregator, target_period, column, timestamps, counts, values, timestamps_of
@@ -492,6 +519,27 @@ def test_missing_criteria_available(self, _, valid, criteria):
492519
assert_frame_equal(result.df, expected_df, check_dtype=False, check_column_order=False, check_exact=False)
493520

494521

522+
class TestMeanSumWithMissingData(unittest.TestCase):
523+
"""Tests the MeanSum aggregation with time series that has missing data."""
524+
def setUp(self):
525+
self.input_ts = ts_PT1H_2days_missing
526+
self.target_period = P1D
527+
self.column = "value"
528+
self.timestamps = [datetime(2025, 1, 1), datetime(2025, 1, 2)]
529+
self.expected_counts = [24, 24]
530+
self.actual_counts = [20, 21]
531+
self.values = {"value": [280.8, 853.71]}
532+
533+
def test_mean_sum_with_missing_data(self):
534+
"""Test MeanSum aggregation with time series that has missing data."""
535+
expected_df = generate_expected_df(
536+
self.timestamps, MeanSum, self.column, self.values, self.expected_counts,
537+
self.actual_counts
538+
)
539+
result = MeanSum().apply(self.input_ts, self.target_period, self.column)
540+
assert_frame_equal(result.df, expected_df, check_dtype=False, check_column_order=False, check_exact=False)
541+
542+
495543
class TestPaddedAggregations(unittest.TestCase):
496544
"""Tests that aggregations work as expected with padded time series."""
497545

0 commit comments

Comments
 (0)