Skip to content

Commit 222fb8c

Browse files
authored
Numba indexing speedups (#128)
* Initial attempt at speeding up indexing with Numba. * Separate out starts and stops into different variables. * Get into mostly-working state. * Add benchmark. * Fix cache. * Fix flake8 * Fix coverage. * Clean up slicing. * Improve performance on large slices. * Coverage. * Update docstrings. * Weird numba bug when dividing by zero. * Add comments for easier review. * Make scalar slices a "view" of sorts. * Refactor a bit. * Undo "view". Too many things could prevent it from properly working. * Perform concatenation and filtering in a single function. Consumes less memory, too. * Refactor and explain all steps. * Better explanations and more comprehensive benchmark suite. * More tests. * More pathological cases. * Missing line breaks.
1 parent 5f19f0a commit 222fb8c

File tree

6 files changed

+618
-178
lines changed

6 files changed

+618
-178
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ htmlcov/
4545
nosetests.xml
4646
coverage.xml
4747
*,cover
48+
.pytest_cache/
49+
50+
# Airspeed velocity
51+
.asv/
4852

4953
# Translations
5054
*.mo

benchmarks/benchmark_coo.py

+26
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@ def setup(self):
1212
self.x.sum_duplicates()
1313
self.y.sum_duplicates()
1414

15+
self.x + self.y # Numba compilation
16+
1517
def time_add(self):
1618
self.x + self.y
1719

1820
def time_mul(self):
1921
self.x * self.y
2022

23+
def time_index(self):
24+
self.x[5]
25+
2126

2227
class ElemwiseBroadcastingSuite(object):
2328
def setup(self):
@@ -33,3 +38,24 @@ def time_add(self):
3338

3439
def time_mul(self):
3540
self.x * self.y
41+
42+
43+
class IndexingSuite(object):
44+
def setup(self):
45+
np.random.seed(0)
46+
self.x = sparse.random((100, 100, 100), density=0.01)
47+
self.x.sum_duplicates()
48+
49+
self.x[5] # Numba compilation
50+
51+
def time_index_scalar(self):
52+
self.x[5]
53+
54+
def time_index_slice(self):
55+
self.x[:50]
56+
57+
def time_index_slice2(self):
58+
self.x[:50, :50]
59+
60+
def time_index_slice3(self):
61+
self.x[:50, :50, :50]

sparse/coo/core.py

+2-111
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
import numbers
21
from collections import Iterable, defaultdict, deque
32

43
import numpy as np
54
import scipy.sparse
65
from numpy.lib.mixins import NDArrayOperatorsMixin
76

87
from .common import dot
8+
from .indexing import getitem
99
from .umath import elemwise, broadcast_to
1010
from ..compatibility import int, range
11-
from ..slicing import normalize_index
1211
from ..sparse_array import SparseArray
1312
from ..utils import _zero_of_dtype
1413

@@ -487,93 +486,7 @@ def __len__(self):
487486
def __sizeof__(self):
488487
return self.nbytes
489488

490-
def __getitem__(self, index):
491-
if not isinstance(index, tuple):
492-
if isinstance(index, str):
493-
data = self.data[index]
494-
idx = np.where(data)
495-
coords = list(self.coords[:, idx[0]])
496-
coords.extend(idx[1:])
497-
498-
return COO(coords, data[idx].flatten(),
499-
shape=self.shape + self.data.dtype[index].shape,
500-
has_duplicates=self.has_duplicates,
501-
sorted=self.sorted)
502-
else:
503-
index = (index,)
504-
505-
last_ellipsis = len(index) > 0 and index[-1] is Ellipsis
506-
index = normalize_index(index, self.shape)
507-
if len(index) != 0 and all(not isinstance(ind, Iterable) and ind == slice(None) for ind in index):
508-
return self
509-
mask = np.ones(self.nnz, dtype=np.bool)
510-
for i, ind in enumerate([i for i in index if i is not None]):
511-
if not isinstance(ind, Iterable) and ind == slice(None):
512-
continue
513-
mask &= _mask(self.coords[i], ind, self.shape[i])
514-
515-
n = mask.sum()
516-
coords = []
517-
shape = []
518-
i = 0
519-
for ind in index:
520-
if isinstance(ind, numbers.Integral):
521-
i += 1
522-
continue
523-
elif isinstance(ind, slice):
524-
step = ind.step if ind.step is not None else 1
525-
if step > 0:
526-
start = ind.start if ind.start is not None else 0
527-
start = max(start, 0)
528-
stop = ind.stop if ind.stop is not None else self.shape[i]
529-
stop = min(stop, self.shape[i])
530-
if start > stop:
531-
start = stop
532-
shape.append((stop - start + step - 1) // step)
533-
else:
534-
start = ind.start or self.shape[i] - 1
535-
stop = ind.stop if ind.stop is not None else -1
536-
start = min(start, self.shape[i] - 1)
537-
stop = max(stop, -1)
538-
if start < stop:
539-
start = stop
540-
shape.append((start - stop - step - 1) // (-step))
541-
542-
dt = np.min_scalar_type(min(-(dim - 1) if dim != 0 else -1 for dim in shape))
543-
coords.append((self.coords[i, mask].astype(dt) - start) // step)
544-
i += 1
545-
elif isinstance(ind, Iterable):
546-
old = self.coords[i][mask]
547-
new = np.empty(shape=old.shape, dtype=old.dtype)
548-
for j, item in enumerate(ind):
549-
new[old == item] = j
550-
coords.append(new)
551-
shape.append(len(ind))
552-
i += 1
553-
elif ind is None:
554-
coords.append(np.zeros(n))
555-
shape.append(1)
556-
557-
for j in range(i, self.ndim):
558-
coords.append(self.coords[j][mask])
559-
shape.append(self.shape[j])
560-
561-
if coords:
562-
coords = np.stack(coords, axis=0)
563-
else:
564-
if last_ellipsis:
565-
coords = np.empty((0, np.sum(mask)), dtype=np.uint8)
566-
else:
567-
if np.sum(mask) != 0:
568-
return self.data[mask][0]
569-
else:
570-
return _zero_of_dtype(self.dtype)[()]
571-
shape = tuple(shape)
572-
data = self.data[mask]
573-
574-
return COO(coords, data, shape=shape,
575-
has_duplicates=self.has_duplicates,
576-
sorted=self.sorted)
489+
__getitem__ = getitem
577490

578491
def __str__(self):
579492
return "<COO: shape=%s, dtype=%s, nnz=%d, sorted=%s, duplicates=%s>" % (
@@ -1572,28 +1485,6 @@ def _keepdims(original, new, axis):
15721485
return new.reshape(shape)
15731486

15741487

1575-
def _mask(coords, idx, shape):
1576-
if isinstance(idx, numbers.Integral):
1577-
return coords == idx
1578-
elif isinstance(idx, slice):
1579-
step = idx.step if idx.step is not None else 1
1580-
if step > 0:
1581-
start = idx.start if idx.start is not None else 0
1582-
stop = idx.stop if idx.stop is not None else shape
1583-
return (coords >= start) & (coords < stop) & \
1584-
(coords % step == start % step)
1585-
else:
1586-
start = idx.start if idx.start is not None else (shape - 1)
1587-
stop = idx.stop if idx.stop is not None else -1
1588-
return (coords <= start) & (coords > stop) & \
1589-
(coords % step == start % step)
1590-
elif isinstance(idx, Iterable):
1591-
mask = np.zeros(len(coords), dtype=np.bool)
1592-
for item in idx:
1593-
mask |= _mask(coords, item, shape)
1594-
return mask
1595-
1596-
15971488
def _grouped_reduce(x, groups, method, **kwargs):
15981489
"""
15991490
Performs a :code:`ufunc` grouped reduce.

0 commit comments

Comments
 (0)