Skip to content

Added multi-axis advanced indexing support #343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 97 additions & 14 deletions sparse/_coo/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def getitem(x, index):
from .core import COO

# If string, this is an index into an np.void

# Custom dtype.
if isinstance(index, str):
data = x.data[index]
Expand Down Expand Up @@ -86,6 +87,7 @@ def getitem(x, index):
i = 0

sorted = adv_idx is None or adv_idx.pos == 0
adv_idx_added = False
for ind in index:
# Nothing is added to shape or coords if the index is an integer.
if isinstance(ind, Integral):
Expand All @@ -100,8 +102,10 @@ def getitem(x, index):
sorted = False
# Add the index and shape for the advanced index.
elif isinstance(ind, np.ndarray):
shape.append(adv_idx.length)
coords.append(adv_idx.idx)
if not adv_idx_added:
shape.append(adv_idx.length)
coords.append(adv_idx.idx)
adv_idx_added = True
i += 1
# Add a dimension for None.
elif ind is None:
Expand Down Expand Up @@ -141,20 +145,37 @@ def _mask(coords, indices, shape):

if len(adv_idx) != 0:
if len(adv_idx) != 1:
raise IndexError(
"Only indices with at most one iterable index are supported."

# Ensure if multiple advanced indices are passed, all are of the same length
# Also check each advanced index to ensure each is only a one-dimensional iterable
adv_ix_len = len(adv_idx[0])
for ai in adv_idx:
if len(ai) != adv_ix_len:
raise IndexError(
"shape mismatch: indexing arrays could not be broadcast together. Ensure all indexing arrays are of the same length."
)
if ai.ndim != 1:
raise IndexError("Only one-dimensional iterable indices supported.")

mask, aidxs = _compute_multi_axis_multi_mask(
coords,
_ind_ar_from_indices(indices),
np.array(adv_idx, dtype=np.intp),
np.array(adv_idx_pos, dtype=np.intp),
)
return mask, _AdvIdxInfo(aidxs, adv_idx_pos, adv_ix_len)

adv_idx = adv_idx[0]
adv_idx_pos = adv_idx_pos[0]
else:
adv_idx = adv_idx[0]
adv_idx_pos = adv_idx_pos[0]

if adv_idx.ndim != 1:
raise IndexError("Only one-dimensional iterable indices supported.")
if adv_idx.ndim != 1:
raise IndexError("Only one-dimensional iterable indices supported.")

mask, aidxs = _compute_multi_mask(
coords, _ind_ar_from_indices(indices), adv_idx, adv_idx_pos
)
return mask, _AdvIdxInfo(aidxs, adv_idx_pos, len(adv_idx))
mask, aidxs = _compute_multi_mask(
coords, _ind_ar_from_indices(indices), adv_idx, adv_idx_pos
)
return mask, _AdvIdxInfo(aidxs, adv_idx_pos, len(adv_idx))

mask, is_slice = _compute_mask(coords, _ind_ar_from_indices(indices))

Expand Down Expand Up @@ -276,6 +297,68 @@ def _separate_adv_indices(indices):
return new_idx, adv_idx, adv_idx_pos


@numba.jit(nopython=True, nogil=True)
def _compute_multi_axis_multi_mask(
coords, indices, adv_idx, adv_idx_pos
): # pragma: no cover
"""
Computes a mask with the advanced index, and also returns the advanced index
dimension.

Parameters
----------
coords : np.ndarray
Coordinates of the input array.
indices : np.ndarray
The indices in slice format.
adv_idx : np.ndarray
List of advanced indices.
adv_idx_pos : np.ndarray
The position of the advanced indices.

Returns
-------
mask : np.ndarray
The mask.
aidxs : np.ndarray
The advanced array index.
"""
n_adv_idx = len(adv_idx_pos)
mask = numba.typed.List.empty_list(numba.types.intp)
a_indices = numba.typed.List.empty_list(numba.types.intp)
full_idx = np.empty((len(indices) + len(adv_idx_pos), 3), dtype=np.intp)

# Get location of non-advanced indices
if len(indices) != 0:
ixx = 0
for ix in range(coords.shape[0]):
isin = False
for ax in adv_idx_pos:
if ix == ax:
isin = True
break
if not isin:
full_idx[ix] = indices[ixx]
ixx += 1

for i in range(len(adv_idx[0])):
for ii in range(n_adv_idx):
full_idx[adv_idx_pos[ii]] = [adv_idx[ii][i], adv_idx[ii][i] + 1, 1]

partial_mask, is_slice = _compute_mask(coords, full_idx)
if is_slice:
slice_mask = numba.typed.List.empty_list(numba.types.intp)
for j in range(partial_mask[0], partial_mask[1]):
slice_mask.append(j)
partial_mask = array_from_list_intp(slice_mask)

for j in range(len(partial_mask)):
mask.append(partial_mask[j])
a_indices.append(i)

return array_from_list_intp(mask), array_from_list_intp(a_indices)


@numba.jit(nopython=True, nogil=True)
def _compute_multi_mask(coords, indices, adv_idx, adv_idx_pos): # pragma: no cover
"""
Expand All @@ -288,9 +371,9 @@ def _compute_multi_mask(coords, indices, adv_idx, adv_idx_pos): # pragma: no co
Coordinates of the input array.
indices : np.ndarray
The indices in slice format.
adv_idx : int
adv_idx : list(int)
The advanced index.
adv_idx_pos : int
adv_idx_pos : list(int)
The position of the advanced index.

Returns
Expand Down
5 changes: 4 additions & 1 deletion sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,10 @@ def test_gt():
(1, Ellipsis, None),
(1, 1, 1, Ellipsis),
(Ellipsis, 1, None),
# With multi-axis advanced indexing
([0, 1],) * 2,
([0, 1], [0, 2]),
([0, 0, 0], [0, 1, 2], [1, 2, 1]),
# Pathological - Slices larger than array
(slice(None, 1000)),
(slice(None), slice(None, 1000)),
Expand Down Expand Up @@ -1336,7 +1340,6 @@ def test_custom_dtype_slicing():
0.5,
[0.5],
{"potato": "kartoffel"},
([0, 1],) * 2,
([[0, 1]],),
],
)
Expand Down