Skip to content

Revert patch for index underflow after location 127 when mode=JAX #395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jessegrabowski opened this issue Jul 24, 2023 · 9 comments · Fixed by #767
Closed

Revert patch for index underflow after location 127 when mode=JAX #395

jessegrabowski opened this issue Jul 24, 2023 · 9 comments · Fixed by #767

Comments

@jessegrabowski
Copy link
Member

jessegrabowski commented Jul 24, 2023

Describe the issue:

It seems the JAX linker downcasts index constants to uint8? mode=None and mode="NUMBA" work as expected. Declaring an index variable (i = pt.lscalar('i'); z = x[i]) also works as expected.

Reproducable code example:

import pytensor
import pytensor.tensor as pt
import numpy as np

x = pt.dvector('x')
z1 = x[127]
z2 = x[128]
f = pytensor.function([x], [z1, z2], mode='JAX')

f(np.arange(200))
# out: [Array(127., dtype=float64), Array(0., dtype=float64)]

Error message:

No response

PyTensor version information:

Pytensor 2.13.1

Context for the issue:

No response

@jessegrabowski jessegrabowski added the bug Something isn't working label Jul 24, 2023
@ricardoV94
Copy link
Member

Now that I didn't expect xD

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 25, 2023

Seems to be a JAX bug?

import jax
import jax.numpy as jnp
import numpy as np

def subtensor(x):
    return x[np.array(128, dtype="uint8")]

subtensor(jnp.arange(200))  # Array(199, dtype=int32)
def subtensor(x):
    return x[np.array(128, dtype="uint16")]

subtensor(jnp.arange(200))  # Array(128, dtype=int32)

Opened an issue: jax-ml/jax#16836

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 25, 2023

On our side we can exclude the rewrite local_uint_constant_indices on JAX mode here:

JAX = Mode(
JAXLinker(),
RewriteDatabaseQuery(
include=["fast_run", "jax"],
exclude=["cxx_only", "BlasOpt", "fusion", "inplace"],
),
)

JAX = Mode(
    JAXLinker(),
    RewriteDatabaseQuery(
        include=["fast_run", "jax"],
        exclude=["cxx_only", "BlasOpt", "fusion", "inplace", "local_uint_constant_indices"],
    ),
)

@jessegrabowski
Copy link
Member Author

I think I was wrong though, it's casting to int8 not uint8 (int8 goes to 128. uint8 should go to 255). Wouldn't dropping the minus sign break negative indexes?

@ricardoV94
Copy link
Member

I stepped on the debugger and it was keeping the same type as the one reported in the dprint of the compiled function. I didn't see any implicit casting, only the one done explicitly by the rewrite mentioned above.

@ricardoV94
Copy link
Member

import pytensor
import pytensor.tensor as pt

x = pt.dvector('x')
z = x[128]

pytensor.dprint(z, print_type=True)

pytensor.config.optimizer_verbose=True
f = pytensor.function([x], z, mode='JAX')

pytensor.dprint(f, print_type=True)
Subtensor{i} [id A] <Scalar(float64, shape=())>
 ├─ x [id B] <Vector(float64, shape=(?,))>
 └─ 128 [id C] <int64>

rewriting: rewrite local_uint_constant_indices replaces Subtensor{i}.0 of Subtensor{i}(x, 128) with Subtensor{i}.0 of Subtensor{i}(x, 128)

DeepCopyOp [id A] <Scalar(float64, shape=())> 1
 └─ Subtensor{i} [id B] <Scalar(float64, shape=())> 0
    ├─ x [id C] <Vector(float64, shape=(?,))>
    └─ 128 [id D] <uint8>

@jessegrabowski
Copy link
Member Author

I guess I'm asking why we rewrite to uint8 then? Isn't it needlessly restrictive?

@ricardoV94
Copy link
Member

The rewrite checks the minimum type that can can still fit the indexing, because in most systems it's faster that way. It was introduced in aesara-devs/aesara#1150

@ricardoV94
Copy link
Member

This was fixed in JAX, but may be worth waiting a while longer before reverting the patch

@ricardoV94 ricardoV94 changed the title BUG: index underflow after location 127 when mode=JAX Revert patch for index underflow after location 127 when mode=JAX Dec 7, 2023
@ricardoV94 ricardoV94 added beginner friendly jax maintenance and removed bug Something isn't working labels Dec 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants