Skip to content

Commit 6c499ff

Browse files
authored
Improve precision for mean, std, var, cumsum. (#90)
* Improve precision for mean, std, var. np.bincount always accumulates to float64. So only cast after the division.
1 parent 12405c2 commit 6c499ff

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

numpy_groupies/aggregate_numpy.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,23 +149,27 @@ def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
149149
sums.real = np.bincount(group_idx, weights=a.real, minlength=size)
150150
sums.imag = np.bincount(group_idx, weights=a.imag, minlength=size)
151151
else:
152-
sums = np.bincount(group_idx, weights=a, minlength=size).astype(
153-
dtype, copy=False
154-
)
152+
sums = np.bincount(group_idx, weights=a, minlength=size)
155153

156154
with np.errstate(divide="ignore", invalid="ignore"):
157-
ret = sums.astype(dtype, copy=False) / counts
155+
ret = sums / counts
158156
if not np.isnan(fill_value):
159157
ret[counts == 0] = fill_value
160-
return ret
158+
if iscomplexobj(a):
159+
return ret
160+
else:
161+
return ret.astype(dtype, copy=False)
161162

162163

163164
def _sum_of_squres(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
164165
ret = np.bincount(group_idx, weights=a * a, minlength=size)
165166
if fill_value != 0:
166167
counts = np.bincount(group_idx, minlength=size)
167168
ret[counts == 0] = fill_value
168-
return ret
169+
if iscomplexobj(a):
170+
return ret
171+
else:
172+
return ret.astype(dtype, copy=False)
169173

170174

171175
def _var(
@@ -176,7 +180,7 @@ def _var(
176180
counts = np.bincount(group_idx, minlength=size)
177181
sums = np.bincount(group_idx, weights=a, minlength=size)
178182
with np.errstate(divide="ignore", invalid="ignore"):
179-
means = sums.astype(dtype, copy=False) / counts
183+
means = sums / counts
180184
counts = np.where(counts > ddof, counts - ddof, 0)
181185
ret = (
182186
np.bincount(group_idx, (a - means[group_idx]) ** 2, minlength=size) / counts
@@ -185,7 +189,10 @@ def _var(
185189
ret = np.sqrt(ret) # this is now std not var
186190
if not np.isnan(fill_value):
187191
ret[counts == 0] = fill_value
188-
return ret
192+
if iscomplexobj(a):
193+
return ret
194+
else:
195+
return ret.astype(dtype, copy=False)
189196

190197

191198
def _std(group_idx, a, size, fill_value, dtype=np.dtype(np.float64), ddof=0):
@@ -252,7 +259,10 @@ def _cumsum(group_idx, a, size, fill_value=None, dtype=None):
252259

253260
increasing = np.arange(len(a), dtype=int)
254261
group_starts = _min(group_idx_srt, increasing, size, fill_value=0)[group_idx_srt]
255-
a_srt_cumsum += -a_srt_cumsum[group_starts] + a_srt[group_starts]
262+
# First subtract large numbers
263+
a_srt_cumsum -= a_srt_cumsum[group_starts]
264+
# Then add potentially small numbers
265+
a_srt_cumsum += a_srt[group_starts]
256266
return a_srt_cumsum[invsortidx]
257267

258268

numpy_groupies/tests/test_generic.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,3 +570,14 @@ def test_var_with_nan_fill_value(aggregate_all, ddof, nan_inds, func):
570570
group_idx, a, axis=-1, fill_value=np.nan, func=func, ddof=ddof
571571
)
572572
np.testing.assert_equal(actual, expected)
573+
574+
575+
def test_cumsum_accuracy(aggregate_all):
576+
array = np.array(
577+
[0.00000000e00, 0.00000000e00, 0.00000000e00, 3.27680000e04, 9.99999975e-06]
578+
)
579+
group_idx = np.array([0, 0, 0, 0, 1])
580+
581+
actual = aggregate_all(group_idx, array, axis=-1, func="cumsum")
582+
expected = array
583+
np.testing.assert_allclose(actual, expected)

0 commit comments

Comments
 (0)