Skip to content

Commit 185d110

Browse files
committed
Fix sparse reindexing
1 parent b6f62ca commit 185d110

File tree

3 files changed

+73
-20
lines changed

3 files changed

+73
-20
lines changed

flox/core.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
6969

7070
HAS_NUMBAGG = module_available("numbagg", minversion="0.3.0")
71+
HAS_SPARSE = module_available("sparse")
7172

7273
if TYPE_CHECKING:
7374
try:
@@ -255,6 +256,12 @@ def _is_bool_supported_reduction(func: T_Agg) -> bool:
255256
)
256257

257258

259+
def _is_sparse_supported_reduction(func: T_Agg) -> bool:
260+
if isinstance(func, Aggregation):
261+
func = func.name
262+
return not HAS_SPARSE or all(f not in func for f in ["first", "last", "prod", "var", "std"])
263+
264+
258265
def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
259266
if is_duck_dask_array(by):
260267
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
@@ -736,12 +743,12 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
736743
return array.rechunk({axis: newchunks})
737744

738745

739-
def reindex_numpy(array, from_, to, fill_value, dtype, axis):
746+
def reindex_numpy(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int):
740747
idx = from_.get_indexer(to)
741748
indexer = [slice(None, None)] * array.ndim
742749
indexer[axis] = idx
743750
reindexed = array[tuple(indexer)]
744-
if any(idx == -1):
751+
if (idx == -1).any():
745752
if fill_value is None:
746753
raise ValueError("Filling is required. fill_value cannot be None.")
747754
indexer[axis] = idx == -1
@@ -750,25 +757,43 @@ def reindex_numpy(array, from_, to, fill_value, dtype, axis):
750757
return reindexed
751758

752759

753-
def reindex_pydata_sparse_coo(array, from_, to, fill_value, dtype, axis):
760+
def reindex_pydata_sparse_coo(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int):
754761
import sparse
755762

756763
assert axis == -1
757764

758-
if fill_value is None:
759-
raise ValueError("Filling is required for sparse arrays. fill_value cannot be None.")
765+
needs_reindex = (from_.difference(to)).size > 0
766+
if needs_reindex and fill_value is None:
767+
raise ValueError("Filling is required. fill_value cannot be None.")
768+
760769
idx = to.get_indexer(from_)
761-
mask = idx != -1
770+
mask = idx != -1 # indices along last axis to keep
771+
if mask.all():
772+
mask = slice(None)
762773
shape = array.shape
763-
ranges = np.broadcast_arrays(*np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx[mask],))))
764-
coords = np.stack(ranges, axis=0).reshape(array.ndim, -1)
765774

766-
data = array[..., mask].data if isinstance(array, sparse.COO) else array[..., mask].reshape(-1)
775+
if isinstance(array, sparse.COO):
776+
subset = array[..., mask]
777+
data = subset.data
778+
coords = subset.coords
779+
if subset.nnz > 0:
780+
coords[-1, :] = idx[mask][coords[-1, :]]
781+
if fill_value is None:
782+
# no reindexing is actually needed (dense case)
783+
# preserve the fill_value
784+
fill_value = array.fill_value
785+
else:
786+
ranges = np.broadcast_arrays(
787+
*np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx[mask],)))
788+
)
789+
coords = np.stack(ranges, axis=0).reshape(array.ndim, -1)
790+
data = array[..., mask].reshape(-1)
767791

768792
reindexed = sparse.COO(
769793
coords=coords,
770794
data=data.astype(dtype, copy=False),
771795
shape=(*array.shape[:axis], to.size),
796+
fill_value=fill_value,
772797
)
773798

