Skip to content

Commit 36f6e01

Browse files
committed
Some fixes.
1 parent 2a74158 commit 36f6e01

File tree

3 files changed

+61
-31
lines changed

3 files changed

+61
-31
lines changed

flox/aggregations.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -573,17 +573,21 @@ class Scan:
573573
# This dataclass is separate from aggregations since there's not much in common
574574
# between reductions and scans
575575
name: str
576-
# binary ufunc name (e.g. add)
577-
ufunc: np.ufunc
578-
# in-memory scan function (e.g. cumsum)
576+
# binary operation (e.g. add)
577+
binary_op: Callable
578+
# in-memory grouped scan function (e.g. cumsum)
579579
scan: str
580-
# reduction that yields the last result of the scan (e.g. sum)
580+
# Grouped reduction that yields the last result of the scan (e.g. sum)
581581
reduction: str
582+
# Identity element
583+
identity: Any
584+
# dtype of result
585+
dtype: Any = None
582586

583587

584-
cumsum = Scan("cumsum", ufunc=np.add, reduction="sum", scan="cumsum")
585-
nancumsum = Scan("nancumsum", ufunc=np.add, reduction="nansum", scan="nancumsum")
586-
# cumprod = Scan("cumprod", ufunc=np.multiply, preop="prod", scan="cumprod")
588+
cumsum = Scan("cumsum", binary_op=np.add, reduction="sum", scan="cumsum", identity=0)
589+
nancumsum = Scan("nancumsum", binary_op=np.add, reduction="nansum", scan="nancumsum", identity=0)
590+
# cumprod = Scan("cumprod", binary_op=np.multiply, preop="prod", scan="cumprod")
587591

588592

589593
aggregations = {

flox/core.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2644,31 +2644,37 @@ def __post_init__(self):
26442644
assert self.array.shape[-1] == self.group_idx.size
26452645

26462646

2647-
def grouped_scan(inp: AlignedArrays, *, func, axis, dtype=None, keepdims=None) -> AlignedArrays:
2647+
def grouped_scan(
2648+
inp: AlignedArrays, *, func: str, axis, fill_value=None, dtype=None, keepdims=None
2649+
) -> AlignedArrays:
26482650
assert axis == inp.array.ndim - 1
26492651
accumulated = generic_aggregate(
2650-
inp.group_idx, inp.array, axis=axis, engine="numpy", func=func, dtype=dtype
2652+
inp.group_idx,
2653+
inp.array,
2654+
axis=axis,
2655+
engine="numpy",
2656+
func=func,
2657+
dtype=dtype,
2658+
fill_value=fill_value,
26512659
)
26522660
return AlignedArrays(array=accumulated, group_idx=inp.group_idx)
26532661

26542662

2655-
def grouped_reduce(
2656-
inp: AlignedArrays, *, func, axis, fill_value=None, dtype=None, keepdims=None
2657-
) -> AlignedArrays:
2663+
def grouped_reduce(inp: AlignedArrays, *, agg: Scan, axis: int, keepdims=None) -> AlignedArrays:
26582664
assert axis == inp.array.ndim - 1
26592665
reduced = generic_aggregate(
26602666
inp.group_idx,
26612667
inp.array,
26622668
axis=axis,
26632669
engine="numpy",
2664-
func=func,
2665-
dtype=dtype,
2666-
fill_value=fill_value,
2670+
func=agg.reduction,
2671+
dtype=inp.array.dtype,
2672+
fill_value=agg.binary_op.identity,
26672673
)
26682674
return AlignedArrays(array=reduced, group_idx=np.arange(reduced.shape[-1]))
26692675

26702676

2671-
def grouped_binop(left: AlignedArrays, right: AlignedArrays, op: np.ufunc) -> AlignedArrays:
2677+
def grouped_binop(left: AlignedArrays, right: AlignedArrays, op: Callable) -> AlignedArrays:
26722678
reindexed = reindex_(
26732679
left.array,
26742680
from_=pd.Index(left.group_idx),
@@ -2708,26 +2714,39 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan):
27082714
_zip, by, array, dtype=array.dtype, meta=array._meta, name="groupby-scan-preprocess"
27092715
)
27102716

