-
Notifications
You must be signed in to change notification settings - Fork 3k
boolean indexing support #166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks for opening this! It’s something we could and should support without |
also don't sub-sample indexing tests (run them all) fixes #166
Once #169 gets in, we'll have support for basic boolean indexing, but not advanced boolean indexing. We could add support for advanced boolean indexing, like we support advanced integer indexing, but since no one has requested that yet specifically, I'd rather put it off. For some examples of what's supported, check the test cases I added, which cover things like: import jax.numpy as np
x = np.array([1, 2, 3])
x[onp.array([True, False, True])] # index by a boolean array
x[[True, False, False]] # index by a list of booleans
y = np.array([[1, 2, 3], [4, 5, 6]])
y[[True, False, False]] # index by a list of booleans on leading dimensions
from jax import jit
jit(lambda i: x[i])(np.array([True, False, False])) # error!
# IndexError: Array boolean indices must be static (e.g. no dependence on an argument to a jit or vmap function). |
I just ran into the @jit
def fun(data):
norms = calculate_norms(data)
data[:, norms > threshold] Would it be possible to implement this case? |
The
README
mentions boolean index slicesa[idx]
in relation to jit-able functions as something to avoid, but currently it seems they do not work in jax at allraises
IndexError: Indexing mode not yet supported. Open a feature request!
is this expected?
The text was updated successfully, but these errors were encountered: