-
Notifications
You must be signed in to change notification settings - Fork 129
Revert patch for index underflow after location 127 when mode=JAX
#395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Now that I didn't expect xD |
Seems to be a JAX bug? import jax
import jax.numpy as jnp
import numpy as np
def subtensor(x):
return x[np.array(128, dtype="uint8")]
subtensor(jnp.arange(200)) # Array(199, dtype=int32) def subtensor(x):
return x[np.array(128, dtype="uint16")]
subtensor(jnp.arange(200)) # Array(128, dtype=int32) Opened an issue: jax-ml/jax#16836 |
On our side we can exclude the rewrite pytensor/pytensor/compile/mode.py Lines 450 to 456 in 6b189ee
JAX = Mode(
JAXLinker(),
RewriteDatabaseQuery(
include=["fast_run", "jax"],
exclude=["cxx_only", "BlasOpt", "fusion", "inplace", "local_uint_constant_indices"],
),
) |
I think I was wrong though, it's casting to |
I stepped on the debugger and it was keeping the same type as the one reported in the |
import pytensor
import pytensor.tensor as pt
x = pt.dvector('x')
z = x[128]
pytensor.dprint(z, print_type=True)
pytensor.config.optimizer_verbose=True
f = pytensor.function([x], z, mode='JAX')
pytensor.dprint(f, print_type=True)
|
I guess I'm asking why we rewrite to uint8 then? Isn't it needlessly restrictive? |
The rewrite checks the minimum type that can can still fit the indexing, because in most systems it's faster that way. It was introduced in aesara-devs/aesara#1150 |
This was fixed in JAX, but may be worth waiting a while longer before reverting the patch |
mode=JAX
mode=JAX
Describe the issue:
It seems the JAX linker downcasts index constants to
uint8
?mode=None
andmode="NUMBA"
work as expected. Declaring an index variable (i = pt.lscalar('i'); z = x[i]
) also works as expected.Reproducable code example:
Error message:
No response
PyTensor version information:
Pytensor 2.13.1
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: