Skip to content

boolean indexing support #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
lukasheinrich opened this issue Dec 21, 2018 · 3 comments · Fixed by #169
Closed

boolean indexing support #166

lukasheinrich opened this issue Dec 21, 2018 · 3 comments · Fixed by #169
Assignees
Labels
enhancement New feature or request

Comments

@lukasheinrich
Copy link

The README mentions boolean index slices a[idx] in relation to jit-able functions as something to avoid, but currently it seems they do not work in jax at all

>>> import jax.numpy as np
>>> a = np.array([1,2,3])
>>> a[[True,False,False]]

raises IndexError: Indexing mode not yet supported. Open a feature request!

is this expected?

@mattjj
Copy link
Collaborator

mattjj commented Dec 21, 2018

Thanks for opening this! It’s something we could and should support without jit, but we haven’t implemented it yet. We should add it! (Sorry if the readme is unclear on the current level of support.)

@mattjj mattjj added the enhancement New feature or request label Dec 21, 2018
@mattjj mattjj self-assigned this Dec 23, 2018
mattjj added a commit that referenced this issue Dec 23, 2018
also don't sub-sample indexing tests (run them all)
fixes #166
@mattjj
Copy link
Collaborator

mattjj commented Dec 23, 2018

Once #169 gets in, we'll have support for basic boolean indexing, but not advanced boolean indexing. We could add support for advanced boolean indexing, like we support advanced integer indexing, but since no one has requested that yet specifically, I'd rather put it off.

For some examples of what's supported, check the test cases I added, which cover things like:

import jax.numpy as np
x = np.array([1, 2, 3])

x[onp.array([True, False, True])]  # index by a boolean array
x[[True, False, False]]  # index by a list of booleans

y = np.array([[1, 2, 3], [4, 5, 6]])
y[[True, False, False]]  # index by a list of booleans on leading dimensions

from jax import jit
jit(lambda i: x[i])(np.array([True, False, False]))  # error!
# IndexError: Array boolean indices must be static (e.g. no dependence on an argument to a jit or vmap function).

@Baschdl
Copy link
Contributor

Baschdl commented May 15, 2020

I just ran into the IndexError: Array boolean indices must be static when trying

@jit
def fun(data):
    norms = calculate_norms(data)
    data[:, norms > threshold]

Would it be possible to implement this case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants