diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index db9ffe1768a0..5d4597537ff3 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -1438,7 +1438,7 @@ def _rewriting_take(arr, idx, axis=0): # Handle slice index (only static, otherwise an error is raised) elif isinstance(idx, slice): - if not _all(elt is None or isinstance(core.get_aval(elt), ConcreteArray) + if not _all(elt is None or type(core.get_aval(elt)) is ConcreteArray for elt in (idx.start, idx.stop, idx.step)): msg = ("Array slice indices must have static start/stop/step to be used " "with Numpy indexing syntax. Try lax.dynamic_slice instead.") @@ -1448,6 +1448,27 @@ def _rewriting_take(arr, idx, axis=0): result = lax.slice_in_dim(arr, start, limit, stride, axis=axis) return lax.rev(result, [axis]) if needs_rev else result + # Handle non-advanced bool index (only static, otherwise an error is raised) + elif (isinstance(abstract_idx, ShapedArray) and onp.issubdtype(abstract_idx.dtype, onp.bool_) + or isinstance(idx, list) and _all(not _shape(e) and onp.issubdtype(_dtype(e), onp.bool_) + for e in idx)): + if isinstance(idx, list): + idx = array(idx) + abstract_idx = core.get_aval(idx) + + if not type(abstract_idx) is ConcreteArray: + msg = ("Array boolean indices must be static (e.g. no dependence on an " + "argument to a jit or vmap function).") + raise IndexError(msg) + else: + if idx.ndim > arr.ndim or idx.shape != arr.shape[:idx.ndim]: + msg = "Boolean index shape did not match indexed array shape prefix." + raise IndexError(msg) + else: + reshaped_arr = arr.reshape((-1,) + arr.shape[idx.ndim:]) + int_idx, = onp.where(idx.ravel()) + return lax.index_take(reshaped_arr, (int_idx,), (0,)) + # Handle non-advanced tuple indices by recursing once elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx): canonical_idx = _canonicalize_tuple_index(arr, idx) @@ -1487,10 +1508,11 @@ def _rewriting_take(arr, idx, axis=0): # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing elif _is_advanced_int_indexer(idx): canonical_idx = _canonicalize_tuple_index(arr, tuple(idx)) - idx_noadvanced = [slice(None) if _is_int(e) else e for e in canonical_idx] + idx_noadvanced = [slice(None) if _is_int_arraylike(e) else e + for e in canonical_idx] arr_sliced = _rewriting_take(arr, tuple(idx_noadvanced)) - advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx) if _is_int(e)) + advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx) if _is_int_arraylike(e)) idx_advanced, axes = zip(*advanced_pairs) idx_advanced = broadcast_arrays(*idx_advanced) @@ -1522,11 +1544,11 @@ def _is_advanced_int_indexer(idx): # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing if isinstance(idx, (tuple, list)): # We assume this check comes *after* the check for non-advanced tuple index, - # and hence we already know at least one element is a sequence - return _all(e is None or e is Ellipsis or isinstance(e, slice) or _is_int(e) - for e in idx) + # and hence we already know at least one element is a sequence if it's a tuple + return _all(e is None or e is Ellipsis or isinstance(e, slice) + or _is_int_arraylike(e) for e in idx) else: - return _is_int(idx) + return _is_int_arraylike(idx) def _is_advanced_int_indexer_without_slices(idx): @@ -1539,11 +1561,11 @@ def _is_advanced_int_indexer_without_slices(idx): return True -def _is_int(x): +def _is_int_arraylike(x): """Returns True if x is array-like with integer dtype, False otherwise.""" return (isinstance(x, int) and not isinstance(x, bool) or onp.issubdtype(getattr(x, "dtype", None), onp.integer) - or isinstance(x, (list, tuple)) and _all(_is_int(e) for e in x)) + or isinstance(x, (list, tuple)) and _all(_is_int_arraylike(e) for e in x)) def _canonicalize_tuple_index(arr, idx): diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index c6f034ed099a..42cfb66a90e0 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -62,11 +62,10 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): class IndexingTest(jtu.JaxTestCase): """Tests for Numpy indexing translation rules.""" - @parameterized.named_parameters(jtu.cases_from_list({ - "testcase_name": - "{}_inshape={}_indexer={}".format( - name, jtu.format_shape_dtype_string( shape, dtype), indexer), - "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer + @parameterized.named_parameters({ + "testcase_name": "{}_inshape={}_indexer={}".format( + name, jtu.format_shape_dtype_string( shape, dtype), indexer), + "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in [ ("OneIntIndex", [ IndexSpec(shape=(3,), indexer=1), @@ -154,14 +153,14 @@ class IndexingTest(jtu.JaxTestCase): IndexSpec(shape=(3, 4), indexer=()), ]), ] for shape, indexer in index_specs for dtype in all_dtypes - for rng in [jtu.rand_default()])) + for rng in [jtu.rand_default()]) @jtu.skip_on_devices("tpu") def testStaticIndexing(self, shape, dtype, rng, indexer): args_maker = lambda: [rng(shape, dtype)] fun = lambda x: x[indexer] self._CompileAndCheck(fun, args_maker, check_dtypes=True) - @parameterized.named_parameters(jtu.cases_from_list({ + @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format(name, jtu.format_shape_dtype_string( @@ -233,7 +232,7 @@ def testStaticIndexing(self, shape, dtype, rng, indexer): # IndexSpec(shape=(3, 4), indexer=()), # ]), ] for shape, indexer in index_specs for dtype in float_dtypes - for rng in [jtu.rand_default()])) + for rng in [jtu.rand_default()]) @jtu.skip_on_devices("tpu") def testStaticIndexingGrads(self, shape, dtype, rng, indexer): tol = 1e-2 if onp.finfo(dtype).bits == 32 else None @@ -257,7 +256,7 @@ def _ReplaceSlicesWithTuples(self, idx): else: return idx, lambda x: x - @parameterized.named_parameters(jtu.cases_from_list( + @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer} @@ -280,7 +279,7 @@ def _ReplaceSlicesWithTuples(self, idx): ] for shape, indexer in index_specs for dtype in all_dtypes - for rng in [jtu.rand_default()])) + for rng in [jtu.rand_default()]) def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng, indexer): unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @@ -292,7 +291,7 @@ def fun(x, unpacked_indexer): args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self.assertRaises(IndexError, lambda: fun(*args_maker())) - @parameterized.named_parameters(jtu.cases_from_list( + @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer} @@ -312,7 +311,7 @@ def fun(x, unpacked_indexer): ] for shape, indexer in index_specs for dtype in all_dtypes - for rng in [jtu.rand_default()])) + for rng in [jtu.rand_default()]) def testDynamicIndexingWithIntegers(self, shape, dtype, rng, indexer): unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @@ -324,7 +323,7 @@ def fun(x, unpacked_indexer): self._CompileAndCheck(fun, args_maker, check_dtypes=True) @skip - @parameterized.named_parameters(jtu.cases_from_list( + @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer} @@ -346,7 +345,7 @@ def fun(x, unpacked_indexer): ] for shape, indexer in index_specs for dtype in float_dtypes - for rng in [jtu.rand_default()])) + for rng in [jtu.rand_default()]) def DISABLED_testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng, indexer): # TODO(mattjj): re-enable (test works but for grad-of-compile, in flux) tol = 1e-2 if onp.finfo(dtype).bits == 32 else None @@ -360,7 +359,7 @@ def fun(unpacked_indexer, x): arr = rng(shape, dtype) check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol) - @parameterized.named_parameters(jtu.cases_from_list( + @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer} @@ -412,13 +411,13 @@ def fun(unpacked_indexer, x): ] for shape, indexer in index_specs for dtype in all_dtypes - for rng in [jtu.rand_default()])) + for rng in [jtu.rand_default()]) def testAdvancedIntegerIndexing(self, shape, dtype, rng, indexer): args_maker = lambda: [rng(shape, dtype), indexer] fun = lambda x, idx: x[idx] self._CompileAndCheck(fun, args_maker, check_dtypes=True) - @parameterized.named_parameters(jtu.cases_from_list( + @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer} @@ -470,14 +469,14 @@ def testAdvancedIntegerIndexing(self, shape, dtype, rng, indexer): ] for shape, indexer in index_specs for dtype in float_dtypes - for rng in [jtu.rand_default()])) + for rng in [jtu.rand_default()]) def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng, indexer): tol = 1e-2 if onp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: x[indexer]**2 check_grads(fun, (arg,), 2, tol, tol, tol) - @parameterized.named_parameters(jtu.cases_from_list( + @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer} @@ -533,7 +532,7 @@ def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng, indexer): ] for shape, indexer in index_specs for dtype in all_dtypes - for rng in [jtu.rand_default()])) + for rng in [jtu.rand_default()]) def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng, indexer): indexer_with_dummies = [e if isinstance(e, onp.ndarray) else () for e in indexer] @@ -588,6 +587,49 @@ def foo(x): self.assertAllClose(a1, a2, check_dtypes=True) + def testBooleanIndexingArray1D(self): + idx = onp.array([True, True, False]) + x = api.device_put(onp.arange(3)) + ans = x[idx] + expected = onp.arange(3)[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingList1D(self): + idx = [True, True, False] + x = api.device_put(onp.arange(3)) + ans = x[idx] + expected = onp.arange(3)[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingArray2DBroadcast(self): + idx = onp.array([True, True, False, True]) + x = onp.arange(8).reshape(4, 2) + ans = api.device_put(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingList2DBroadcast(self): + idx = [True, True, False, True] + x = onp.arange(8).reshape(4, 2) + ans = api.device_put(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingArray2D(self): + idx = onp.array([[True, False], + [False, True], + [False, False], + [True, True]]) + x = onp.arange(8).reshape(4, 2) + ans = api.device_put(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingDynamicShapeError(self): + x = onp.zeros(3) + i = onp.array([True, True, False]) + self.assertRaises(IndexError, lambda: api.jit(lambda x, i: x[i])(x, i)) + if __name__ == "__main__": absltest.main() diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 7d7338f628fd..b0edb80c5641 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -50,7 +50,8 @@ def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, onp.abs(a), onp.abs(b), loc, onp.abs(scale)] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, + tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default())