Skip to content

Commit 41abaac

Browse files
committed
more test
1 parent 7f2981f commit 41abaac

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

flox/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ def reindex_pydata_sparse_coo(array, from_: pd.Index, to: pd.Index, fill_value,
762762

763763
assert axis == -1
764764

765-
needs_reindex = (from_.difference(to)).size > 0
765+
needs_reindex = (from_.get_indexer(to) == -1).any()
766766
if needs_reindex and fill_value is None:
767767
raise ValueError("Filling is required. fill_value cannot be None.")
768768

tests/test_core.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
assert_equal_tuple,
4545
has_cubed,
4646
has_dask,
47+
has_sparse,
4748
raise_if_dask_computes,
4849
requires_cubed,
4950
requires_dask,
@@ -459,7 +460,18 @@ def test_numpy_reduce_nd_md():
459460

460461

461462
@requires_dask
462-
@pytest.mark.parametrize("reindex", [None, False, True])
463+
@pytest.mark.parametrize(
464+
"reindex",
465+
[
466+
None,
467+
False,
468+
True,
469+
pytest.param(
470+
ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO),
471+
marks=pytest.mark.skipif(not has_sparse, reason="no sparse"),
472+
),
473+
],
474+
)
463475
@pytest.mark.parametrize("func", ALL_FUNCS)
464476
@pytest.mark.parametrize("add_nan", [False, True])
465477
@pytest.mark.parametrize("dtype", (float,))
@@ -482,6 +494,9 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp
482494
if "arg" in func and (engine in ["flox", "numbagg"] or reindex):
483495
pytest.skip()
484496

497+
if isinstance(reindex, ReindexStrategy) and not _is_sparse_supported_reduction(func):
498+
pytest.skip()
499+
485500
rng = np.random.default_rng(12345)
486501
array = dask.array.from_array(rng.random(shape), chunks=array_chunks).astype(dtype)
487502
array = dask.array.ones(shape, chunks=array_chunks)
@@ -787,6 +802,11 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
787802
(None, None),
788803
pytest.param(False, (2, 2, 3), marks=requires_dask),
789804
pytest.param(True, (2, 2, 3), marks=requires_dask),
805+
pytest.param(
806+
ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO),
807+
(2, 2, 3),
808+
marks=(requires_dask, pytest.mark.skipif(not has_sparse, reason="no sparse")),
809+
),
790810
],
791811
)
792812
@pytest.mark.parametrize(
@@ -833,7 +853,17 @@ def _maybe_chunk(arr):
833853
@requires_dask
834854
@pytest.mark.parametrize(
835855
"expected_groups, reindex",
836-
[(None, None), (None, False), ([0, 1, 2], True), ([0, 1, 2], False)],
856+
[
857+
(None, None),
858+
(None, False),
859+
([0, 1, 2], True),
860+
([0, 1, 2], False),
861+
pytest.param(
862+
[0, 1, 2],
863+
ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO),
864+
marks=pytest.mark.skipif(not has_sparse, reason="no sparse"),
865+
),
866+
],
837867
)
838868
def test_groupby_all_nan_blocks_dask(expected_groups, reindex, engine):
839869
labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])

0 commit comments

Comments
 (0)