Skip to content

add non-advanced boolean indexing support #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ def _rewriting_take(arr, idx, axis=0):

# Handle slice index (only static, otherwise an error is raised)
elif isinstance(idx, slice):
if not _all(elt is None or isinstance(core.get_aval(elt), ConcreteArray)
if not _all(elt is None or type(core.get_aval(elt)) is ConcreteArray
for elt in (idx.start, idx.stop, idx.step)):
msg = ("Array slice indices must have static start/stop/step to be used "
"with Numpy indexing syntax. Try lax.dynamic_slice instead.")
Expand All @@ -1448,6 +1448,27 @@ def _rewriting_take(arr, idx, axis=0):
result = lax.slice_in_dim(arr, start, limit, stride, axis=axis)
return lax.rev(result, [axis]) if needs_rev else result

# Handle non-advanced bool index (only static, otherwise an error is raised)
elif (isinstance(abstract_idx, ShapedArray) and onp.issubdtype(abstract_idx.dtype, onp.bool_)
or isinstance(idx, list) and _all(not _shape(e) and onp.issubdtype(_dtype(e), onp.bool_)
for e in idx)):
if isinstance(idx, list):
idx = array(idx)
abstract_idx = core.get_aval(idx)

if not type(abstract_idx) is ConcreteArray:
msg = ("Array boolean indices must be static (e.g. no dependence on an "
"argument to a jit or vmap function).")
raise IndexError(msg)
else:
if idx.ndim > arr.ndim or idx.shape != arr.shape[:idx.ndim]:
msg = "Boolean index shape did not match indexed array shape prefix."
raise IndexError(msg)
else:
reshaped_arr = arr.reshape((-1,) + arr.shape[idx.ndim:])
int_idx, = onp.where(idx.ravel())
return lax.index_take(reshaped_arr, (int_idx,), (0,))

# Handle non-advanced tuple indices by recursing once
elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx):
canonical_idx = _canonicalize_tuple_index(arr, idx)
Expand Down Expand Up @@ -1487,10 +1508,11 @@ def _rewriting_take(arr, idx, axis=0):
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
elif _is_advanced_int_indexer(idx):
canonical_idx = _canonicalize_tuple_index(arr, tuple(idx))
idx_noadvanced = [slice(None) if _is_int(e) else e for e in canonical_idx]
idx_noadvanced = [slice(None) if _is_int_arraylike(e) else e
for e in canonical_idx]
arr_sliced = _rewriting_take(arr, tuple(idx_noadvanced))

advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx) if _is_int(e))
advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx) if _is_int_arraylike(e))
idx_advanced, axes = zip(*advanced_pairs)
idx_advanced = broadcast_arrays(*idx_advanced)

Expand Down Expand Up @@ -1522,11 +1544,11 @@ def _is_advanced_int_indexer(idx):
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
if isinstance(idx, (tuple, list)):
# We assume this check comes *after* the check for non-advanced tuple index,
# and hence we already know at least one element is a sequence
return _all(e is None or e is Ellipsis or isinstance(e, slice) or _is_int(e)
for e in idx)
# and hence we already know at least one element is a sequence if it's a tuple
return _all(e is None or e is Ellipsis or isinstance(e, slice)
or _is_int_arraylike(e) for e in idx)
else:
return _is_int(idx)
return _is_int_arraylike(idx)


def _is_advanced_int_indexer_without_slices(idx):
Expand All @@ -1539,11 +1561,11 @@ def _is_advanced_int_indexer_without_slices(idx):
return True


def _is_int(x):
def _is_int_arraylike(x):
"""Returns True if x is array-like with integer dtype, False otherwise."""
return (isinstance(x, int) and not isinstance(x, bool)
or onp.issubdtype(getattr(x, "dtype", None), onp.integer)
or isinstance(x, (list, tuple)) and _all(_is_int(e) for e in x))
or isinstance(x, (list, tuple)) and _all(_is_int_arraylike(e) for e in x))


def _canonicalize_tuple_index(arr, idx):
Expand Down
82 changes: 62 additions & 20 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None):
class IndexingTest(jtu.JaxTestCase):
"""Tests for Numpy indexing translation rules."""

@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name":
"{}_inshape={}_indexer={}".format(
name, jtu.format_shape_dtype_string( shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer
@parameterized.named_parameters({
"testcase_name": "{}_inshape={}_indexer={}".format(
name, jtu.format_shape_dtype_string( shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer
} for name, index_specs in [
("OneIntIndex", [
IndexSpec(shape=(3,), indexer=1),
Expand Down Expand Up @@ -154,14 +153,14 @@ class IndexingTest(jtu.JaxTestCase):
IndexSpec(shape=(3, 4), indexer=()),
]),
] for shape, indexer in index_specs for dtype in all_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
@jtu.skip_on_devices("tpu")
def testStaticIndexing(self, shape, dtype, rng, indexer):
args_maker = lambda: [rng(shape, dtype)]
fun = lambda x: x[indexer]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list({
@parameterized.named_parameters({
"testcase_name":
"{}_inshape={}_indexer={}".format(name,
jtu.format_shape_dtype_string(
Expand Down Expand Up @@ -233,7 +232,7 @@ def testStaticIndexing(self, shape, dtype, rng, indexer):
# IndexSpec(shape=(3, 4), indexer=()),
# ]),
] for shape, indexer in index_specs for dtype in float_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
@jtu.skip_on_devices("tpu")
def testStaticIndexingGrads(self, shape, dtype, rng, indexer):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
Expand All @@ -257,7 +256,7 @@ def _ReplaceSlicesWithTuples(self, idx):
else:
return idx, lambda x: x

@parameterized.named_parameters(jtu.cases_from_list(
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer}
Expand All @@ -280,7 +279,7 @@ def _ReplaceSlicesWithTuples(self, idx):
]
for shape, indexer in index_specs
for dtype in all_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng, indexer):
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

Expand All @@ -292,7 +291,7 @@ def fun(x, unpacked_indexer):
args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
self.assertRaises(IndexError, lambda: fun(*args_maker()))

@parameterized.named_parameters(jtu.cases_from_list(
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer}
Expand All @@ -312,7 +311,7 @@ def fun(x, unpacked_indexer):
]
for shape, indexer in index_specs
for dtype in all_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
def testDynamicIndexingWithIntegers(self, shape, dtype, rng, indexer):
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

Expand All @@ -324,7 +323,7 @@ def fun(x, unpacked_indexer):
self._CompileAndCheck(fun, args_maker, check_dtypes=True)

@skip
@parameterized.named_parameters(jtu.cases_from_list(
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer}
Expand All @@ -346,7 +345,7 @@ def fun(x, unpacked_indexer):
]
for shape, indexer in index_specs
for dtype in float_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
def DISABLED_testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng, indexer):
# TODO(mattjj): re-enable (test works but for grad-of-compile, in flux)
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
Expand All @@ -360,7 +359,7 @@ def fun(unpacked_indexer, x):
arr = rng(shape, dtype)
check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol)

@parameterized.named_parameters(jtu.cases_from_list(
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer}
Expand Down Expand Up @@ -412,13 +411,13 @@ def fun(unpacked_indexer, x):
]
for shape, indexer in index_specs
for dtype in all_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
def testAdvancedIntegerIndexing(self, shape, dtype, rng, indexer):
args_maker = lambda: [rng(shape, dtype), indexer]
fun = lambda x, idx: x[idx]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer}
Expand Down Expand Up @@ -470,14 +469,14 @@ def testAdvancedIntegerIndexing(self, shape, dtype, rng, indexer):
]
for shape, indexer in index_specs
for dtype in float_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng, indexer):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
arg = rng(shape, dtype)
fun = lambda x: x[indexer]**2
check_grads(fun, (arg,), 2, tol, tol, tol)

@parameterized.named_parameters(jtu.cases_from_list(
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer}
Expand Down Expand Up @@ -533,7 +532,7 @@ def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng, indexer):
]
for shape, indexer in index_specs
for dtype in all_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng, indexer):
indexer_with_dummies = [e if isinstance(e, onp.ndarray) else ()
for e in indexer]
Expand Down Expand Up @@ -588,6 +587,49 @@ def foo(x):

self.assertAllClose(a1, a2, check_dtypes=True)

def testBooleanIndexingArray1D(self):
idx = onp.array([True, True, False])
x = api.device_put(onp.arange(3))
ans = x[idx]
expected = onp.arange(3)[idx]
self.assertAllClose(ans, expected, check_dtypes=False)

def testBooleanIndexingList1D(self):
idx = [True, True, False]
x = api.device_put(onp.arange(3))
ans = x[idx]
expected = onp.arange(3)[idx]
self.assertAllClose(ans, expected, check_dtypes=False)

def testBooleanIndexingArray2DBroadcast(self):
idx = onp.array([True, True, False, True])
x = onp.arange(8).reshape(4, 2)
ans = api.device_put(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)

def testBooleanIndexingList2DBroadcast(self):
idx = [True, True, False, True]
x = onp.arange(8).reshape(4, 2)
ans = api.device_put(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)

def testBooleanIndexingArray2D(self):
idx = onp.array([[True, False],
[False, True],
[False, False],
[True, True]])
x = onp.arange(8).reshape(4, 2)
ans = api.device_put(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)

def testBooleanIndexingDynamicShapeError(self):
x = onp.zeros(3)
i = onp.array([True, True, False])
self.assertRaises(IndexError, lambda: api.jit(lambda x, i: x[i])(x, i))


if __name__ == "__main__":
absltest.main()
3 changes: 2 additions & 1 deletion tests/scipy_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, onp.abs(a), onp.abs(b), loc, onp.abs(scale)]

self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

@genNamedParametersNArgs(3, jtu.rand_default())
Expand Down