Skip to content

Commit 84e4e6c

Browse files
committed
Add back test disabled for debugging.
1 parent 9ec04e6 commit 84e4e6c

File tree

3 files changed

+75
-55
lines changed

3 files changed

+75
-55
lines changed

sparse/coo.py

Lines changed: 37 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from numpy.lib.mixins import NDArrayOperatorsMixin
1212

1313
from .slicing import normalize_index
14-
from .utils import _zero_of_dtype, isscalar
14+
from .utils import _zero_of_dtype, isscalar, PositinalArgumentPartial
1515
from .sparse_array import SparseArray
1616
from .compatibility import int, zip_longest, range, zip
1717

@@ -1895,6 +1895,24 @@ def tril(x, k=0):
18951895

18961896

18971897
def _nary_match(*arrays):
1898+
"""
1899+
Matches coordinates from N different 1-D arrays. Equivalent to
1900+
an SQL outer join.
1901+
1902+
Parameters
1903+
----------
1904+
arrays : tuple[numpy.ndarray]
1905+
Input arrays to match, must be sorted.
1906+
1907+
Returns
1908+
-------
1909+
matched : numpy.ndarray
1910+
The overall matched array.
1911+
1912+
matched_idx : list[numpy.ndarray]
1913+
The indices for the matched coordinates in each array.
1914+
"""
1915+
18981916
matched = arrays[0]
18991917
matched_idx = [np.arange(arrays[0].shape[0],
19001918
dtype=np.min_scalar_type(arrays[0].shape[0] - 1))]
@@ -2022,28 +2040,25 @@ def elemwise(func, *args, **kwargs):
20222040
args = list(args)
20232041
posargs = []
20242042
pos = []
2025-
for i in range(len(args)):
2026-
if isinstance(args[i], scipy.sparse.spmatrix):
2027-
args[i] = COO.from_scipy_sparse(args[i])
2028-
2029-
if isscalar(args[i]) or (isinstance(args[i], np.ndarray)
2030-
and not args[i].shape):
2043+
for i, arg in enumerate(args):
2044+
if isinstance(arg, scipy.sparse.spmatrix):
2045+
args[i] = COO.from_scipy_sparse(arg)
2046+
elif isscalar(arg) or (isinstance(arg, np.ndarray)
2047+
and not arg.shape):
20312048
# Faster and more reliable to pass ()-shaped ndarrays as scalars.
2032-
if isinstance(args[i], np.ndarray):
2033-
args[i] = args[i][()]
2049+
if isinstance(arg, np.ndarray):
2050+
args[i] = arg[()]
20342051

2035-
# The -scalars factor is there because we need to account for already
2036-
# added scalars in the function.
20372052
pos.append(i)
20382053
posargs.append(args[i])
2039-
elif isinstance(args[i], SparseArray) and not isinstance(args[i], COO):
2040-
args[i] = COO(args[i])
2041-
elif not isinstance(args[i], COO):
2054+
elif isinstance(arg, SparseArray) and not isinstance(arg, COO):
2055+
args[i] = COO(arg)
2056+
elif not isinstance(arg, COO):
20422057
raise ValueError("Performing this operation would produce "
20432058
"a dense result: %s" % str(func))
20442059

20452060
# Filter out scalars as they are 'baked' into the function.
2046-
func = _posarg_partial(func, pos, posargs)
2061+
func = PositinalArgumentPartial(func, pos, posargs)
20472062
args = list(filter(lambda arg: not isscalar(arg), args))
20482063

