From fbba75a43277fbb4553182a088eb27ee488c3638 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 4 Feb 2022 10:08:48 -0700 Subject: [PATCH] Fix engine='numba' --- flox/aggregate_npg.py | 27 +++++++++++++++++++-------- flox/aggregations.py | 25 +++++++++++++++---------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/flox/aggregate_npg.py b/flox/aggregate_npg.py index 0de005b9e..5597803fd 100644 --- a/flox/aggregate_npg.py +++ b/flox/aggregate_npg.py @@ -2,11 +2,15 @@ import numpy_groupies as npg +def _get_aggregate(engine): + return npg.aggregate_numpy if engine == "numpy" else npg.aggregate_numba + + def sum_of_squares( - group_idx, array, *, axis=-1, func="sum", size=None, fill_value=None, dtype=None + group_idx, array, engine, *, axis=-1, func="sum", size=None, fill_value=None, dtype=None ): - return npg.aggregate_numpy.aggregate( + return _get_aggregate(engine).aggregate( group_idx, array**2, axis=axis, @@ -17,12 +21,12 @@ def sum_of_squares( ) -def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): +def nansum(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): # npg takes out NaNs before calling np.bincount # This means that all NaN groups are equivalent to absent groups # This behaviour does not work for xarray - return npg.aggregate_numpy.aggregate( + return _get_aggregate(engine).aggregate( group_idx, np.where(np.isnan(array), 0, array), axis=axis, @@ -33,12 +37,12 @@ def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None) ) -def nanprod(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): +def nanprod(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): # npg takes out NaNs before calling np.bincount # This means that all NaN groups are equivalent to absent groups # This behaviour does not work for xarray - return npg.aggregate_numpy.aggregate( + return _get_aggregate(engine).aggregate( group_idx, np.where(np.isnan(array), 1, array), axis=axis, @@ -49,7 +53,14 @@ def nanprod(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None ) -def nansum_of_squares(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): +def nansum_of_squares(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): return sum_of_squares( - group_idx, array, func="nansum", size=size, fill_value=fill_value, axis=axis, dtype=dtype + group_idx, + array, + engine=engine, + func="nansum", + size=size, + fill_value=fill_value, + axis=axis, + dtype=dtype, ) diff --git a/flox/aggregations.py b/flox/aggregations.py index 8f389e376..64e4d9cdc 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -10,24 +10,29 @@ def generic_aggregate( - group_idx, array, *, engine, func, axis=-1, size=None, fill_value=None, dtype=None, **kwargs + group_idx, + array, + *, + engine: str, + func: str, + axis=-1, + size=None, + fill_value=None, + dtype=None, + **kwargs, ): if engine == "flox": try: method = getattr(aggregate_flox, func) except AttributeError: method = partial(npg.aggregate_numpy.aggregate, func=func) - elif engine == "numpy": + elif engine in ["numpy", "numba"]: try: - # TODO: fix numba here - method = getattr(aggregate_npg, func) + method_ = getattr(aggregate_npg, func) + method = partial(method_, engine=engine) except AttributeError: - method = partial(npg.aggregate_np, func=func) - elif engine == "numba": - try: - method = getattr(aggregate_npg, f"{func}") - except AttributeError: - method = partial(npg.aggregate_nb, func=func) + aggregate = npg.aggregate_np if engine == "numpy" else npg.aggregate_nb + method = partial(aggregate, func=func) else: raise ValueError( f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead."