Skip to content

Commit a69a43b

Browse files
committed
Fix sparse reindexing
1 parent b6f62ca commit a69a43b

File tree

3 files changed

+56
-14
lines changed

3 files changed

+56
-14
lines changed

flox/core.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
6969

7070
HAS_NUMBAGG = module_available("numbagg", minversion="0.3.0")
71+
HAS_SPARSE = module_available("sparse")
7172

7273
if TYPE_CHECKING:
7374
try:
@@ -255,6 +256,12 @@ def _is_bool_supported_reduction(func: T_Agg) -> bool:
255256
)
256257

257258

259+
def _is_sparse_supported_reduction(func: T_Agg) -> bool:
260+
if isinstance(func, Aggregation):
261+
func = func.name
262+
return not HAS_SPARSE or all(f not in func for f in ["first", "last", "prod", "var", "std"])
263+
264+
258265
def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
259266
if is_duck_dask_array(by):
260267
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
@@ -741,7 +748,7 @@ def reindex_numpy(array, from_, to, fill_value, dtype, axis):
741748
indexer = [slice(None, None)] * array.ndim
742749
indexer[axis] = idx
743750
reindexed = array[tuple(indexer)]
744-
if any(idx == -1):
751+
if (idx == -1).any():
745752
if fill_value is None:
746753
raise ValueError("Filling is required. fill_value cannot be None.")
747754
indexer[axis] = idx == -1
@@ -755,20 +762,36 @@ def reindex_pydata_sparse_coo(array, from_, to, fill_value, dtype, axis):
755762

756763
assert axis == -1
757764

758-
if fill_value is None:
759-
raise ValueError("Filling is required for sparse arrays. fill_value cannot be None.")
765+
needs_reindex = (from_.get_indexer(to) == -1).any()
766+
if needs_reindex and fill_value is None:
767+
raise ValueError("Filling is required. fill_value cannot be None.")
768+
760769
idx = to.get_indexer(from_)
761-
mask = idx != -1
770+
mask = idx != -1 # indices along last axis to keep
762771
shape = array.shape
763-
ranges = np.broadcast_arrays(*np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx[mask],))))
764-
coords = np.stack(ranges, axis=0).reshape(array.ndim, -1)
765772

766-
data = array[..., mask].data if isinstance(array, sparse.COO) else array[..., mask].reshape(-1)
773+
if isinstance(array, sparse.COO):
774+
subset = array[..., mask]
775+
data = subset.data
776+
coords = subset.coords
777+
if subset.nnz > 0:
778+
coords[-1, :] = coords[-1, :][idx[mask]]
779+
if fill_value is None:
780+
# no reindexing is actually needed (dense case)
781+
# preserve the fill_value
782+
fill_value = array.fill_value
783+
else:
784+
ranges = np.broadcast_arrays(
785+
*np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx[mask],)))
786+
)
787+
coords = np.stack(ranges, axis=0).reshape(array.ndim, -1)
788+
data = array[..., mask].reshape(-1)
767789

768790
reindexed = sparse.COO(
769791
coords=coords,
770792
data=data.astype(dtype, copy=False),
771793
shape=(*array.shape[:axis], to.size),
794+
fill_value=fill_value,
772795
)
773796

774797
return reindexed
@@ -795,7 +818,11 @@ def reindex_(
795818

796819
if array.shape[axis] == 0:
797820
# all groups were NaN
798-
reindexed = np.full(array.shape[:-1] + (len(to),), fill_value, dtype=array.dtype)
821+
shape = array.shape[:-1] + (len(to),)
822+
if array_type in (ReindexArrayType.AUTO, ReindexArrayType.NUMPY):
823+
reindexed = np.full(shape, fill_value, dtype=array.dtype)
824+
else:
825+
raise NotImplementedError
799826
return reindexed
800827

801828
from_ = pd.Index(from_)
@@ -1288,7 +1315,7 @@ def _finalize_results(
12881315
fill_value = agg.fill_value["user"]
12891316
if min_count > 0:
12901317
count_mask = counts < min_count
1291-
if count_mask.any():
1318+
if count_mask.any() or reindex.array_type is ReindexArrayType.SPARSE_COO:
12921319
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
12931320
# necessary
12941321
if fill_value is None:
@@ -2815,6 +2842,12 @@ def groupby_reduce(
28152842
array.dtype,
28162843
)
28172844

2845+
if reindex.array_type is ReindexArrayType.SPARSE_COO and not _is_sparse_supported_reduction(func):
2846+
raise NotImplementedError(
2847+
f"Aggregation {func=!r} is not supported when reindexing to a sparse array. "
2848+
"Please raise an issue"
2849+
)
2850+
28182851
if TYPE_CHECKING:
28192852
assert isinstance(reindex, ReindexStrategy)
28202853
assert method is not None

flox/xrutils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def notnull(data):
159159
return out
160160

161161

162-
def isnull(data):
162+
def isnull(data: Any) -> bool:
163+
if data is None:
164+
return False
163165
if not is_duck_array(data):
164166
data = np.asarray(data)
165167
scalar_type = data.dtype.type
@@ -374,9 +376,9 @@ def _select_along_axis(values, idx, axis):
374376
def nanfirst(values, axis, keepdims=False):
375377
if isinstance(axis, tuple):
376378
(axis,) = axis
377-
values = np.asarray(values)
379+
# values = np.asarray(values)
378380
axis = normalize_axis_index(axis, values.ndim)
379-
idx_first = np.argmax(~pd.isnull(values), axis=axis)
381+
idx_first = np.argmax(~isnull(values), axis=axis)
380382
result = _select_along_axis(values, idx_first, axis)
381383
if keepdims:
382384
return np.expand_dims(result, axis=axis)
@@ -387,10 +389,10 @@ def nanfirst(values, axis, keepdims=False):
387389
def nanlast(values, axis, keepdims=False):
388390
if isinstance(axis, tuple):
389391
(axis,) = axis
390-
values = np.asarray(values)
392+
# values = np.asarray(values)
391393
axis = normalize_axis_index(axis, values.ndim)
392394
rev = (slice(None),) * axis + (slice(None, None, -1),)
393-
idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis)
395+
idx_last = -1 - np.argmax(~isnull(values)[rev], axis=axis)
394396
result = _select_along_axis(values, idx_last, axis)
395397
if keepdims:
396398
return np.expand_dims(result, axis=axis)

tests/test_core.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_choose_engine,
2525
_convert_expected_groups_to_index,
2626
_get_optimal_chunks_for_groups,
27+
_is_sparse_supported_reduction,
2728
_normalize_indexes,
2829
_validate_reindex,
2930
factorize_,
@@ -327,6 +328,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
327328

328329
combine_error = RuntimeError("This combine should not have been called.")
329330
for method, reindex in params:
331+
if isinstance(reindex, ReindexStrategy) and not _is_sparse_supported_reduction(func):
332+
continue
330333
call = partial(
331334
groupby_reduce,
332335
array,
@@ -360,6 +363,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
360363
assert_equal(actual_group, expect, tolerance)
361364
if "arg" in func:
362365
assert actual.dtype.kind == "i"
366+
if isinstance(reindex, ReindexStrategy):
367+
import sparse
368+
369+
expected = sparse.COO.from_numpy(expected)
363370
assert_equal(actual, expected, tolerance)
364371

365372

0 commit comments

Comments
 (0)