Skip to content

Fix engine='numba' #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions flox/aggregate_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
import numpy_groupies as npg


def _get_aggregate(engine):
return npg.aggregate_numpy if engine == "numpy" else npg.aggregate_numba


def sum_of_squares(
group_idx, array, *, axis=-1, func="sum", size=None, fill_value=None, dtype=None
group_idx, array, engine, *, axis=-1, func="sum", size=None, fill_value=None, dtype=None
):

return npg.aggregate_numpy.aggregate(
return _get_aggregate(engine).aggregate(
group_idx,
array**2,
axis=axis,
Expand All @@ -17,12 +21,12 @@ def sum_of_squares(
)


def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
def nansum(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
# npg takes out NaNs before calling np.bincount
# This means that all NaN groups are equivalent to absent groups
# This behaviour does not work for xarray

return npg.aggregate_numpy.aggregate(
return _get_aggregate(engine).aggregate(
group_idx,
np.where(np.isnan(array), 0, array),
axis=axis,
Expand All @@ -33,12 +37,12 @@ def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None)
)


def nanprod(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
def nanprod(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
# npg takes out NaNs before calling np.bincount
# This means that all NaN groups are equivalent to absent groups
# This behaviour does not work for xarray

return npg.aggregate_numpy.aggregate(
return _get_aggregate(engine).aggregate(
group_idx,
np.where(np.isnan(array), 1, array),
axis=axis,
Expand All @@ -49,7 +53,14 @@ def nanprod(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None
)


def nansum_of_squares(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
def nansum_of_squares(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
return sum_of_squares(
group_idx, array, func="nansum", size=size, fill_value=fill_value, axis=axis, dtype=dtype
group_idx,
array,
engine=engine,
func="nansum",
size=size,
fill_value=fill_value,
axis=axis,
dtype=dtype,
)
25 changes: 15 additions & 10 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,29 @@


def generic_aggregate(
group_idx, array, *, engine, func, axis=-1, size=None, fill_value=None, dtype=None, **kwargs
group_idx,
array,
*,
engine: str,
func: str,
axis=-1,
size=None,
fill_value=None,
dtype=None,
**kwargs,
):
if engine == "flox":
try:
method = getattr(aggregate_flox, func)
except AttributeError:
method = partial(npg.aggregate_numpy.aggregate, func=func)
elif engine == "numpy":
elif engine in ["numpy", "numba"]:
try:
# TODO: fix numba here
method = getattr(aggregate_npg, func)
method_ = getattr(aggregate_npg, func)
method = partial(method_, engine=engine)
except AttributeError:
method = partial(npg.aggregate_np, func=func)
elif engine == "numba":
try:
method = getattr(aggregate_npg, f"{func}")
except AttributeError:
method = partial(npg.aggregate_nb, func=func)
aggregate = npg.aggregate_np if engine == "numpy" else npg.aggregate_nb
method = partial(aggregate, func=func)
else:
raise ValueError(
f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead."
Expand Down