2717+
# TODO: move to aggregate_npg.py
2718+
if agg.name in ["cumsum", "nancumsum"]:
2719+
# https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
2720+
# it defaults to the dtype of a, unless a
2721+
# has an integer dtype with a precision less than that of the default platform integer.
2722+
if array.dtype.kind == "i":
2723+
agg.dtype = np.result_type(array.dtype, np.intp)
2724+
elif array.dtype.kind == "u":
2725+
agg.dtype = np.result_type(array.dtype, np.uintp)
2726+
else:
2727+
agg.dtype = array.dtype
2728+
else:
2729+
agg.dtype = array.dtype
2730+
2731+
scan_ = partial(grouped_scan, func=agg.scan, fill_value=agg.identity)
27112732
# dask tokenizing error workaround
2712-
scan_ = partial(grouped_scan, func=agg.scan)
27132733
scan_.__name__ = scan_.func.__name__
27142734

27152735
# 2. Run the scan
27162736
accumulated = scan(
27172737
func=scan_,
2718-
binop=partial(grouped_binop, op=agg.ufunc),
2719-
ident=agg.ufunc.identity,
2738+
binop=partial(grouped_binop, op=agg.binary_op),
2739+
ident=agg.identity,
27202740
x=zipped,
27212741
axis=axis,
27222742
method="blelloch",
2723-
preop=partial(grouped_reduce, func=agg.reduction, fill_value=agg.ufunc.identity),
2724-
dtype=array.dtype,
2743+
preop=partial(grouped_reduce, agg=agg),
2744+
dtype=agg.dtype,
27252745
)
27262746

27272747
# 3. Unzip and extract the final result array, discard groups
2728-
result = map_blocks(extract_array, accumulated, dtype=array.dtype)
2748+
result = map_blocks(extract_array, accumulated, dtype=agg.dtype)
27292749

2730-
assert result.dtype == array.dtype
27312750
assert result.chunks == array.chunks
27322751

27332752
return result

tests/test_properties.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
import numpy as np
99
from hypothesis import HealthCheck, assume, given, note, settings
1010

11+
from flox.aggregations import cumsum
1112
from flox.core import dask_groupby_scan, groupby_reduce
1213

1314
from . import ALL_FUNCS, SCIPY_STATS_FUNCS, assert_equal
1415

16+
dask.config.set(scheduler="sync")
1517
NON_NUMPY_FUNCS = ["first", "last", "nanfirst", "nanlast", "count", "any", "all"] + list(
1618
SCIPY_STATS_FUNCS
1719
)
@@ -128,25 +130,30 @@ def chunked_arrays(
128130
return from_array(array, chunks=("auto",) * (array.ndim - 1) + (chunks,))
129131

130132

131-
from flox.aggregations import cumsum
132-
133-
dask.config.set(scheduler="sync")
134-
135-
136133
def test():
137-
array = np.array([0.0, 0.0, 0.0], dtype=np.float32)
134+
# TODO: FIX
135+
# array =np.array([[5592407., 5592407.],
136+
# [5592407., 5592407.]], dtype=np.float32)
137+
138+
array = np.array([1, 1, 1], dtype=np.uint64)
138139
da = dask.array.from_array(array, chunks=2)
139140
actual = dask_groupby_scan(
140141
da, np.array([0] * array.shape[-1]), agg=cumsum, axes=(array.ndim - 1,)
141142
)
142143
actual.compute()
144+
expected = np.cumsum(array, axis=-1)
145+
np.testing.assert_array_equal(expected, actual)
143146

144147

145148
@given(data=st.data(), array=chunked_arrays())
146149
def test_scans(data, array):
147150
note(np.array(array))
151+
# overflow behaviour differs between bincount and sum (for example)
152+
assume(not_overflowing_array(np.asarray(array)))
153+
148154
actual = dask_groupby_scan(
149-
array, np.array([0] * array.shape[-1]), agg=cumsum, axes=(array.ndim - 1,)
155+
array, np.repeat(0, array.shape[-1]), agg=cumsum, axes=(array.ndim - 1,)
150156
)
151157
expected = np.cumsum(np.asarray(array), axis=-1)
152-
np.testing.assert_array_equal(expected, actual)
158+
tolerance = {"rtol": 1e-13, "atol": 1e-15}
159+
assert_equal(actual, expected, tolerance)

0 commit comments

Comments
 (0)