774799
return reindexed
@@ -795,7 +820,11 @@ def reindex_(
795820

796821
if array.shape[axis] == 0:
797822
# all groups were NaN
798-
reindexed = np.full(array.shape[:-1] + (len(to),), fill_value, dtype=array.dtype)
823+
shape = array.shape[:-1] + (len(to),)
824+
if array_type in (ReindexArrayType.AUTO, ReindexArrayType.NUMPY):
825+
reindexed = np.full(shape, fill_value, dtype=array.dtype)
826+
else:
827+
raise NotImplementedError
799828
return reindexed
800829

801830
from_ = pd.Index(from_)
@@ -1044,7 +1073,7 @@ def chunk_argreduce(
10441073
sort=sort,
10451074
user_dtype=user_dtype,
10461075
)
1047-
if not isnull(results["groups"]).all():
1076+
if not all(isnull(results["groups"])):
10481077
idx = np.broadcast_to(idx, array.shape)
10491078

10501079
# array, by get flattened to 1D before passing to npg
@@ -1288,7 +1317,7 @@ def _finalize_results(
12881317
fill_value = agg.fill_value["user"]
12891318
if min_count > 0:
12901319
count_mask = counts < min_count
1291-
if count_mask.any():
1320+
if count_mask.any() or reindex.array_type is ReindexArrayType.SPARSE_COO:
12921321
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
12931322
# necessary
12941323
if fill_value is None:
@@ -2815,6 +2844,12 @@ def groupby_reduce(
28152844
array.dtype,
28162845
)
28172846

2847+
if reindex.array_type is ReindexArrayType.SPARSE_COO and not _is_sparse_supported_reduction(func):
2848+
raise NotImplementedError(
2849+
f"Aggregation {func=!r} is not supported when reindexing to a sparse array. "
2850+
"Please raise an issue"
2851+
)
2852+
28182853
if TYPE_CHECKING:
28192854
assert isinstance(reindex, ReindexStrategy)
28202855
assert method is not None

flox/xrutils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def notnull(data):
159159
return out
160160

161161

162-
def isnull(data):
162+
def isnull(data: Any):
163+
if data is None:
164+
return False
163165
if not is_duck_array(data):
164166
data = np.asarray(data)
165167
scalar_type = data.dtype.type
@@ -177,7 +179,7 @@ def isnull(data):
177179
else:
178180
# at this point, array should have dtype=object
179181
if isinstance(data, (np.ndarray, dask_array_type)): # noqa
180-
return pd.isnull(data)
182+
return pd.isnull(data) # type: ignore[arg-type]
181183
else:
182184
# Not reachable yet, but intended for use with other duck array
183185
# types. For full consistency with pandas, we should accept None as
@@ -374,9 +376,10 @@ def _select_along_axis(values, idx, axis):
374376
def nanfirst(values, axis, keepdims=False):
375377
if isinstance(axis, tuple):
376378
(axis,) = axis
377-
values = np.asarray(values)
379+
if not is_duck_array(values):
380+
values = np.asarray(values)
378381
axis = normalize_axis_index(axis, values.ndim)
379-
idx_first = np.argmax(~pd.isnull(values), axis=axis)
382+
idx_first = np.argmax(~isnull(values), axis=axis)
380383
result = _select_along_axis(values, idx_first, axis)
381384
if keepdims:
382385
return np.expand_dims(result, axis=axis)
@@ -387,10 +390,11 @@ def nanfirst(values, axis, keepdims=False):
387390
def nanlast(values, axis, keepdims=False):
388391
if isinstance(axis, tuple):
389392
(axis,) = axis
390-
values = np.asarray(values)
393+
if not is_duck_array(values):
394+
values = np.asarray(values)
391395
axis = normalize_axis_index(axis, values.ndim)
392396
rev = (slice(None),) * axis + (slice(None, None, -1),)
393-
idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis)
397+
idx_last = -1 - np.argmax(~isnull(values)[rev], axis=axis)
394398
result = _select_along_axis(values, idx_last, axis)
395399
if keepdims:
396400
return np.expand_dims(result, axis=axis)

tests/test_core.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_choose_engine,
2525
_convert_expected_groups_to_index,
2626
_get_optimal_chunks_for_groups,
27+
_is_sparse_supported_reduction,
2728
_normalize_indexes,
2829
_validate_reindex,
2930
factorize_,
@@ -320,13 +321,20 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
320321
if not has_dask or chunks is None or func in BLOCKWISE_FUNCS:
321322
continue
322323

323-
params = list(itertools.product(["map-reduce"], [True, False, None]))
324+
params = list(
325+
itertools.product(
326+
["map-reduce"],
327+
[True, False, None, ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO)],
328+
)
329+
)
324330
params.extend(itertools.product(["cohorts"], [False, None]))
325331
if chunks == -1:
326332
params.extend([("blockwise", None)])
327333

328334
combine_error = RuntimeError("This combine should not have been called.")
329335
for method, reindex in params:
336+
if isinstance(reindex, ReindexStrategy) and not _is_sparse_supported_reduction(func):
337+
continue
330338
call = partial(
331339
groupby_reduce,
332340
array,
@@ -360,6 +368,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
360368
assert_equal(actual_group, expect, tolerance)
361369
if "arg" in func:
362370
assert actual.dtype.kind == "i"
371+
if isinstance(reindex, ReindexStrategy):
372+
import sparse
373+
374+
expected = sparse.COO.from_numpy(expected)
363375
assert_equal(actual, expected, tolerance)
364376

365377

@@ -2085,7 +2097,9 @@ def mocked_reindex(*args, **kwargs):
20852097

20862098
with patch("flox.core.reindex_") as mocked_func:
20872099
mocked_func.side_effect = mocked_reindex
2088-
actual, *_ = groupby_reduce(array, by, func=func, reindex=reindex, expected_groups=expected_groups)
2100+
actual, *_ = groupby_reduce(
2101+
array, by, func=func, reindex=reindex, expected_groups=expected_groups, fill_value=0
2102+
)
20892103
assert_equal(actual, expected)
20902104
# once during graph construction, 10 times afterward
20912105
assert mocked_func.call_count > 1

0 commit comments

Comments
 (0)