Skip to content

MAINT: bump to sparse >=0.17 #318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,26 +322,28 @@ def capabilities(
dict
Capabilities of the namespace.
"""
if is_pydata_sparse_namespace(xp):
# No __array_namespace_info__(); no indexing by sparse arrays
return {
"boolean indexing": False,
"data-dependent shapes": True,
"max dimensions": None,
}
out = xp.__array_namespace_info__().capabilities()
if is_jax_namespace(xp) and out["boolean indexing"]:
# FIXME https://github.com/jax-ml/jax/issues/27418
# Fixed in jax >=0.6.0
out = out.copy()
out["boolean indexing"] = False
if is_torch_namespace(xp):
if is_pydata_sparse_namespace(xp):
if out["boolean indexing"]:
# FIXME https://github.com/pydata/sparse/issues/876
# boolean indexing is supported, but not when the index is a sparse array.
# boolean indexing by list or numpy array is not part of the Array API.
out = out.copy()
out["boolean indexing"] = False
elif is_jax_namespace(xp):
if out["boolean indexing"]: # pragma: no cover
# Backwards compatibility with jax <0.6.0
# https://github.com/jax-ml/jax/issues/27418
out = out.copy()
out["boolean indexing"] = False
elif is_torch_namespace(xp):
# FIXME https://github.com/data-apis/array-api/issues/945
device = xp.get_default_device() if device is None else xp.device(device)
if device.type == "meta": # type: ignore[union-attr] # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
out = out.copy()
out["boolean indexing"] = False
out["data-dependent shapes"] = False

return out


Expand Down
21 changes: 11 additions & 10 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def test_complex(self, xp: ModuleType):
expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128)
xp_assert_close(actual, expect)

@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="matmul with nan fillvalue")
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="matmul with nan fillvalue")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this tracked by an issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is now pydata/sparse#877

def test_empty(self, xp: ModuleType):
with warnings.catch_warnings(record=True):
warnings.simplefilter("always", RuntimeWarning)
Expand Down Expand Up @@ -451,7 +451,7 @@ def test_xp(self, xp: ModuleType):
)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="backend doesn't have arange")
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False)
class TestOneHot:
@pytest.mark.parametrize("n_dim", range(4))
@pytest.mark.parametrize("num_classes", [1, 3, 10])
Expand Down Expand Up @@ -816,7 +816,7 @@ def test_bool_dtype(self, xp: ModuleType):
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
)

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
def test_none_shape(self, xp: ModuleType):
a = xp.asarray([1, 5, 0])
Expand All @@ -825,7 +825,7 @@ def test_none_shape(self, xp: ModuleType):
a = a[a < 5]
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
def test_none_shape_bool(self, xp: ModuleType):
a = xp.asarray([True, True, False])
Expand Down Expand Up @@ -1141,10 +1141,10 @@ def test_xp(self, xp: ModuleType):


class TestSinc:
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no linspace")
def test_simple(self, xp: ModuleType):
xp_assert_equal(sinc(xp.asarray(0.0)), xp.asarray(1.0))
w = sinc(xp.linspace(-1, 1, 100))
x = xp.asarray(np.linspace(-1, 1, 100))
w = sinc(x)
# check symmetry
xp_assert_close(w, xp.flip(w, axis=0))

Expand All @@ -1153,11 +1153,12 @@ def test_dtype(self, xp: ModuleType, x: int | complex):
with pytest.raises(ValueError, match="real floating data type"):
_ = sinc(xp.asarray(x))

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange")
def test_3d(self, xp: ModuleType):
x = xp.reshape(xp.arange(18, dtype=xp.float64), (3, 3, 2))
expected = xp.zeros((3, 3, 2), dtype=xp.float64)
expected = at(expected)[0, 0, 0].set(1.0)
x = np.arange(18, dtype=np.float64).reshape((3, 3, 2))
expected = np.zeros_like(x)
expected[0, 0, 0] = 1
x = xp.asarray(x)
expected = xp.asarray(expected)
xp_assert_close(sinc(x), expected, atol=1e-15)

def test_device(self, xp: ModuleType, device: Device):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def override(func):
lazy_xp_function(in1d, jax_jit=False)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse")
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse")
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_inverse")
class TestIn1D:
# cover both code paths
Expand Down
2 changes: 1 addition & 1 deletion tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_assert_less(self, xp: ModuleType):
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))

@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
def test_none_shape(self, xp: ModuleType, func: Callable[..., None]):
"""On Dask and other lazy backends, test that a shape with NaN's or None's
Expand Down