20492064
if len(args) == 0:
@@ -2124,14 +2139,14 @@ def _match_coo(*args):
21242139
21252140
Parameters
21262141
----------
2127-
args : tuple[COO]
2142+
args : Tuple[COO]
21282143
The input :obj:`COO` arrays.
21292144
21302145
Returns
21312146
-------
2132-
matched_idx : list[ndarray]
2147+
matched_idx : List[ndarray]
21332148
The indices of matched elements in the original arrays.
2134-
matched_arrays : list[COO]
2149+
matched_arrays : List[COO]
21352150
The expanded, matched :obj:`COO` objects.
21362151
"""
21372152
# If there's only a single input, return as-is.
@@ -2166,7 +2181,7 @@ def _match_coo(*args):
21662181
matched_coords = [c[:, idx] for c, idx in zip(coords, matched_indices)]
21672182

21682183
# Add the matched part.
2169-
matched_coords = _get_nary_matching_coords(matched_coords, broadcast_params, matched_shape)
2184+
matched_coords = _get_matching_coords(matched_coords, broadcast_params, matched_shape)
21702185
matched_data = [d[idx] for d, idx in zip(data, matched_indices)]
21712186

21722187
matched_indices = [sidx[midx] for sidx, midx in zip(sorted_indices, matched_indices)]
@@ -2206,7 +2221,7 @@ def _unmatch_coo(func, args, mask, **kwargs):
22062221
posargs = [_zero_of_dtype(arg.dtype)[()] for arg, m in zip(args, mask) if not m]
22072222
result_shape = _get_nary_broadcast_shape(*[arg.shape for arg in args])
22082223

2209-
partial = _posarg_partial(func, pos, posargs)
2224+
partial = PositinalArgumentPartial(func, pos, posargs)
22102225
matched_func = partial(*[a.data for a in matched_arrays], **kwargs)
22112226

22122227
unmatched_mask = matched_func != _zero_of_dtype(matched_func.dtype)
@@ -2468,7 +2483,7 @@ def _elemwise_unary(func, self, *args, **kwargs):
24682483
sorted=self.sorted)
24692484

24702485

2471-
def _get_nary_matching_coords(coords, params, shape):
2486+
def _get_matching_coords(coords, params, shape):
24722487
"""
24732488
Get the matching coords across a number of broadcast operands.
24742489
@@ -2481,7 +2496,7 @@ def _get_nary_matching_coords(coords, params, shape):
24812496
Returns
24822497
-------
24832498
numpy.ndarray
2484-
The broacasted coordinates.
2499+
The broacasted coordinates
24852500
"""
24862501
matching_coords = []
24872502
dims = np.zeros(len(params), dtype=np.uint8)
@@ -2516,35 +2531,3 @@ def _linear_loc(coords, shape, signed=False):
25162531
np.add(tmp, out, out=out)
25172532
strides *= d
25182533
return out
2519-
2520-
2521-
def _posarg_partial(func, pos, posargs):
2522-
if not isinstance(pos, Iterable):
2523-
pos = (pos,)
2524-
posargs = (posargs,)
2525-
2526-
n_partial_args = len(pos)
2527-
2528-
class Partial(object):
2529-
def __call__(self, *args, **kwargs):
2530-
j = 0
2531-
totargs = []
2532-
2533-
for i in range(len(args) + n_partial_args):
2534-
if j >= n_partial_args or i != pos[j]:
2535-
totargs.append(args[i - j])
2536-
else:
2537-
totargs.append(posargs[j])
2538-
j += 1
2539-
2540-
return func(*totargs, **kwargs)
2541-
2542-
def __str__(self):
2543-
return str(func)
2544-
2545-
def __repr__(self):
2546-
return repr(func)
2547-
2548-
__doc__ = func.__doc__
2549-
2550-
return Partial()

sparse/tests/test_coo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def test_elemwise_scalar(func, scalar, convert_to_np_number):
309309
(operator.le, 3),
310310
(operator.eq, 1),
311311
])
312-
@pytest.mark.parametrize('convert_to_np_number', [True])
312+
@pytest.mark.parametrize('convert_to_np_number', [True, False])
313313
def test_leftside_elemwise_scalar(func, scalar, convert_to_np_number):
314314
xs = sparse.random((2, 3, 4), density=0.5)
315315
if convert_to_np_number:

sparse/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from numbers import Integral
3+
from collections import Iterable
34

45

56
def assert_eq(x, y, **kwargs):
@@ -151,3 +152,39 @@ def random(
151152
def isscalar(x):
152153
from .sparse_array import SparseArray
153154
return not isinstance(x, SparseArray) and np.isscalar(x)
155+
156+
157+
class PositinalArgumentPartial(object):
158+
def __init__(self, func, pos, posargs):
159+
if not isinstance(pos, Iterable):
160+
pos = (pos,)
161+
posargs = (posargs,)
162+
163+
n_partial_args = len(pos)
164+
165+
self.pos = pos
166+
self.posargs = posargs
167+
self.func = func
168+
169+
self.n = n_partial_args
170+
171+
self.__doc__ = func.__doc__
172+
173+
def __call__(self, *args, **kwargs):
174+
j = 0
175+
totargs = []
176+
177+
for i in range(len(args) + self.n):
178+
if j >= self.n or i != self.pos[j]:
179+
totargs.append(args[i - j])
180+
else:
181+
totargs.append(self.posargs[j])
182+
j += 1
183+
184+
return self.func(*totargs, **kwargs)
185+
186+
def __str__(self):
187+
return str(self.func)
188+
189+
def __repr__(self):
190+
return repr(self.func)

0 commit comments

Comments
 (0)