Skip to content

Added warnings for NaN propagation #469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions sparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from functools import wraps
from itertools import chain
from collections.abc import Iterable
from scipy.sparse import spmatrix
from numba import literal_unroll
import warnings

from ._sparse_array import SparseArray
from ._utils import check_compressed_axes, normalize_axis, check_zero_fill_value
Expand Down Expand Up @@ -33,6 +36,50 @@
)


@numba.njit
def nan_check(*args):
"""
Check for the NaN values in Numpy Arrays

Parameters
----------
Union[Numpy Array, Integer, Float]

Returns
-------
Boolean Whether Numpy Array Contains NaN

"""
for i in literal_unroll(args):
if np.isnan(np.min(np.asarray(i))):
return True
return False


def check_class_nan(test):
"""
Check NaN for Sparse Arrays

Parameters
----------
test : Union[sparse.COO, sparse.GCXS, scipy.sparse.spmatrix, Numpy Ndarrays]

Returns
-------
Boolean Whether Sparse Array Contains NaN

"""
from ._compressed import GCXS
from ._coo import COO

if isinstance(test, (GCXS, COO)):
return nan_check(test.fill_value, test.data)
elif isinstance(test, spmatrix):
return nan_check(test.data)
else:
return nan_check(test)


def tensordot(a, b, axes=2, *, return_type=None):
"""
Perform the equivalent of :obj:`numpy.tensordot`.
Expand Down Expand Up @@ -174,6 +221,11 @@ def matmul(a, b):
"Cannot perform dot product on types %s, %s" % (type(a), type(b))
)

if check_class_nan(a) or check_class_nan(b):
warnings.warn(
"Nan will not be propagated in matrix multiplication", RuntimeWarning
)

# When b is 2-d, it is equivalent to dot
if b.ndim <= 2:
return dot(a, b)
Expand Down
38 changes: 38 additions & 0 deletions sparse/tests/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,44 @@ def test_matmul_errors():
sparse.matmul(sa, sb)


@pytest.mark.parametrize(
"a, b",
[
(
sparse.GCXS.from_numpy(
np.random.choice(
[0, np.nan, 2], size=[100, 100], p=[0.99, 0.001, 0.009]
)
),
sparse.random((100, 100), density=0.01),
),
(
sparse.COO.from_numpy(
np.random.choice(
[0, np.nan, 2], size=[100, 100], p=[0.99, 0.001, 0.009]
)
),
sparse.random((100, 100), density=0.01),
),
(
sparse.GCXS.from_numpy(
np.random.choice(
[0, np.nan, 2], size=[100, 100], p=[0.99, 0.001, 0.009]
)
),
scipy.sparse.random(100, 100),
),
(
np.random.choice([0, np.nan, 2], size=[100, 100], p=[0.99, 0.001, 0.009]),
sparse.random((100, 100), density=0.01),
),
],
)
def test_matmul_nan_warnings(a, b):
with pytest.warns(RuntimeWarning):
a @ b


@pytest.mark.parametrize(
"a_shape, b_shape",
[
Expand Down