From 650088b2a3b3ab51daae0e87df1b65494f5f61f5 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 27 Jul 2024 08:28:45 -0600 Subject: [PATCH 01/45] Add topk --- flox/aggregate_flox.py | 119 ++++++++++++++++++++++++++--------------- flox/aggregations.py | 13 +++++ flox/core.py | 9 +++- flox/xarray.py | 25 +++++---- 4 files changed, 113 insertions(+), 53 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 7174552c8..b51a0aaaa 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -46,74 +46,107 @@ def _lerp(a, b, *, t, dtype, out=None): return out -def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=None): - inv_idx = np.concatenate((inv_idx, [array.shape[-1]])) +def quantile_or_topk( + array, inv_idx, *, q=None, k=None, axis, skipna, group_idx, dtype=None, out=None +): + assert q or k - array_nanmask = isnull(array) - actual_sizes = np.add.reduceat(~array_nanmask, inv_idx[:-1], axis=axis) - newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,) - full_sizes = np.reshape(np.diff(inv_idx), newshape) - nanmask = full_sizes != actual_sizes + inv_idx = np.concatenate((inv_idx, [array.shape[-1]])) - # The approach here is to use (complex_array.partition) because + # The approach for quantiles and topk, both of which are basically grouped partition, + # here is to use (complex_array.partition) because # 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary # 2. Using record_array.partition(..., order=["labels", "array"]) is incredibly slow. - # partition will first sort by real part, then by imaginary part, so it is a two element lex-partition. - # So we set + # partition will first sort by real part, then by imaginary part, so it is a two element + # lex-partition. Therefore we set # complex_array = group_idx + 1j * array # group_idx is an integer (guaranteed), but array can have NaNs. Now, # 1 + 1j*NaN = NaN + 1j * NaN # so we must replace all NaNs with the maximum array value in the group so these NaNs # get sorted to the end. + + # Replace NaNs with the maximum value for each group. # Partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/ - # TODO: Don't know if this array has been copied in _prepare_for_flox. This is potentially wasteful + array_nanmask = isnull(array) + actual_sizes = np.add.reduceat(~array_nanmask, inv_idx[:-1], axis=axis) + newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,) + full_sizes = np.reshape(np.diff(inv_idx), newshape) + nanmask = full_sizes != actual_sizes + # TODO: Don't know if this array has been copied in _prepare_for_flox. + # This is potentially wasteful array = np.where(array_nanmask, -np.inf, array) maxes = np.maximum.reduceat(array, inv_idx[:-1], axis=axis) replacement = np.repeat(maxes, np.diff(inv_idx), axis=axis) array[array_nanmask] = replacement[array_nanmask] - qin = q - q = np.atleast_1d(qin) - q = np.reshape(q, (len(q),) + (1,) * array.ndim) - - # This is numpy's method="linear" - # TODO: could support all the interpolations here - virtual_index = q * (actual_sizes - 1) + inv_idx[:-1] + param = q or k + if k is not None: + assert k > 0 + is_scalar_param = False + param = np.arange(k) + else: + is_scalar_param = is_scalar(q) + param = np.atleast_1d(param) + param = np.reshape(param, (param.size,) + (1,) * array.ndim) - is_scalar_q = is_scalar(qin) - if is_scalar_q: - virtual_index = virtual_index.squeeze(axis=0) + if is_scalar_param: idxshape = array.shape[:-1] + (actual_sizes.shape[-1],) else: - idxshape = (q.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],) + idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],) - lo_ = np.floor( - virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64) - ) - hi_ = np.ceil( - virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64) - ) - kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)])) + if q is not None: + # This is numpy's method="linear" + # TODO: could support all the interpolations here + virtual_index = param * (actual_sizes - 1) + inv_idx[:-1] + + if is_scalar_param: + virtual_index = virtual_index.squeeze(axis=0) + + lo_ = np.floor( + virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64) + ) + hi_ = np.ceil( + virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64) + ) + kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)])) + + else: + virtual_index = (actual_sizes - k) + inv_idx[:-1] + kth = np.unique(virtual_index) + kth = kth[kth > 0] + k_offset = np.arange(k).reshape((k,) + (1,) * virtual_index.ndim) + lo_ = k_offset + virtual_index[np.newaxis, ...] # partition the complex array in-place labels_broadcast = np.broadcast_to(group_idx, array.shape) with np.errstate(invalid="ignore"): cmplx = labels_broadcast + 1j * array cmplx.partition(kth=kth, axis=-1) - if is_scalar_q: + + if is_scalar_param: a_ = cmplx.imag else: - a_ = np.broadcast_to(cmplx.imag, (q.shape[0],) + array.shape) + a_ = np.broadcast_to(cmplx.imag, (param.shape[0],) + array.shape) - # get bounds, Broadcast to (num quantiles, ..., num labels) loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis) - hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis) + if q is not None: + # get bounds, Broadcast to (num quantiles, ..., num labels) + hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis) + + # TODO: could support all the interpolations here + gamma = np.broadcast_to(virtual_index, idxshape) - lo_ + result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype) + else: + import ipdb - # TODO: could support all the interpolations here - gamma = np.broadcast_to(virtual_index, idxshape) - lo_ - result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype) + ipdb.set_trace() + result = loval + result[lo_ < 0] = np.nan if not skipna and np.any(nanmask): result[..., nanmask] = np.nan + if k is not None: + result = result.astype(array.dtype, copy=False) + np.copyto(out, result) return result @@ -138,10 +171,11 @@ def _np_grouped_op( if out is None: q = kwargs.get("q", None) - if q is None: + k = kwargs.get("k", None) + if not q and not k: out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype) else: - nq = len(np.atleast_1d(q)) + nq = len(np.atleast_1d(q)) if q is not None else k out = np.full((nq,) + array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype) kwargs["group_idx"] = group_idx @@ -178,10 +212,11 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs): nanmax = partial(_nan_grouped_op, func=max, fillna=-np.inf) min = partial(_np_grouped_op, op=np.minimum.reduceat) nanmin = partial(_nan_grouped_op, func=min, fillna=np.inf) -quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False)) -nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True)) -median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False)) -nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=True)) +quantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=False)) +topk = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True)) +nanquantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True)) +median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=False)) +nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=True)) # TODO: all, any diff --git a/flox/aggregations.py b/flox/aggregations.py index 51e650a66..3472c5a69 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -563,6 +563,10 @@ def quantile_new_dims_func(q) -> tuple[Dim]: return (Dim(name="quantile", values=q),) +def topk_new_dims_func(k) -> tuple[Dim]: + return (Dim(name="k", values=np.arange(k)),) + + quantile = Aggregation( name="quantile", fill_value=dtypes.NA, @@ -579,6 +583,14 @@ def quantile_new_dims_func(q) -> tuple[Dim]: final_dtype=np.floating, new_dims_func=quantile_new_dims_func, ) +topk = Aggregation( + name="topk", + fill_value=dtypes.NINF, + chunk=None, + combine=None, + final_dtype=None, + new_dims_func=topk_new_dims_func, +) mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None) nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None) @@ -778,6 +790,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan) "nanquantile": nanquantile, "mode": mode, "nanmode": nanmode, + "topk": topk, # "cumsum": cumsum, "nancumsum": nancumsum, "ffill": ffill, diff --git a/flox/core.py b/flox/core.py index 8eaf54b74..a6c24fa65 100644 --- a/flox/core.py +++ b/flox/core.py @@ -42,6 +42,7 @@ _initialize_aggregation, generic_aggregate, quantile_new_dims_func, + topk_new_dims_func, ) from .cache import memoize from .xrutils import ( @@ -1081,6 +1082,10 @@ def chunk_reduce( new_dims_shape = tuple( dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar ) + elif reduction == "topk": + new_dims_shape = tuple( + dim.size for dim in topk_new_dims_func(**kw) if not dim.is_scalar + ) else: new_dims_shape = tuple() result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape) @@ -2205,7 +2210,7 @@ def _choose_engine(by, agg: Aggregation): not_arg_reduce = not _is_arg_reduction(agg) - if agg.name in ["quantile", "nanquantile", "median", "nanmedian"]: + if agg.name in ["quantile", "nanquantile", "median", "nanmedian", "topk"]: logger.debug(f"_choose_engine: Choosing 'flox' since {agg.name}") return "flox" @@ -2258,7 +2263,7 @@ def groupby_reduce( equality check are for dimensions of size 1 in `by`. func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ - "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ + "quantile", "nanquantile", "median", "nanmedian", "topk", "mode", "nanmode", \ "first", "nanfirst", "last", "nanlast"} or Aggregation Single function name or an Aggregation instance expected_groups : (optional) Sequence diff --git a/flox/xarray.py b/flox/xarray.py index 11cf706d4..e8f8d4302 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -9,7 +9,13 @@ from packaging.version import Version from xarray.core.duck_array_ops import _datetime_nanmin -from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func +from .aggregations import ( + Aggregation, + Dim, + _atleast_1d, + quantile_new_dims_func, + topk_new_dims_func, +) from .core import ( _convert_expected_groups_to_index, _get_expected_groups, @@ -92,7 +98,7 @@ def xarray_reduce( Variables with which to group by ``obj`` func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ - "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ + "quantile", "nanquantile", "median", "nanmedian", "topk", "mode", "nanmode", \ "first", "nanfirst", "last", "nanlast"} or Aggregation Single function name or an Aggregation instance expected_groups : str or sequence @@ -390,17 +396,18 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): result, *groups = groupby_reduce(array, *by, func=func, **kwargs) - # Transpose the new quantile dimension to the end. This is ugly. + # Transpose the new quantile or topk dimension to the end. This is ugly. # but new core dimensions are expected at the end :/ # but groupby_reduce inserts them at the beginning if func in ["quantile", "nanquantile"]: (newdim,) = quantile_new_dims_func(**finalize_kwargs) - if not newdim.is_scalar: - # NOTE: _restore_dim_order will move any new dims to the end anyway. - # This transpose is simply makes it easy to specify output_core_dims - # output dim order: (*broadcast_dims, *group_dims, quantile_dim) - result = np.moveaxis(result, 0, -1) - + elif func == "topk": + (newdim,) = topk_new_dims_func(**finalize_kwargs) + if not newdim.is_scalar: + # NOTE: _restore_dim_order will move any new dims to the end anyway. + # This transpose is simply makes it easy to specify output_core_dims + # output dim order: (*broadcast_dims, *group_dims, quantile_dim) + result = np.moveaxis(result, 0, -1) # Output of count has an int dtype. if requires_numeric and func != "count": if is_npdatetime: From 889be0c580b5f6b5aaf9ab1cdca7a72f3d357e7a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 27 Jul 2024 20:41:18 -0600 Subject: [PATCH 02/45] Negative k --- flox/aggregate_flox.py | 31 +++++++++++++++++++------------ flox/aggregations.py | 4 +++- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index b51a0aaaa..3c4c875bb 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -47,7 +47,17 @@ def _lerp(a, b, *, t, dtype, out=None): def quantile_or_topk( - array, inv_idx, *, q=None, k=None, axis, skipna, group_idx, dtype=None, out=None + array, + inv_idx, + *, + q=None, + k=None, + axis, + skipna, + group_idx, + dtype=None, + out=None, + fill_value=None, ): assert q or k @@ -81,9 +91,8 @@ def quantile_or_topk( param = q or k if k is not None: - assert k > 0 is_scalar_param = False - param = np.arange(k) + param = np.arange(abs(k)) else: is_scalar_param = is_scalar(q) param = np.atleast_1d(param) @@ -111,10 +120,10 @@ def quantile_or_topk( kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)])) else: - virtual_index = (actual_sizes - k) + inv_idx[:-1] + virtual_index = inv_idx[:-1] + ((actual_sizes - k) if k > 0 else abs(k) - 1) kth = np.unique(virtual_index) kth = kth[kth > 0] - k_offset = np.arange(k).reshape((k,) + (1,) * virtual_index.ndim) + k_offset = param.reshape((abs(k),) + (1,) * virtual_index.ndim) lo_ = k_offset + virtual_index[np.newaxis, ...] # partition the complex array in-place @@ -137,15 +146,12 @@ def quantile_or_topk( gamma = np.broadcast_to(virtual_index, idxshape) - lo_ result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype) else: - import ipdb - - ipdb.set_trace() result = loval - result[lo_ < 0] = np.nan + result[lo_ < 0] = fill_value if not skipna and np.any(nanmask): - result[..., nanmask] = np.nan + result[..., nanmask] = fill_value if k is not None: - result = result.astype(array.dtype, copy=False) + result = result.astype(dtype, copy=False) np.copyto(out, result) return result @@ -175,9 +181,10 @@ def _np_grouped_op( if not q and not k: out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype) else: - nq = len(np.atleast_1d(q)) if q is not None else k + nq = len(np.atleast_1d(q)) if q is not None else abs(k) out = np.full((nq,) + array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype) kwargs["group_idx"] = group_idx + kwargs["fill_value"] = fill_value if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all(): # The previous version of this if condition diff --git a/flox/aggregations.py b/flox/aggregations.py index 3472c5a69..a41ca0173 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -564,7 +564,7 @@ def quantile_new_dims_func(q) -> tuple[Dim]: def topk_new_dims_func(k) -> tuple[Dim]: - return (Dim(name="k", values=np.arange(k)),) + return (Dim(name="k", values=np.arange(abs(k))),) quantile = Aggregation( @@ -848,6 +848,8 @@ def _initialize_aggregation( ), } + if agg.name == "topk" and finalize_kwargs["k"] < 0: + agg.fill_value["intermediate"] = (dtypes.INF,) # Replace sentinel fill values according to dtype agg.fill_value["user"] = fill_value agg.fill_value["intermediate"] = tuple( From 996ff2a4b18516787e6126b80cc9bfcbf35faf17 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 27 Jul 2024 20:41:27 -0600 Subject: [PATCH 03/45] dask support --- flox/aggregations.py | 9 +++------ flox/xrutils.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index a41ca0173..355603645 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -586,8 +586,8 @@ def topk_new_dims_func(k) -> tuple[Dim]: topk = Aggregation( name="topk", fill_value=dtypes.NINF, - chunk=None, - combine=None, + chunk="topk", + combine=xrutils.topk, final_dtype=None, new_dims_func=topk_new_dims_func, ) @@ -890,10 +890,7 @@ def _initialize_aggregation( simple_combine: list[Callable | None] = [] for combine in agg.combine: if isinstance(combine, str): - if combine in ["nanfirst", "nanlast"]: - simple_combine.append(getattr(xrutils, combine)) - else: - simple_combine.append(getattr(np, combine)) + simple_combine.append(getattr(np, combine)) else: simple_combine.append(combine) diff --git a/flox/xrutils.py b/flox/xrutils.py index 12bf54a10..9f72f04b7 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -378,3 +378,21 @@ def nanlast(values, axis, keepdims=False): return np.expand_dims(result, axis=axis) else: return result + + +def topk(a, k, axis, keepdims): + """Chunk and combine function of topk + + Extract the k largest elements from a on the given axis. + If k is negative, extract the -k smallest elements instead. + Note that, unlike in the parent function, the returned elements + are not sorted internally. + """ + assert keepdims is True + axis = axis[0] + if abs(k) >= a.shape[axis]: + return a + + a = np.partition(a, -k, axis=axis) + k_slice = slice(-k, None) if k > 0 else slice(-k) + return a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))] From 776d2339a5c65921b17737e0fbbc60378a10f310 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 27 Jul 2024 20:41:38 -0600 Subject: [PATCH 04/45] test --- flox/aggregate_flox.py | 19 ++++++++++++------- flox/aggregations.py | 4 +++- flox/core.py | 6 +++++- flox/xrutils.py | 7 +++++-- tests/test_properties.py | 13 +++++++++++++ 5 files changed, 38 insertions(+), 11 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 3c4c875bb..3ef800102 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -98,10 +98,8 @@ def quantile_or_topk( param = np.atleast_1d(param) param = np.reshape(param, (param.size,) + (1,) * array.ndim) - if is_scalar_param: - idxshape = array.shape[:-1] + (actual_sizes.shape[-1],) - else: - idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],) + # For topk(.., k=+1 or -1), we always return the singleton dimension. + idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],) if q is not None: # This is numpy's method="linear" @@ -110,6 +108,7 @@ def quantile_or_topk( if is_scalar_param: virtual_index = virtual_index.squeeze(axis=0) + idxshape = array.shape[:-1] + (actual_sizes.shape[-1],) lo_ = np.floor( virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64) @@ -122,7 +121,7 @@ def quantile_or_topk( else: virtual_index = inv_idx[:-1] + ((actual_sizes - k) if k > 0 else abs(k) - 1) kth = np.unique(virtual_index) - kth = kth[kth > 0] + kth = kth[kth >= 0] k_offset = param.reshape((abs(k),) + (1,) * virtual_index.ndim) lo_ = k_offset + virtual_index[np.newaxis, ...] @@ -147,12 +146,18 @@ def quantile_or_topk( result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype) else: result = loval - result[lo_ < 0] = fill_value + # This happens if numel in group < abs(k) + badmask = lo_ < 0 + if badmask.any(): + result[badmask] = fill_value + if not skipna and np.any(nanmask): result[..., nanmask] = fill_value + if k is not None: result = result.astype(dtype, copy=False) - np.copyto(out, result) + if out is not None: + np.copyto(out, result) return result diff --git a/flox/aggregations.py b/flox/aggregations.py index 355603645..0c2e82bbf 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -830,7 +830,7 @@ def _initialize_aggregation( ) final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value) - if agg.name not in ["first", "last", "nanfirst", "nanlast", "min", "max", "nanmin", "nanmax"]: + if agg.name not in ["first", "last", "nanfirst", "nanlast", "min", "max", "nanmin", "nanmax", "topk"]: final_dtype = _maybe_promote_int(final_dtype) agg.dtype = { "user": dtype, # Save to automatically choose an engine @@ -892,6 +892,8 @@ def _initialize_aggregation( if isinstance(combine, str): simple_combine.append(getattr(np, combine)) else: + if agg.name == "topk": + combine = partial(combine, **finalize_kwargs) simple_combine.append(combine) agg.simple_combine = tuple(simple_combine) diff --git a/flox/core.py b/flox/core.py index a6c24fa65..d357aa672 100644 --- a/flox/core.py +++ b/flox/core.py @@ -958,7 +958,7 @@ def chunk_reduce( nfuncs = len(funcs) dtypes = _atleast_1d(dtype, nfuncs) fill_values = _atleast_1d(fill_value, nfuncs) - kwargss = _atleast_1d({}, nfuncs) if kwargs is None else kwargs + kwargss = _atleast_1d({} if kwargs is None else kwargs, nfuncs) if isinstance(axis, Sequence): axes: T_Axes = axis @@ -1645,6 +1645,7 @@ def dask_groupby_agg( dtype=agg.dtype["intermediate"], reindex=reindex, user_dtype=agg.dtype["user"], + kwargs=agg.finalize_kwargs if agg.name == "topk" else None, ) if do_simple_combine: # Add a dummy dimension that then gets reduced over @@ -2372,6 +2373,9 @@ def groupby_reduce( "Use engine='flox' instead (it is also much faster), " "or set engine=None to use the default." ) + if func == "topk": + if finalize_kwargs is None or "k" not in finalize_kwargs: + raise ValueError("Please pass `k` for topk calculations.") bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) nby = len(bys) diff --git a/flox/xrutils.py b/flox/xrutils.py index 9f72f04b7..e85e327f5 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -389,10 +389,13 @@ def topk(a, k, axis, keepdims): are not sorted internally. """ assert keepdims is True - axis = axis[0] + (axis,) = axis + axis = normalize_axis_index(axis, a.ndim) if abs(k) >= a.shape[axis]: return a + # TODO: handle NaNs a = np.partition(a, -k, axis=axis) k_slice = slice(-k, None) if k > 0 else slice(-k) - return a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))] + result = a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))] + return result diff --git a/tests/test_properties.py b/tests/test_properties.py index 26150519c..6735367ad 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -210,3 +210,16 @@ def test_first_last(data, array: dask.array.Array, func: str) -> None: first, *_ = groupby_reduce(array, by, func=func, engine="flox") second, *_ = groupby_reduce(array, by, func=mate, engine="flox") assert_equal(first, second) + + +@given(data=st.data(), array=chunked_arrays()) +def test_topk_max_min(data, array): + "top 1 == max; top -1 == min" + size = array.shape[-1] + by = data.draw(by_arrays(shape=(size,))) + k, npfunc = data.draw(st.sampled_from([(1, "max"), (-1, "min")])) + + for a in (array, array.compute()): + actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k}) + expected, _ = groupby_reduce(a, by, func=npfunc) + assert_equal(actual, expected[np.newaxis, :]) From a5eb7b957f4934b8e55e3f25d981276b7c301a1a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 27 Jul 2024 21:21:25 -0600 Subject: [PATCH 05/45] wip --- tests/test_properties.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_properties.py b/tests/test_properties.py index 6735367ad..821c93188 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -212,6 +212,10 @@ def test_first_last(data, array: dask.array.Array, func: str) -> None: assert_equal(first, second) +from hypothesis import settings + + +@settings(report_multiple_bugs=False) @given(data=st.data(), array=chunked_arrays()) def test_topk_max_min(data, array): "top 1 == max; top -1 == min" From 4fa9a4cea72bfe0f56f1c80607a89b21da2a5f03 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 27 Jul 2024 23:00:50 -0600 Subject: [PATCH 06/45] fix --- flox/xarray.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flox/xarray.py b/flox/xarray.py index e8f8d4302..aef0421f7 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -403,7 +403,9 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): (newdim,) = quantile_new_dims_func(**finalize_kwargs) elif func == "topk": (newdim,) = topk_new_dims_func(**finalize_kwargs) - if not newdim.is_scalar: + else: + newdim = None + if newdim is not None and not newdim.is_scalar: # NOTE: _restore_dim_order will move any new dims to the end anyway. # This transpose is simply makes it easy to specify output_core_dims # output dim order: (*broadcast_dims, *group_dims, quantile_dim) From 4b04fde339e36d05089e74b50a3908a7bf51c8e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 28 Jul 2024 20:16:56 +0000 Subject: [PATCH 07/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flox/aggregations.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 0c2e82bbf..f33a0d513 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -830,7 +830,17 @@ def _initialize_aggregation( ) final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value) - if agg.name not in ["first", "last", "nanfirst", "nanlast", "min", "max", "nanmin", "nanmax", "topk"]: + if agg.name not in [ + "first", + "last", + "nanfirst", + "nanlast", + "min", + "max", + "nanmin", + "nanmax", + "topk", + ]: final_dtype = _maybe_promote_int(final_dtype) agg.dtype = { "user": dtype, # Save to automatically choose an engine From 93800aae1b8ddcd72a9681ac8a84b486d9c8915d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 31 Jul 2024 14:58:07 -0600 Subject: [PATCH 08/45] Handle dtypes.NA properly for datetime/timedelta --- flox/aggregations.py | 18 +++++++++--------- tests/test_properties.py | 5 +++-- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index f33a0d513..644dca437 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -158,12 +158,12 @@ def _get_fill_value(dtype, fill_value): return np.nan # This is madness, but npg checks that fill_value is compatible # with array dtype even if the fill_value is never used. - elif ( - np.issubdtype(dtype, np.integer) - or np.issubdtype(dtype, np.timedelta64) - or np.issubdtype(dtype, np.datetime64) - ): + elif np.issubdtype(dtype, np.integer): return dtypes.get_neg_infinity(dtype, min_for_int=True) + elif np.issubdtype(dtype, np.timedelta64): + return np.timedelta64("NaT") + elif np.issubdtype(dtype, np.datetime64): + return np.datetime64("NaT") else: return None return fill_value @@ -435,9 +435,9 @@ def _std_finalize(sumsq, sum_, count, ddof=0): min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF) -nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan) +nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA) max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF) -nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan) +nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA) def argreduce_preprocess(array, axis): @@ -741,7 +741,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan) binary_op=None, reduction="nanlast", scan="ffill", - identity=np.nan, + identity=dtypes.NA, mode="concat_then_scan", ) bfill = Scan( @@ -749,7 +749,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan) binary_op=None, reduction="nanlast", scan="ffill", - identity=np.nan, + identity=dtypes.NA, mode="concat_then_scan", preprocess=reverse, finalize=reverse, diff --git a/tests/test_properties.py b/tests/test_properties.py index 821c93188..0490c56f2 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -218,10 +218,11 @@ def test_first_last(data, array: dask.array.Array, func: str) -> None: @settings(report_multiple_bugs=False) @given(data=st.data(), array=chunked_arrays()) def test_topk_max_min(data, array): - "top 1 == max; top -1 == min" + "top 1 == nanmax; top -1 == nanmin" size = array.shape[-1] + note(array.compute()) by = data.draw(by_arrays(shape=(size,))) - k, npfunc = data.draw(st.sampled_from([(1, "max"), (-1, "min")])) + k, npfunc = data.draw(st.sampled_from([(1, "nanmax"), (-1, "nanmin")])) for a in (array, array.compute()): actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k}) From 80c67f434b18f4c19be30b23ec5ff91f3685452a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 31 Jul 2024 15:16:30 -0600 Subject: [PATCH 09/45] Fix --- flox/aggregate_flox.py | 15 ++++++--- flox/aggregations.py | 68 +++++----------------------------------- flox/xrdtypes.py | 55 ++++++++++++++++++++++++++++++++ tests/test_core.py | 11 ++++--- tests/test_properties.py | 6 ++-- 5 files changed, 84 insertions(+), 71 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 3ef800102..e4d111416 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -2,6 +2,7 @@ import numpy as np +from . import xrdtypes as dtypes from .xrutils import is_scalar, isnull, notnull @@ -60,6 +61,7 @@ def quantile_or_topk( fill_value=None, ): assert q or k + assert axis == -1 inv_idx = np.concatenate((inv_idx, [array.shape[-1]])) @@ -84,7 +86,7 @@ def quantile_or_topk( nanmask = full_sizes != actual_sizes # TODO: Don't know if this array has been copied in _prepare_for_flox. # This is potentially wasteful - array = np.where(array_nanmask, -np.inf, array) + array = np.where(array_nanmask, dtypes.get_neg_infinity(array.dtype, min_for_int=True), array) maxes = np.maximum.reduceat(array, inv_idx[:-1], axis=axis) replacement = np.repeat(maxes, np.diff(inv_idx), axis=axis) array[array_nanmask] = replacement[array_nanmask] @@ -128,7 +130,7 @@ def quantile_or_topk( # partition the complex array in-place labels_broadcast = np.broadcast_to(group_idx, array.shape) with np.errstate(invalid="ignore"): - cmplx = labels_broadcast + 1j * array + cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array) cmplx.partition(kth=kth, axis=-1) if is_scalar_param: @@ -136,6 +138,9 @@ def quantile_or_topk( else: a_ = np.broadcast_to(cmplx.imag, (param.shape[0],) + array.shape) + if array.dtype.kind in "Mm": + a_ = a_.astype(array.dtype) + loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis) if q is not None: # get bounds, Broadcast to (num quantiles, ..., num labels) @@ -204,6 +209,8 @@ def _np_grouped_op( def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs): + if fillna in [dtypes.INF, dtypes.NINF]: + fillna = dtypes._get_fill_value(kwargs.get("dtype", array.dtype), fillna) result = func(group_idx, np.where(isnull(array), fillna, array), *args, **kwargs) # np.nanmax([np.nan, np.nan]) = np.nan # To recover this behaviour, we need to search for the fillna value @@ -221,9 +228,9 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs): prod = partial(_np_grouped_op, op=np.multiply.reduceat) nanprod = partial(_nan_grouped_op, func=prod, fillna=1) max = partial(_np_grouped_op, op=np.maximum.reduceat) -nanmax = partial(_nan_grouped_op, func=max, fillna=-np.inf) +nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF) min = partial(_np_grouped_op, op=np.minimum.reduceat) -nanmin = partial(_nan_grouped_op, func=min, fillna=np.inf) +nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF) quantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=False)) topk = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True)) nanquantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True)) diff --git a/flox/aggregations.py b/flox/aggregations.py index 644dca437..6df5aad2a 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -115,60 +115,6 @@ def generic_aggregate( return result -def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype: - if dtype is None: - dtype = array_dtype - if dtype is np.floating: - # mean, std, var always result in floating - # but we preserve the array's dtype if it is floating - if array_dtype.kind in "fcmM": - dtype = array_dtype - else: - dtype = np.dtype("float64") - elif not isinstance(dtype, np.dtype): - dtype = np.dtype(dtype) - if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]: - dtype = np.result_type(dtype, fill_value) - return dtype - - -def _maybe_promote_int(dtype) -> np.dtype: - # https://numpy.org/doc/stable/reference/generated/numpy.prod.html - # The dtype of a is used by default unless a has an integer dtype of less precision - # than the default platform integer. - if not isinstance(dtype, np.dtype): - dtype = np.dtype(dtype) - if dtype.kind == "i": - dtype = np.result_type(dtype, np.intp) - elif dtype.kind == "u": - dtype = np.result_type(dtype, np.uintp) - return dtype - - -def _get_fill_value(dtype, fill_value): - """Returns dtype appropriate infinity. Returns +Inf equivalent for None.""" - if fill_value in [None, dtypes.NA] and dtype.kind in "US": - return "" - if fill_value == dtypes.INF or fill_value is None: - return dtypes.get_pos_infinity(dtype, max_for_int=True) - if fill_value == dtypes.NINF: - return dtypes.get_neg_infinity(dtype, min_for_int=True) - if fill_value == dtypes.NA: - if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating): - return np.nan - # This is madness, but npg checks that fill_value is compatible - # with array dtype even if the fill_value is never used. - elif np.issubdtype(dtype, np.integer): - return dtypes.get_neg_infinity(dtype, min_for_int=True) - elif np.issubdtype(dtype, np.timedelta64): - return np.timedelta64("NaT") - elif np.issubdtype(dtype, np.datetime64): - return np.datetime64("NaT") - else: - return None - return fill_value - - def _atleast_1d(inp, min_length: int = 1): if xrutils.is_scalar(inp): inp = (inp,) * min_length @@ -646,7 +592,7 @@ def last(self) -> AlignedArrays: # TODO: automate? engine="flox", dtype=self.array.dtype, - fill_value=_get_fill_value(self.array.dtype, dtypes.NA), + fill_value=dtypes._get_fill_value(self.array.dtype, dtypes.NA), expected_groups=None, ) return AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"]) @@ -829,7 +775,9 @@ def _initialize_aggregation( np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype ) - final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value) + final_dtype = dtypes._normalize_dtype( + dtype_ or agg.dtype_init["final"], array_dtype, fill_value + ) if agg.name not in [ "first", "last", @@ -841,14 +789,14 @@ def _initialize_aggregation( "nanmax", "topk", ]: - final_dtype = _maybe_promote_int(final_dtype) + final_dtype = dtypes._maybe_promote_int(final_dtype) agg.dtype = { "user": dtype, # Save to automatically choose an engine "final": final_dtype, "numpy": (final_dtype,), "intermediate": tuple( ( - _normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv) + dtypes._normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv) if int_dtype is None else np.dtype(int_dtype) ) @@ -863,10 +811,10 @@ def _initialize_aggregation( # Replace sentinel fill values according to dtype agg.fill_value["user"] = fill_value agg.fill_value["intermediate"] = tuple( - _get_fill_value(dt, fv) + dtypes._get_fill_value(dt, fv) for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"]) ) - agg.fill_value[func] = _get_fill_value(agg.dtype["final"], agg.fill_value[func]) + agg.fill_value[func] = dtypes._get_fill_value(agg.dtype["final"], agg.fill_value[func]) fv = fill_value if fill_value is not None else agg.fill_value[agg.name] if _is_arg_reduction(agg): diff --git a/flox/xrdtypes.py b/flox/xrdtypes.py index 2d6ce3698..3fd0f4fec 100644 --- a/flox/xrdtypes.py +++ b/flox/xrdtypes.py @@ -1,6 +1,7 @@ import functools import numpy as np +from numpy.typing import DTypeLike from . import xrutils as utils @@ -147,3 +148,57 @@ def get_neg_infinity(dtype, min_for_int=False): def is_datetime_like(dtype): """Check if a dtype is a subclass of the numpy datetime types""" return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) + + +def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype: + if dtype is None: + dtype = array_dtype + if dtype is np.floating: + # mean, std, var always result in floating + # but we preserve the array's dtype if it is floating + if array_dtype.kind in "fcmM": + dtype = array_dtype + else: + dtype = np.dtype("float64") + elif not isinstance(dtype, np.dtype): + dtype = np.dtype(dtype) + if fill_value not in [None, INF, NINF, NA]: + dtype = np.result_type(dtype, fill_value) + return dtype + + +def _maybe_promote_int(dtype) -> np.dtype: + # https://numpy.org/doc/stable/reference/generated/numpy.prod.html + # The dtype of a is used by default unless a has an integer dtype of less precision + # than the default platform integer. + if not isinstance(dtype, np.dtype): + dtype = np.dtype(dtype) + if dtype.kind == "i": + dtype = np.result_type(dtype, np.intp) + elif dtype.kind == "u": + dtype = np.result_type(dtype, np.uintp) + return dtype + + +def _get_fill_value(dtype, fill_value): + """Returns dtype appropriate infinity. Returns +Inf equivalent for None.""" + if fill_value in [None, NA] and dtype.kind in "US": + return "" + if fill_value == INF or fill_value is None: + return get_pos_infinity(dtype, max_for_int=True) + if fill_value == NINF: + return get_neg_infinity(dtype, min_for_int=True) + if fill_value == NA: + if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating): + return np.nan + # This is madness, but npg checks that fill_value is compatible + # with array dtype even if the fill_value is never used. + elif np.issubdtype(dtype, np.integer): + return get_neg_infinity(dtype, min_for_int=True) + elif np.issubdtype(dtype, np.timedelta64): + return np.timedelta64("NaT") + elif np.issubdtype(dtype, np.datetime64): + return np.datetime64("NaT") + else: + return None + return fill_value diff --git a/tests/test_core.py b/tests/test_core.py index e12e695db..998a26dee 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -13,8 +13,9 @@ from numpy_groupies.aggregate_numpy import aggregate import flox +from flox import xrdtypes as dtypes from flox import xrutils -from flox.aggregations import Aggregation, _initialize_aggregation, _maybe_promote_int +from flox.aggregations import Aggregation, _initialize_aggregation from flox.core import ( HAS_NUMBAGG, _choose_engine, @@ -161,7 +162,7 @@ def test_groupby_reduce( if func == "mean" or func == "nanmean": expected_result = np.array(expected, dtype=np.float64) elif func == "sum": - expected_result = np.array(expected, dtype=_maybe_promote_int(array.dtype)) + expected_result = np.array(expected, dtype=dtypes._maybe_promote_int(array.dtype)) elif func == "count": expected_result = np.array(expected, dtype=np.intp) @@ -389,7 +390,7 @@ def test_groupby_reduce_preserves_dtype(dtype, func): array = np.ones((2, 12), dtype=dtype) by = np.array([labels] * 2) result, _ = groupby_reduce(from_array(array, chunks=(-1, 4)), by, func=func) - expect_dtype = _maybe_promote_int(array.dtype) + expect_dtype = dtypes._maybe_promote_int(array.dtype) assert result.dtype == expect_dtype @@ -1027,7 +1028,7 @@ def test_dtype_preservation(dtype, func, engine): # https://github.com/numbagg/numbagg/issues/121 pytest.skip() if func == "sum": - expected = _maybe_promote_int(dtype) + expected = dtypes._maybe_promote_int(dtype) elif func == "mean" and "int" in dtype: expected = np.float64 else: @@ -1058,7 +1059,7 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype): actual, actual_groups = groupby_reduce(array, labels, func="sum", method=method) assert_equal(actual_groups, np.arange(6, dtype=labels.dtype)) - expect_dtype = _maybe_promote_int(dtype) + expect_dtype = dtypes._maybe_promote_int(dtype) assert_equal(actual, np.array([0, 4, 24, 6, 12, 20], dtype=expect_dtype)) diff --git a/tests/test_properties.py b/tests/test_properties.py index 0490c56f2..73d6ea6da 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -215,8 +215,9 @@ def test_first_last(data, array: dask.array.Array, func: str) -> None: from hypothesis import settings +# TODO: do all_arrays instead of numeric_arrays @settings(report_multiple_bugs=False) -@given(data=st.data(), array=chunked_arrays()) +@given(data=st.data(), array=chunked_arrays(arrays=numeric_arrays)) def test_topk_max_min(data, array): "top 1 == nanmax; top -1 == nanmin" size = array.shape[-1] @@ -226,5 +227,6 @@ def test_topk_max_min(data, array): for a in (array, array.compute()): actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k}) - expected, _ = groupby_reduce(a, by, func=npfunc) + # TODO: do numbagg, flox + expected, _ = groupby_reduce(a, by, func=npfunc, engine="numpy") assert_equal(actual, expected[np.newaxis, :]) From c924017e142f3c14b0feb0e04b8d6f0299ef80a4 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Jan 2025 19:27:07 -0700 Subject: [PATCH 10/45] Fixes --- flox/aggregations.py | 5 ++++- flox/core.py | 1 + flox/xrutils.py | 3 +++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index a4b061de1..30920a3d5 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -877,7 +877,10 @@ def _initialize_aggregation( simple_combine: list[Callable | None] = [] for combine in agg.combine: if isinstance(combine, str): - simple_combine.append(getattr(np, combine)) + if combine in ["nanfirst", "nanlast"]: + simple_combine.append(getattr(xrutils, combine)) + else: + simple_combine.append(getattr(np, combine)) else: if agg.name == "topk": combine = partial(combine, **finalize_kwargs) diff --git a/flox/core.py b/flox/core.py index 735462c9d..b82df9d90 100644 --- a/flox/core.py +++ b/flox/core.py @@ -879,6 +879,7 @@ def chunk_argreduce( engine: T_Engine = "numpy", sort: bool = True, user_dtype=None, + kwargs: Sequence[dict] | None = None, ) -> IntermediateDict: """ Per-chunk arg reduction. diff --git a/flox/xrutils.py b/flox/xrutils.py index a2ba259f1..8e6d88e2a 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -385,6 +385,9 @@ def topk(a, k, axis, keepdims): If k is negative, extract the -k smallest elements instead. Note that, unlike in the parent function, the returned elements are not sorted internally. + + NOTE: This function was copied from the dask project under the terms + of their LICENSE. """ assert keepdims is True (axis,) = axis From 7a794ba5f6e3cf4018c9a97c67bc9ccd24f33bc3 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Jan 2025 19:30:57 -0700 Subject: [PATCH 11/45] one more fix --- flox/aggregate_flox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 4397c2757..aecba2e3a 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -192,7 +192,7 @@ def _np_grouped_op( if out is None: q = kwargs.get("q", None) k = kwargs.get("k", None) - if not q and not k: + if q is None and k is None: out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype) else: nq = len(np.atleast_1d(q)) if q is not None else abs(k) From eec4dd4a5658ab766e3fd4b97e1cb10f0ee48f01 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Jan 2025 19:33:06 -0700 Subject: [PATCH 12/45] fix --- flox/aggregations.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 30920a3d5..37a70022a 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -835,7 +835,11 @@ def _initialize_aggregation( ), } - if agg.name == "topk" and finalize_kwargs["k"] < 0: + if finalize_kwargs is not None: + assert isinstance(finalize_kwargs, dict) + agg.finalize_kwargs = finalize_kwargs + + if agg.name == "topk" and agg.finalize_kwargs["k"] < 0: agg.fill_value["intermediate"] = (dtypes.INF,) # Replace sentinel fill values according to dtype agg.fill_value["user"] = fill_value @@ -852,10 +856,6 @@ def _initialize_aggregation( else: agg.fill_value["numpy"] = (fv,) - if finalize_kwargs is not None: - assert isinstance(finalize_kwargs, dict) - agg.finalize_kwargs = finalize_kwargs - # This is needed for the dask pathway. # Because we use intermediate fill_value since a group could be # absent in one block, but present in another block @@ -883,7 +883,7 @@ def _initialize_aggregation( simple_combine.append(getattr(np, combine)) else: if agg.name == "topk": - combine = partial(combine, **finalize_kwargs) + combine = partial(combine, **agg.finalize_kwargs) simple_combine.append(combine) agg.simple_combine = tuple(simple_combine) From 6ac9a1f3d7dfb9f0d029d6ad5c08678d096ec14f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Jan 2025 19:44:36 -0700 Subject: [PATCH 13/45] one more fix --- flox/aggregate_flox.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index aecba2e3a..ff007cdd1 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -60,7 +60,7 @@ def quantile_or_topk( out=None, fill_value=None, ): - assert q or k + assert q is not None or k is not None assert axis == -1 inv_idx = np.concatenate((inv_idx, [array.shape[-1]])) @@ -91,7 +91,7 @@ def quantile_or_topk( replacement = np.repeat(maxes, np.diff(inv_idx), axis=axis) array[array_nanmask] = replacement[array_nanmask] - param = q or k + param = q if q is not None else k if k is not None: is_scalar_param = False param = np.arange(abs(k)) From 83594e8c3bd30d0068d036cdebc17d721fdad08a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Jan 2025 22:18:01 -0700 Subject: [PATCH 14/45] Fixes. --- flox/aggregate_flox.py | 38 +++++++++++++++++++++-------------- flox/aggregations.py | 5 ++++- flox/core.py | 11 +++++----- flox/xrutils.py | 7 +++++-- tests/test_properties.py | 43 +++++++++++++++++++++------------------- 5 files changed, 61 insertions(+), 43 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index ff007cdd1..d8605f09a 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -79,14 +79,21 @@ def quantile_or_topk( # Replace NaNs with the maximum value for each group. # Partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/ + # TODO: optimize for int, string? array_nanmask = isnull(array) actual_sizes = np.add.reduceat(~array_nanmask, inv_idx[:-1], axis=axis) newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,) full_sizes = np.reshape(np.diff(inv_idx), newshape) - nanmask = full_sizes != actual_sizes + # These groups get replaced with the fill_value. For topk, we only replace with fill-values + # if non-NaN values in group < k + nanmask = (full_sizes - actual_sizes) > (abs(k) if k is not None else 0) # TODO: Don't know if this array has been copied in _prepare_for_flox. # This is potentially wasteful - array = np.where(array_nanmask, dtypes.get_neg_infinity(array.dtype, min_for_int=True), array) + # FIXME: should the filling handle sign(k) + fill = dtypes.get_neg_infinity(array.dtype, min_for_int=True) + if k is not None and k < 0: + fill = dtypes.get_pos_infinity(array.dtype, max_for_int=True) + array = np.where(array_nanmask, fill, array) maxes = np.maximum.reduceat(array, inv_idx[:-1], axis=axis) replacement = np.repeat(maxes, np.diff(inv_idx), axis=axis) array[array_nanmask] = replacement[array_nanmask] @@ -126,16 +133,19 @@ def quantile_or_topk( # partition the complex array in-place labels_broadcast = np.broadcast_to(group_idx, array.shape) with np.errstate(invalid="ignore"): - cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array) + cmplx = 1j * (array.view(int) if array.dtype.kind in "Mm" else array) + # This is a very intentional way of handling `array` with -inf/+inf values :/ + # a simple (labels + 1j * array) will yield `nan+inf * 1j` instead of `0 + inf * j` + # TODO: optimize copies here + cmplx.real = labels_broadcast cmplx.partition(kth=kth, axis=-1) - if is_scalar_param: - a_ = cmplx.imag - else: + a_ = cmplx.imag + if not is_scalar_param: a_ = np.broadcast_to(cmplx.imag, (param.shape[0],) + array.shape) if array.dtype.kind in "Mm": - a_ = a_.astype(array.dtype) + a_ = a_.view(array.dtype) loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis) if q is not None: @@ -145,15 +155,13 @@ def quantile_or_topk( # TODO: could support all the interpolations here gamma = np.broadcast_to(virtual_index, idxshape) - lo_ result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype) + if not skipna and np.any(nanmask): + result[..., nanmask] = fill_value else: result = loval - # This happens if numel in group < abs(k) - badmask = lo_ < 0 - if badmask.any(): - result[badmask] = fill_value - - if not skipna and np.any(nanmask): - result[..., nanmask] = fill_value + # The first clause is True if numel in group < abs(k) + badmask = np.broadcast_to(lo_ < 0, idxshape) | nanmask + result[badmask] = fill_value if k is not None: result = result.astype(dtype, copy=False) @@ -235,8 +243,8 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs): nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF) min = partial(_np_grouped_op, op=np.minimum.reduceat) nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF) -quantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=False)) topk = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True)) +quantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=False)) nanquantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True)) median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=False)) nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=True)) diff --git a/flox/aggregations.py b/flox/aggregations.py index 37a70022a..d92848684 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -393,6 +393,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0): "nanmin", chunk="nanmin", combine="nanmin", + # FIXME: This is wrong, we need it to be NA for nan, INF for nanmin, NINF for nanmax, I think fill_value=dtypes.NA, preserves_dtype=True, ) @@ -882,7 +883,9 @@ def _initialize_aggregation( else: simple_combine.append(getattr(np, combine)) else: - if agg.name == "topk": + # TODO: bah, we need to pass `k` to the combine topk function + # this is ugly. + if agg.name == "topk" and not isinstance(combine, str): combine = partial(combine, **agg.finalize_kwargs) simple_combine.append(combine) diff --git a/flox/core.py b/flox/core.py index b82df9d90..7ddc62604 100644 --- a/flox/core.py +++ b/flox/core.py @@ -879,7 +879,6 @@ def chunk_argreduce( engine: T_Engine = "numpy", sort: bool = True, user_dtype=None, - kwargs: Sequence[dict] | None = None, ) -> IntermediateDict: """ Per-chunk arg reduction. @@ -1653,6 +1652,9 @@ def dask_groupby_agg( # use the "non dask" code path, but applied blockwise blockwise_method = partial(_reduce_blockwise, agg=agg, fill_value=fill_value, reindex=reindex) else: + extra = {} + if agg.name == "topk": + extra["kwargs"] = (agg.finalize_kwargs, *(({},) * (len(agg.chunk) - 1))) # choose `chunk_reduce` or `chunk_argreduce` blockwise_method = partial( _get_chunk_reduction(agg.reduction_type), @@ -1661,7 +1663,7 @@ def dask_groupby_agg( dtype=agg.dtype["intermediate"], reindex=reindex, user_dtype=agg.dtype["user"], - kwargs=agg.finalize_kwargs if agg.name == "topk" else None, + **extra, ) if do_simple_combine: # Add a dummy dimension that then gets reduced over @@ -2392,9 +2394,8 @@ def groupby_reduce( "Use engine='flox' instead (it is also much faster), " "or set engine=None to use the default." ) - if func == "topk": - if finalize_kwargs is None or "k" not in finalize_kwargs: - raise ValueError("Please pass `k` for topk calculations.") + if func == "topk" and (finalize_kwargs is None or "k" not in finalize_kwargs): + raise ValueError("Please pass `k` in ``finalize_kwargs`` for topk calculations.") bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) nby = len(bys) diff --git a/flox/xrutils.py b/flox/xrutils.py index 8e6d88e2a..b049d681f 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -395,8 +395,11 @@ def topk(a, k, axis, keepdims): if abs(k) >= a.shape[axis]: return a - # TODO: handle NaNs + # TODO: This may not need to handle NaNs + # if a.dtype.kind in ["cfO"]: + # fill = xrdtypes.get_neg_infinity(a.dtype) if k > 0 else xrdtypes.get_pos_infinity(a.dtype) + # a = np.where(isnull(a), fill) a = np.partition(a, -k, axis=axis) k_slice = slice(-k, None) if k > 0 else slice(-k) result = a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))] - return result + return result.astype(a.dtype, copy=False) diff --git a/tests/test_properties.py b/tests/test_properties.py index a8f3b08ac..b27ac3b0a 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -246,26 +246,6 @@ def test_first_last_useless(data, func): assert_equal(actual, expected) -from hypothesis import settings - - -# TODO: do all_arrays instead of numeric_arrays -@settings(report_multiple_bugs=False) -@given(data=st.data(), array=chunked_arrays(arrays=numeric_arrays)) -def test_topk_max_min(data, array): - "top 1 == nanmax; top -1 == nanmin" - size = array.shape[-1] - note(array.compute()) - by = data.draw(by_arrays(shape=(size,))) - k, npfunc = data.draw(st.sampled_from([(1, "nanmax"), (-1, "nanmin")])) - - for a in (array, array.compute()): - actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k}) - # TODO: do numbagg, flox - expected, _ = groupby_reduce(a, by, func=npfunc, engine="numpy") - assert_equal(actual, expected[np.newaxis, :]) - - @given( func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]), engine=st.sampled_from(["numpy", "flox"]), @@ -286,3 +266,26 @@ def test_agg_dtype_specified(func, array_dtype, dtype, engine): ) expected = getattr(np, func)(counts, keepdims=True, dtype=dtype) assert actual.dtype == expected.dtype + + +from hypothesis import settings + + +# TODO: do all_arrays instead of numeric_arrays +@settings(report_multiple_bugs=False) +@given(data=st.data(), array=chunked_arrays(arrays=numeric_arrays)) +def test_topk_max_min(data, array): + "top 1 == nanmax; top -1 == nanmin" + size = array.shape[-1] + note(array.compute()) # FIXME + by = data.draw(by_arrays(shape=(size,))) + k, npfunc = data.draw(st.sampled_from([(1, "nanmax"), (-1, "nanmin")])) + + for a in (array, array.compute()): + actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k}) + # TODO: do numbagg, flox + # FIXME: this is wrong + expected, _ = groupby_reduce(a, by, func=npfunc, engine="numpy") + # if a.dtype.kind in "cf": + # expected[np.isnan(expected)] = -np.inf if k == 1 else np.inf + assert_equal(actual, expected[np.newaxis, :]) From 740f85f26f9ef7ed0b771560b121f8e61a7bad68 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Jan 2025 07:23:59 -0700 Subject: [PATCH 15/45] WIP --- flox/aggregate_flox.py | 5 ++++- flox/aggregations.py | 1 - tests/test_properties.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index d8605f09a..40cb70a19 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -86,7 +86,10 @@ def quantile_or_topk( full_sizes = np.reshape(np.diff(inv_idx), newshape) # These groups get replaced with the fill_value. For topk, we only replace with fill-values # if non-NaN values in group < k - nanmask = (full_sizes - actual_sizes) > (abs(k) if k is not None else 0) + if k is not None: + nanmask = (actual_sizes < abs(k)) + else: + nanmask = full_sizes != actual_sizes # TODO: Don't know if this array has been copied in _prepare_for_flox. # This is potentially wasteful # FIXME: should the filling handle sign(k) diff --git a/flox/aggregations.py b/flox/aggregations.py index d92848684..3ca3a5cb5 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -578,7 +578,6 @@ def topk_new_dims_func(k) -> tuple[Dim]: fill_value=dtypes.NINF, chunk="topk", combine=xrutils.topk, - final_dtype=None, new_dims_func=topk_new_dims_func, preserves_dtype=True, ) diff --git a/tests/test_properties.py b/tests/test_properties.py index b27ac3b0a..44734698b 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -286,6 +286,6 @@ def test_topk_max_min(data, array): # TODO: do numbagg, flox # FIXME: this is wrong expected, _ = groupby_reduce(a, by, func=npfunc, engine="numpy") - # if a.dtype.kind in "cf": - # expected[np.isnan(expected)] = -np.inf if k == 1 else np.inf + if a.dtype.kind in "cf": + expected[np.isnan(expected)] = -np.inf if k == 1 else np.inf assert_equal(actual, expected[np.newaxis, :]) From e177efddc17a4270bbccf1c4822a2aeb76380995 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Jan 2025 14:06:22 -0700 Subject: [PATCH 16/45] fixes --- flox/aggregate_flox.py | 8 ++------ flox/aggregations.py | 20 +++++++++++++++----- tests/test_properties.py | 4 ++-- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index ef605e882..311c0e9bd 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -73,10 +73,6 @@ def quantile_or_topk( # The approach here is to use (complex_array.partition) because # 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary # 2. Using record_array.partition(..., order=["labels", "array"]) is incredibly slow. - # partition will first sort by real part, then by imaginary part, so it is a two element - # lex-partition. Therefore we set - # partition will first sort by real part, then by imaginary part, so it is a two element lex-partition. - # So we set # 3. For complex arrays, partition will first sort by real part, then by imaginary part, so it is a two element # lex-partition. # Therefore we use approach (3) and set @@ -106,7 +102,6 @@ def quantile_or_topk( # This is numpy's method="linear" # TODO: could support all the interpolations here offset = actual_sizes.cumsum(axis=-1) - actual_sizes -= 1 # For topk(.., k=+1 or -1), we always return the singleton dimension. idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],) @@ -114,6 +109,7 @@ def quantile_or_topk( # This is numpy's method="linear" # TODO: could support all the interpolations here virtual_index = param * actual_sizes + actual_sizes -= 1 # virtual_index is relative to group starts, so now offset that virtual_index[..., 1:] += offset[..., :-1] @@ -126,7 +122,7 @@ def quantile_or_topk( kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)])) else: - virtual_index = (actual_sizes - k) if k > 0 else (abs(k) - 1) + virtual_index = (actual_sizes - k) if k > 0 else (np.zeros_like(actual_sizes) + abs(k) - 1) # virtual_index is relative to group starts, so now offset that virtual_index[..., 1:] += offset[..., :-1] kth = np.unique(virtual_index) diff --git a/flox/aggregations.py b/flox/aggregations.py index 3ca3a5cb5..495b5fde4 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -573,16 +573,26 @@ def topk_new_dims_func(k) -> tuple[Dim]: final_dtype=np.float64, new_dims_func=quantile_new_dims_func, ) +mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) +nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) + + +def _topk_finalize(result, counts, *, k): + # TODO: pass through final_fill_value + # TODO: apply in numpy code-path too + result[..., counts < k] = np.nan + return result + + topk = Aggregation( name="topk", - fill_value=dtypes.NINF, - chunk="topk", - combine=xrutils.topk, + fill_value=(dtypes.NINF, 0), + chunk=("topk", "nanlen"), + combine=(xrutils.topk, "sum"), + finalize=_topk_finalize, new_dims_func=topk_new_dims_func, preserves_dtype=True, ) -mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) -nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) @dataclass diff --git a/tests/test_properties.py b/tests/test_properties.py index 44734698b..b27ac3b0a 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -286,6 +286,6 @@ def test_topk_max_min(data, array): # TODO: do numbagg, flox # FIXME: this is wrong expected, _ = groupby_reduce(a, by, func=npfunc, engine="numpy") - if a.dtype.kind in "cf": - expected[np.isnan(expected)] = -np.inf if k == 1 else np.inf + # if a.dtype.kind in "cf": + # expected[np.isnan(expected)] = -np.inf if k == 1 else np.inf assert_equal(actual, expected[np.newaxis, :]) From 93934705426e1700e0ceb6ed98183a1e9f3d0a28 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Jan 2025 14:16:48 -0700 Subject: [PATCH 17/45] fix --- flox/aggregations.py | 16 +++++----------- flox/core.py | 1 + 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 495b5fde4..8ec803018 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -575,21 +575,12 @@ def topk_new_dims_func(k) -> tuple[Dim]: ) mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) - - -def _topk_finalize(result, counts, *, k): - # TODO: pass through final_fill_value - # TODO: apply in numpy code-path too - result[..., counts < k] = np.nan - return result - - topk = Aggregation( name="topk", fill_value=(dtypes.NINF, 0), + final_fill_value=dtypes.NA, chunk=("topk", "nanlen"), combine=(xrutils.topk, "sum"), - finalize=_topk_finalize, new_dims_func=topk_new_dims_func, preserves_dtype=True, ) @@ -850,7 +841,7 @@ def _initialize_aggregation( agg.finalize_kwargs = finalize_kwargs if agg.name == "topk" and agg.finalize_kwargs["k"] < 0: - agg.fill_value["intermediate"] = (dtypes.INF,) + agg.fill_value["intermediate"] = (dtypes.INF, 0) # Replace sentinel fill values according to dtype agg.fill_value["user"] = fill_value agg.fill_value["intermediate"] = tuple( @@ -866,6 +857,9 @@ def _initialize_aggregation( else: agg.fill_value["numpy"] = (fv,) + if agg.name == "topk": + min_count = max(min_count or 0, abs(agg.finalize_kwargs["k"])) + # This is needed for the dask pathway. # Because we use intermediate fill_value since a group could be # absent in one block, but present in another block diff --git a/flox/core.py b/flox/core.py index 7ddc62604..ed8c82a91 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1150,6 +1150,7 @@ def _finalize_results( if count_mask.any(): # For one count_mask.any() prevents promoting bool to dtype(fill_value) unless # necessary + fill_value = fill_value or agg.fill_value[agg.name] if fill_value is None: raise ValueError("Filling is required but fill_value is None.") # This allows us to match xarray's type promotion rules From 17eb915c6a57b1f2f8058e9153d6e4d6599c0453 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Jan 2025 14:18:45 -0700 Subject: [PATCH 18/45] cleanup --- flox/aggregate_flox.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 311c0e9bd..c64b58691 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -68,7 +68,11 @@ def quantile_or_topk( array_validmask = notnull(array) actual_sizes = np.add.reduceat(array_validmask, inv_idx[:-1], axis=axis) newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,) - full_sizes = np.reshape(np.diff(inv_idx), newshape) + if k is not None: + nanmask = actual_sizes < abs(k) + else: + full_sizes = np.reshape(np.diff(inv_idx), newshape) + nanmask = full_sizes != actual_sizes # The approach here is to use (complex_array.partition) because # 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary @@ -86,11 +90,6 @@ def quantile_or_topk( # So we determine which indices we need using the fact that NaNs get sorted to the end. # This *was* partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/ # but not any more now that I use partition and avoid replacing NaNs - if k is not None: - nanmask = actual_sizes < abs(k) - else: - nanmask = full_sizes != actual_sizes - if k is not None: is_scalar_param = False param = np.arange(abs(k)) From dc0df3ed130f90cbc8f2e73bd70e7eb52b7904da Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Jan 2025 14:21:47 -0700 Subject: [PATCH 19/45] works? --- tests/test_properties.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_properties.py b/tests/test_properties.py index b27ac3b0a..b55b8f732 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -284,8 +284,6 @@ def test_topk_max_min(data, array): for a in (array, array.compute()): actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k}) # TODO: do numbagg, flox - # FIXME: this is wrong - expected, _ = groupby_reduce(a, by, func=npfunc, engine="numpy") - # if a.dtype.kind in "cf": - # expected[np.isnan(expected)] = -np.inf if k == 1 else np.inf + # FIXME: this is wrong, remove this compute, add a property test checking dask vs numpy + expected, _ = groupby_reduce(dask.compute(a)[0], by, func=npfunc, engine="numpy") assert_equal(actual, expected[np.newaxis, :]) From 83ae5d8b3590b3e6f4f297b688a165cc6bc75166 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Jan 2025 15:34:14 -0700 Subject: [PATCH 20/45] fix quantile --- flox/aggregate_flox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index c64b58691..d229b3f19 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -107,8 +107,8 @@ def quantile_or_topk( if q is not None: # This is numpy's method="linear" # TODO: could support all the interpolations here - virtual_index = param * actual_sizes actual_sizes -= 1 + virtual_index = param * actual_sizes # virtual_index is relative to group starts, so now offset that virtual_index[..., 1:] += offset[..., :-1] From 95d20b88f3bc0fbc533737f27df6bc0e5a5ac499 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Jan 2025 15:34:45 -0700 Subject: [PATCH 21/45] optimize xrutils.topk --- flox/xrutils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/flox/xrutils.py b/flox/xrutils.py index b049d681f..74175e1ec 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -395,11 +395,7 @@ def topk(a, k, axis, keepdims): if abs(k) >= a.shape[axis]: return a - # TODO: This may not need to handle NaNs - # if a.dtype.kind in ["cfO"]: - # fill = xrdtypes.get_neg_infinity(a.dtype) if k > 0 else xrdtypes.get_pos_infinity(a.dtype) - # a = np.where(isnull(a), fill) - a = np.partition(a, -k, axis=axis) + a.partition(-k, axis=axis) k_slice = slice(-k, None) if k > 0 else slice(-k) result = a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))] return result.astype(a.dtype, copy=False) From caa98b88a3d378d1fa1a9532bc63a20ce07669e2 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Jan 2025 20:43:37 -0700 Subject: [PATCH 22/45] Update tests/test_properties.py --- tests/test_properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_properties.py b/tests/test_properties.py index a1ce172e8..9d728fbfb 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -318,5 +318,5 @@ def test_topk_max_min(data, array): actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k}) # TODO: do numbagg, flox # FIXME: this is wrong, remove this compute, add a property test checking dask vs numpy - expected, _ = groupby_reduce(dask.compute(a)[0], by, func=npfunc, engine="numpy") + expected, _ = groupby_reduce(a, by, func=npfunc, engine="numpy") assert_equal(actual, expected[np.newaxis, :]) From 820d46c789b40f72718ca9fe8ded23ea7e9c4b3d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 12 Jan 2025 18:15:13 -0600 Subject: [PATCH 23/45] generalize new_dims_func --- flox/core.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/flox/core.py b/flox/core.py index 969b29c13..9dba803c1 100644 --- a/flox/core.py +++ b/flox/core.py @@ -41,8 +41,6 @@ _atleast_1d, _initialize_aggregation, generic_aggregate, - quantile_new_dims_func, - topk_new_dims_func, ) from .cache import memoize from .xrutils import ( @@ -937,6 +935,7 @@ def chunk_reduce( kwargs: Sequence[dict] | None = None, sort: bool = True, user_dtype=None, + new_dims_func: Callable | None = None, ) -> IntermediateDict: """ Wrapper for numpy_groupies aggregate that supports nD ``array`` and @@ -961,6 +960,9 @@ def chunk_reduce( axis : (optional) int or Sequence[int] If None, reduce along all dimensions of array. Else reduce along specified axes. + new_dims_func: Callable + Function that returns expected shape for any new dimensions + (needed for quantile and topk) Returns ------- @@ -1089,11 +1091,8 @@ def chunk_reduce( if hasnan: # remove NaN group label which should be last result = result[..., :-1] - # TODO: Figure out how to generalize this - if reduction in ("quantile", "nanquantile"): - new_dims_shape = tuple(dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar) - elif reduction == "topk": - new_dims_shape = tuple(dim.size for dim in topk_new_dims_func(**kw) if not dim.is_scalar) + if new_dims_func is not None: + new_dims_shape = tuple(dim.size for dim in new_dims_func(**kw) if not dim.is_scalar) else: new_dims_shape = tuple() result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape) @@ -1448,6 +1447,7 @@ def _reduce_blockwise( sort=sort, reindex=reindex, user_dtype=agg.dtype["user"], + new_dims_func=agg.new_dims_func, ) if _is_arg_reduction(agg): From 6aa923a806aea35786e1c88d889a18d74cee716a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 12 Jan 2025 19:30:40 -0600 Subject: [PATCH 24/45] Revert "generalize new_dims_func" This reverts commit 820d46c789b40f72718ca9fe8ded23ea7e9c4b3d. --- flox/core.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/flox/core.py b/flox/core.py index 4aba36d66..d1080b6e8 100644 --- a/flox/core.py +++ b/flox/core.py @@ -41,6 +41,8 @@ _atleast_1d, _initialize_aggregation, generic_aggregate, + quantile_new_dims_func, + topk_new_dims_func, ) from .cache import memoize from .xrutils import ( @@ -935,7 +937,6 @@ def chunk_reduce( kwargs: Sequence[dict] | None = None, sort: bool = True, user_dtype=None, - new_dims_func: Callable | None = None, ) -> IntermediateDict: """ Wrapper for numpy_groupies aggregate that supports nD ``array`` and @@ -960,9 +961,6 @@ def chunk_reduce( axis : (optional) int or Sequence[int] If None, reduce along all dimensions of array. Else reduce along specified axes. - new_dims_func: Callable - Function that returns expected shape for any new dimensions - (needed for quantile and topk) Returns ------- @@ -1091,8 +1089,11 @@ def chunk_reduce( if hasnan: # remove NaN group label which should be last result = result[..., :-1] - if new_dims_func is not None: - new_dims_shape = tuple(dim.size for dim in new_dims_func(**kw) if not dim.is_scalar) + # TODO: Figure out how to generalize this + if reduction in ("quantile", "nanquantile"): + new_dims_shape = tuple(dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar) + elif reduction == "topk": + new_dims_shape = tuple(dim.size for dim in topk_new_dims_func(**kw) if not dim.is_scalar) else: new_dims_shape = tuple() result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape) @@ -1447,7 +1448,6 @@ def _reduce_blockwise( sort=sort, reindex=reindex, user_dtype=agg.dtype["user"], - new_dims_func=agg.new_dims_func, ) if _is_arg_reduction(agg): From 2c6d48626d7712390e12d219c34b5128c9671c49 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 13 Jan 2025 16:11:44 -0600 Subject: [PATCH 25/45] Support bool --- flox/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index e450370b2..c7db3b061 100644 --- a/flox/core.py +++ b/flox/core.py @@ -179,7 +179,7 @@ def _is_bool_supported_reduction(func: T_Agg) -> bool: if isinstance(func, Aggregation): func = func.name return ( - func in ["all", "any"] + func in ["all", "any", "topk"] # TODO: enable in npg # or _is_first_last_reduction(func) # or _is_minmax_reduction(func) From 0dcd87c373ab6b68adde48897c7581bc6d4dede7 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 13 Jan 2025 16:29:16 -0600 Subject: [PATCH 26/45] more skipping --- tests/test_properties.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_properties.py b/tests/test_properties.py index 1eb1769ad..baecd9ffe 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -337,6 +337,12 @@ def test_agg_dtype_specified(func, array_dtype, dtype, engine): @given(data=st.data(), array=chunked_arrays()) def test_topk_max_min(data, array): "top 1 == nanmax; top -1 == nanmin" + + if array.dtype.kind in "mM": + # we cast to float and back, so this is the effective limit + assume((array.view(np.int64) < 2**53).all()) + elif array.dtype.kind == "i": + assume((array < 2**53).all()) size = array.shape[-1] note(array.compute()) # FIXME by = data.draw(by_arrays(shape=(size,))) From 9b874ea2be51912bf8aba1112de22be6a331cb66 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 14 Jan 2025 17:52:10 -0600 Subject: [PATCH 27/45] fix --- flox/aggregate_flox.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index d229b3f19..c14c4df38 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -159,7 +159,8 @@ def quantile_or_topk( result = loval # The first clause is True if numel in group < abs(k) badmask = np.broadcast_to(lo_ < 0, idxshape) | nanmask - result[badmask] = fill_value + if badmask.any(): + result[badmask] = fill_value if k is not None: result = result.astype(dtype, copy=False) From adebbec4c0e53797e1d02acbb4c681758d504d0b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 14 Jan 2025 18:43:23 -0600 Subject: [PATCH 28/45] more xfail --- tests/test_properties.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_properties.py b/tests/test_properties.py index baecd9ffe..794466052 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -338,11 +338,13 @@ def test_agg_dtype_specified(func, array_dtype, dtype, engine): def test_topk_max_min(data, array): "top 1 == nanmax; top -1 == nanmin" - if array.dtype.kind in "mM": + if array.dtype.kind in "Mm": # we cast to float and back, so this is the effective limit - assume((array.view(np.int64) < 2**53).all()) - elif array.dtype.kind == "i": - assume((array < 2**53).all()) + assume((np.abs(array.view(np.int64)) < 2**53).all()) + elif _contains_cftime_datetimes(array): + asint = datetime_to_numeric(array, datetime_unit="us") + assume((np.abs(asint.view(np.int64)) < 2**53).all()) + size = array.shape[-1] note(array.compute()) # FIXME by = data.draw(by_arrays(shape=(size,))) From 4f35230f6a2c638ad8e8743e39075e11604f1883 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 18 Jan 2025 20:48:12 -0700 Subject: [PATCH 29/45] cleanup --- tests/test_properties.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_properties.py b/tests/test_properties.py index 794466052..dc00fd1d9 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -346,13 +346,11 @@ def test_topk_max_min(data, array): assume((np.abs(asint.view(np.int64)) < 2**53).all()) size = array.shape[-1] - note(array.compute()) # FIXME by = data.draw(by_arrays(shape=(size,))) k, npfunc = data.draw(st.sampled_from([(1, "nanmax"), (-1, "nanmin")])) for a in (array, array.compute()): actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k}) # TODO: do numbagg, flox - # FIXME: this is wrong, remove this compute, add a property test checking dask vs numpy expected, _ = groupby_reduce(a, by, func=npfunc, engine="numpy") assert_equal(actual, expected[np.newaxis, :]) From cd2f1509922fdaafd606e09bf00aec4a81bb5f94 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 18 Jan 2025 20:56:27 -0700 Subject: [PATCH 30/45] one more xfail --- tests/test_properties.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_properties.py b/tests/test_properties.py index dc00fd1d9..718a8c330 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -329,18 +329,16 @@ def test_agg_dtype_specified(func, array_dtype, dtype, engine): assert actual.dtype == expected.dtype -from hypothesis import settings - - -# TODO: do all_arrays instead of numeric_arrays -@settings(report_multiple_bugs=False) @given(data=st.data(), array=chunked_arrays()) def test_topk_max_min(data, array): "top 1 == nanmax; top -1 == nanmin" - if array.dtype.kind in "Mm": + if array.dtype.kind == "i": # we cast to float and back, so this is the effective limit + assume((np.abs(array) < 2**53).all()) + elif array.dtype.kind in "Mm": assume((np.abs(array.view(np.int64)) < 2**53).all()) + # we cast to float and back, so this is the effective limit elif _contains_cftime_datetimes(array): asint = datetime_to_numeric(array, datetime_unit="us") assume((np.abs(asint.view(np.int64)) < 2**53).all()) From 70e6f225058d0558c259c757ec3997231eeff467 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 18 Jan 2025 21:00:00 -0700 Subject: [PATCH 31/45] typing --- flox/aggregations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flox/aggregations.py b/flox/aggregations.py index ce36fe140..310725361 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -900,6 +900,7 @@ def _initialize_aggregation( # TODO: bah, we need to pass `k` to the combine topk function # this is ugly. if agg.name == "topk" and not isinstance(combine, str): + assert combine is not None combine = partial(combine, **agg.finalize_kwargs) simple_combine.append(combine) From 5d45603656f860dc646f74b93d6fdcfd34f43313 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 18 Jan 2025 21:02:28 -0700 Subject: [PATCH 32/45] minor docs --- docs/source/aggregations.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/source/aggregations.md b/docs/source/aggregations.md index d3591d2dc..82562cc3a 100644 --- a/docs/source/aggregations.md +++ b/docs/source/aggregations.md @@ -9,19 +9,16 @@ the `func` kwarg: - `"mean"`, `"nanmean"` - `"var"`, `"nanvar"` - `"std"`, `"nanstd"` -- `"argmin"` -- `"argmax"` +- `"argmin"`, `"nanargmax"` +- `"argmax"`, `"nanargmin"` - `"first"`, `"nanfirst"` - `"last"`, `"nanlast"` - `"median"`, `"nanmedian"` - `"mode"`, `"nanmode"` - `"quantile"`, `"nanquantile"` +- `"topk"` -```{tip} -We would like to add support for `cumsum`, `cumprod` ([issue](https://github.com/xarray-contrib/flox/issues/91)). Contributions are welcome! -``` - -## Custom Aggregations +## Custom Reductions `flox` also allows you to specify a custom Aggregation (again inspired by dask.dataframe), though this might not be fully functional at the moment. See `aggregations.py` for examples. @@ -46,3 +43,7 @@ mean = Aggregation( final_fill_value=np.nan, ) ``` + +## Custom Scans + +Coming soon! From 096f6b92eadab78f21d1008e3266de8fe1edf604 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 18 Jan 2025 22:17:53 -0700 Subject: [PATCH 33/45] disable log in CI --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 59fa216e5..b64b3932c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -74,7 +74,7 @@ jobs: id: status run: | python -c "import xarray; xarray.show_versions()" - pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci + pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci --log-disable=flox - name: Upload code coverage to Codecov uses: codecov/codecov-action@v5.1.2 with: From 0277cb9b7ab2ff198301b75690aba05d1a9f3d98 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 18 Jan 2025 22:28:51 -0700 Subject: [PATCH 34/45] Fix boolean --- flox/xrdtypes.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/flox/xrdtypes.py b/flox/xrdtypes.py index e1b9bccec..2b9b4525a 100644 --- a/flox/xrdtypes.py +++ b/flox/xrdtypes.py @@ -109,6 +109,9 @@ def get_pos_infinity(dtype, max_for_int=False): if issubclass(dtype.type, np.complexfloating): return np.inf + 1j * np.inf + if issubclass(dtype.type, np.bool): + return True + return INF @@ -142,6 +145,9 @@ def get_neg_infinity(dtype, min_for_int=False): if issubclass(dtype.type, np.complexfloating): return -np.inf - 1j * np.inf + if issubclass(dtype.type, np.bool): + return False + return NINF From 6c7e84a05aaabd0240b4bc55e1efa15c87fd0d7c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 20 Jan 2025 10:04:18 -0700 Subject: [PATCH 35/45] bool -> bool_ --- flox/xrdtypes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flox/xrdtypes.py b/flox/xrdtypes.py index 2b9b4525a..05a060733 100644 --- a/flox/xrdtypes.py +++ b/flox/xrdtypes.py @@ -109,7 +109,7 @@ def get_pos_infinity(dtype, max_for_int=False): if issubclass(dtype.type, np.complexfloating): return np.inf + 1j * np.inf - if issubclass(dtype.type, np.bool): + if issubclass(dtype.type, np.bool_): return True return INF @@ -145,7 +145,7 @@ def get_neg_infinity(dtype, min_for_int=False): if issubclass(dtype.type, np.complexfloating): return -np.inf - 1j * np.inf - if issubclass(dtype.type, np.bool): + if issubclass(dtype.type, np.bool_): return False return NINF From 43c3408b012ade53023ae0c83774a30358a006d2 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 20 Jan 2025 10:10:03 -0700 Subject: [PATCH 36/45] update int limits --- tests/test_properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_properties.py b/tests/test_properties.py index 718a8c330..689cd3b47 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -333,7 +333,7 @@ def test_agg_dtype_specified(func, array_dtype, dtype, engine): def test_topk_max_min(data, array): "top 1 == nanmax; top -1 == nanmin" - if array.dtype.kind == "i": + if array.dtype.kind in "iu": # we cast to float and back, so this is the effective limit assume((np.abs(array) < 2**53).all()) elif array.dtype.kind in "Mm": From 01eabfbdf97cca47cfb22ab0eff263507e08a551 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 20 Jan 2025 10:11:13 -0700 Subject: [PATCH 37/45] fix rtd --- readthedocs.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/readthedocs.yml b/readthedocs.yml index 51b6b6b18..b42bd07c4 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,5 +1,9 @@ version: 2 +sphinx: + # Path to your Sphinx configuration file. + configuration: docs/source/conf.py + build: os: "ubuntu-lts-latest" tools: From 6e4ce6944b96ab8759f7560b78f1548a10d54c9d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 20 Jan 2025 10:30:44 -0700 Subject: [PATCH 38/45] Add note --- flox/core.py | 5 +++++ flox/xarray.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/flox/core.py b/flox/core.py index 2fedddd6d..063d117cf 100644 --- a/flox/core.py +++ b/flox/core.py @@ -2374,6 +2374,11 @@ def groupby_reduce( finalize_kwargs : dict, optional Kwargs passed to finalize the reduction such as ``ddof`` for var, std or ``q`` for quantile. + Notes + ----- + ``topk`` and ``quantile`` are implemented by converting to a complex number and so are limited to values between +-``2**53-1`` + i.e. the limit of a ``float64`` dtype. Offset your data appropriately if you need the larger range. + Returns ------- result diff --git a/flox/xarray.py b/flox/xarray.py index d605a5ca5..d51bf95d6 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -181,6 +181,11 @@ def xarray_reduce( DataArray or Dataset Reduced object + Notes + ----- + ``topk`` and ``quantile`` are implemented by converting to a complex number and so are limited to values between +-``2**53-1`` + i.e. the limit of a ``float64`` dtype. Offset your data appropriately if you need the larger range. + See Also -------- flox.core.groupby_reduce From 8f60477b2a98a2da896a5c59c272962960b0f518 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 23 Jan 2025 21:06:31 -0700 Subject: [PATCH 39/45] Add unit test --- flox/aggregate_flox.py | 2 +- flox/xrutils.py | 6 +++--- tests/test_core.py | 19 ++++++++++++++++--- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index c14c4df38..fae8ff5d3 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -92,7 +92,7 @@ def quantile_or_topk( # but not any more now that I use partition and avoid replacing NaNs if k is not None: is_scalar_param = False - param = np.arange(abs(k)) + param = np.sort(np.arange(abs(k)) * np.sign(k)) else: is_scalar_param = is_scalar(q) param = np.atleast_1d(q) diff --git a/flox/xrutils.py b/flox/xrutils.py index fc12cec38..d58b1da6a 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd +from numpy.lib.array_utils import normalize_axis_tuple from packaging.version import Version @@ -398,7 +399,7 @@ def nanlast(values, axis, keepdims=False): return result -def topk(a, k, axis, keepdims): +def topk(a: np.ndarray, k: int, axis, keepdims: bool = True) -> np.ndarray: """Chunk and combine function of topk Extract the k largest elements from a on the given axis. @@ -410,8 +411,7 @@ def topk(a, k, axis, keepdims): of their LICENSE. """ assert keepdims is True - (axis,) = axis - axis = normalize_axis_index(axis, a.ndim) + (axis,) = normalize_axis_tuple(axis, a.ndim) if abs(k) >= a.shape[axis]: return a diff --git a/tests/test_core.py b/tests/test_core.py index b4c83c989..74bbefc08 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -83,7 +83,7 @@ def npfunc(x, **kwargs): x = np.asarray(x) return (~xrutils.isnull(x)).sum(**kwargs) - elif func in ["nanfirst", "nanlast"]: + elif func in ["nanfirst", "nanlast", "topk"]: npfunc = getattr(xrutils, func) elif func in SCIPY_STATS_FUNCS: @@ -252,6 +252,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): ] fill_value = None tolerance = None + elif func == "topk": + finalize_kwargs = [{"k": 3}, {"k": -3}] + fill_value = None + tolerance = None else: fill_value = None tolerance = None @@ -281,6 +285,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): expected = getattr(np, func_)(array_, axis=-1, **kwargs) else: expected = array_func(array_[..., ~nanmask], axis=-1, **kwargs) + if func == "topk": + expected = np.sort(np.swapaxes(expected, array.ndim - 1, 0), axis=0) for _ in range(nby): expected = np.expand_dims(expected, -1) @@ -288,7 +294,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert chunks == -1 actual, *groups = groupby_reduce(array, *by, **flox_kwargs) - if "quantile" in func and isinstance(kwargs["q"], list): + if ("quantile" in func and isinstance(kwargs["q"], list)) or func == "topk": assert actual.ndim == expected.ndim == (array.ndim + nby) else: assert actual.ndim == expected.ndim == (array.ndim + nby - 1) @@ -298,9 +304,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert_equal(actual_group, expect) if "arg" in func: assert actual.dtype.kind == "i" + if func == "topk": + actual = np.sort(actual, axis=0) assert_equal(expected, actual, tolerance) - if "nan" not in func and "arg" not in func: + # FIXME: topk vs nantopk + if "nan" not in func and "arg" not in func and "topk" not in func: # test non-NaN skipping behaviour when NaNs are present nanned = array_.copy() # remove nans in by to reduce complexity @@ -310,6 +319,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): nanned.reshape(-1)[0] = np.nan actual, *_ = groupby_reduce(nanned, *by_, **flox_kwargs) expected_0 = array_func(nanned, axis=-1, **kwargs) + if func == "topk": + expected_0 = np.sort(np.swapaxes(expected_0, array.ndim - 1, 0), axis=-1) + actual = np.sort(actual, axis=-1) + for _ in range(nby): expected_0 = np.expand_dims(expected_0, -1) assert_equal(expected_0, actual, tolerance) From 15fcfa1b9458245a539143825c16ed4f17504cb9 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 23 Jan 2025 21:39:10 -0700 Subject: [PATCH 40/45] WIP --- tests/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_core.py b/tests/test_core.py index 74bbefc08..18cb9d835 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -224,7 +224,7 @@ def gen_array_by(size, func): @pytest.mark.parametrize("size", ((1, 12), (12,), (12, 9))) @pytest.mark.parametrize("nby", [1, 2, 3]) @pytest.mark.parametrize("add_nan_by", [True, False]) -@pytest.mark.parametrize("func", ALL_FUNCS) +@pytest.mark.parametrize("func", ["topk"]) def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): if ("arg" in func and engine in ["flox", "numbagg"]) or (func in BLOCKWISE_FUNCS and chunks != -1): pytest.skip() From a5bcc5be642c0c0c825ccb536208a0b736d569e3 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 23 Jan 2025 21:51:31 -0700 Subject: [PATCH 41/45] fix --- flox/core.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/flox/core.py b/flox/core.py index 063d117cf..eec663ef7 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1071,8 +1071,16 @@ def chunk_reduce( # optimize that out. previous_reduction: T_Func = "" for reduction, fv, kw, dt in zip(funcs, fill_values, kwargss, dtypes): + # TODO: Figure out how to generalize this + if reduction in ("quantile", "nanquantile"): + new_dims_shape = tuple(dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar) + elif reduction == "topk": + new_dims_shape = tuple(dim.size for dim in topk_new_dims_func(**kw) if not dim.is_scalar) + else: + new_dims_shape = tuple() + if empty: - result = np.full(shape=final_array_shape, fill_value=fv) + result = np.full(shape=new_dims_shape + final_array_shape, fill_value=fv) elif is_nanlen(reduction) and is_nanlen(previous_reduction): result = results["intermediates"][-1] else: @@ -1101,13 +1109,6 @@ def chunk_reduce( if hasnan: # remove NaN group label which should be last result = result[..., :-1] - # TODO: Figure out how to generalize this - if reduction in ("quantile", "nanquantile"): - new_dims_shape = tuple(dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar) - elif reduction == "topk": - new_dims_shape = tuple(dim.size for dim in topk_new_dims_func(**kw) if not dim.is_scalar) - else: - new_dims_shape = tuple() result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape) results["intermediates"].append(result) previous_reduction = reduction From 91e1d0733671e6deca173e7466dde42482ef2c3f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 17 Mar 2025 21:08:13 -0600 Subject: [PATCH 42/45] Switch DUMMY_AXIS to 0 --- flox/core.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/flox/core.py b/flox/core.py index a557bfdb5..9279b372e 100644 --- a/flox/core.py +++ b/flox/core.py @@ -113,7 +113,7 @@ # This dummy axis is inserted using np.expand_dims # and then reduced over during the combine stage by # _simple_combine. -DUMMY_AXIS = -2 +DUMMY_AXIS = 0 logger = logging.getLogger("flox") @@ -1203,8 +1203,15 @@ def _aggregate( return _finalize_results(results, agg, axis, expected_groups, reindex) -def _expand_dims(results: IntermediateDict) -> IntermediateDict: - results["intermediates"] = tuple(np.expand_dims(array, DUMMY_AXIS) for array in results["intermediates"]) +def _expand_dims(results: IntermediateDict, agg: Aggregation) -> IntermediateDict: + if agg.name == "topk": + results["intermediates"] = tuple(results["intermediates"][:1]) + tuple( + np.expand_dims(array, DUMMY_AXIS) for array in results["intermediates"][1:] + ) + else: + results["intermediates"] = tuple( + np.expand_dims(array, DUMMY_AXIS) for array in results["intermediates"] + ) return results @@ -1254,7 +1261,7 @@ def _simple_combine( results: IntermediateDict = {"groups": unique_groups} results["intermediates"] = [] - axis_ = axis[:-1] + (DUMMY_AXIS,) + axis_ = (DUMMY_AXIS,) + tuple(a + 1 for a in axis[:-1]) for idx, combine in enumerate(agg.simple_combine): array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis_) assert array.ndim >= 2 @@ -1262,7 +1269,9 @@ def _simple_combine( warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") assert callable(combine) result = combine(array, axis=axis_, keepdims=True) - if is_aggregate: + # FIXME: The `idx > 0` clause assumes that DUMMY_AXIS = 0 + # and is inserted by the first elem of simple_combine. + if is_aggregate and (agg.new_dims_func is None or idx > 0): # squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate result = result.squeeze(axis=DUMMY_AXIS) results["intermediates"].append(result) @@ -1677,7 +1686,7 @@ def dask_groupby_agg( ) if do_simple_combine: # Add a dummy dimension that then gets reduced over - blockwise_method = tlz.compose(_expand_dims, blockwise_method) + blockwise_method = tlz.compose(partial(_expand_dims, agg=agg), blockwise_method) # apply reduction on chunk intermediate = dask.array.blockwise( From 2d868fe6fe8f55e67b288986d83636164eb54601 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 17 Mar 2025 21:08:37 -0600 Subject: [PATCH 43/45] More support for edge cases --- flox/aggregate_flox.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index fae8ff5d3..9987704d6 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -126,8 +126,12 @@ def quantile_or_topk( virtual_index[..., 1:] += offset[..., :-1] kth = np.unique(virtual_index) kth = kth[kth >= 0] + kth[kth >= array.shape[axis]] = array.shape[axis] - 1 k_offset = param.reshape((abs(k),) + (1,) * virtual_index.ndim) lo_ = k_offset + virtual_index[np.newaxis, ...] + not_enough_elems = actual_sizes < np.abs(k) + lo_[..., not_enough_elems] = 0 + badmask = np.broadcast_to(not_enough_elems, idxshape) | nanmask # partition the complex array in-place labels_broadcast = np.broadcast_to(group_idx, array.shape) @@ -157,8 +161,6 @@ def quantile_or_topk( result[..., nanmask] = fill_value else: result = loval - # The first clause is True if numel in group < abs(k) - badmask = np.broadcast_to(lo_ < 0, idxshape) | nanmask if badmask.any(): result[badmask] = fill_value From d244d60e1ce94e0eed48df6bf61d8c40f8f9fa3e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 17 Mar 2025 21:13:12 -0600 Subject: [PATCH 44/45] minor --- flox/aggregations.py | 1 + flox/core.py | 2 +- tests/test_core.py | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 310725361..d8e2ff094 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -580,6 +580,7 @@ def topk_new_dims_func(k) -> tuple[Dim]: name="topk", fill_value=(dtypes.NINF, 0), final_fill_value=dtypes.NA, + # FIXME: set numpy chunk=("topk", "nanlen"), combine=(xrutils.topk, "sum"), new_dims_func=topk_new_dims_func, diff --git a/flox/core.py b/flox/core.py index 9279b372e..acb81fd49 100644 --- a/flox/core.py +++ b/flox/core.py @@ -2257,7 +2257,7 @@ def _choose_method( return method -def _choose_engine(by, agg: Aggregation): +def _choose_engine(by, agg: Aggregation) -> T_Engine: dtype = agg.dtype["user"] not_arg_reduce = not _is_arg_reduction(agg) diff --git a/tests/test_core.py b/tests/test_core.py index 59ad261a8..02bda8ebf 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -266,6 +266,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): for kwargs in finalize_kwargs: if "quantile" in func and isinstance(kwargs["q"], list) and engine != "flox": continue + if "topk" in func and engine != "flox": + continue flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value) with np.errstate(invalid="ignore", divide="ignore"): with warnings.catch_warnings(): @@ -286,6 +288,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): else: expected = array_func(array_[..., ~nanmask], axis=-1, **kwargs) if func == "topk": + if nanmask.all(): + expected = np.full(expected.shape[:-1] + (abs(kwargs["k"]),), np.nan) expected = np.sort(np.swapaxes(expected, array.ndim - 1, 0), axis=0) for _ in range(nby): expected = np.expand_dims(expected, -1) From 8319f7f5ab82f1055295131684f1cc5d4c5e72e7 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 17 Mar 2025 21:28:17 -0600 Subject: [PATCH 45/45] [WIP] failing test --- tests/test_core.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 02bda8ebf..0552af142 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -216,14 +216,14 @@ def gen_array_by(size, func): "chunks", [ None, - pytest.param(-1, marks=requires_dask), - pytest.param(3, marks=requires_dask), - pytest.param(4, marks=requires_dask), + # pytest.param(-1, marks=requires_dask), + # pytest.param(3, marks=requires_dask), + # pytest.param(4, marks=requires_dask), ], ) -@pytest.mark.parametrize("size", ((1, 12), (12,), (12, 9))) -@pytest.mark.parametrize("nby", [1, 2, 3]) -@pytest.mark.parametrize("add_nan_by", [True, False]) +@pytest.mark.parametrize("size", ((12, 6),)) +@pytest.mark.parametrize("nby", [2]) +@pytest.mark.parametrize("add_nan_by", [True]) @pytest.mark.parametrize("func", ["topk"]) def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): if ("arg" in func and engine in ["flox", "numbagg"]) or (func in BLOCKWISE_FUNCS and chunks != -1):