Skip to content

Don't include local_uint_constant_indices rewrite in JAX mode due to XLA bug #400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 26, 2023

Conversation

ricardoV94
Copy link
Member

Temporarily minimizes #395

Graphs may still fail if users explicitly create uint indexes but that's not very likely, and anyway it's a JAX bug.

@ricardoV94 ricardoV94 added bug Something isn't working jax graph rewriting labels Jul 26, 2023
@ricardoV94 ricardoV94 changed the title Don't include local_uint_constant_indices rewrite in JAX mode due to XLA bug Don't include local_uint_constant_indices rewrite in JAX mode due to XLA bug Jul 26, 2023
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, it's certainty the simplest stop-gap. Do you think it would be worth it to instead temporarily disallow uint8, and make uint16 the smallest allowed? It seems like it would just need to add a check here:

if dtype == np.dtype('uint8'):
    dtype = np.dtype('uint16')

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jul 26, 2023

AFAICT this issue happens for every uint:

import jax
import jax.numpy as jnp
import numpy as np

def subtensor(x):
    return x[np.array(65534, dtype="uint16")]

subtensor(jnp.arange(100_000))  # Array(99999, dtype=int32)

I don't think it's worth the trouble. We just wait for JAX to fix it.

We would be penalizing other backends as well

@ricardoV94 ricardoV94 force-pushed the exclude_uint_index_rewrite branch from 6668ed7 to 0312140 Compare July 26, 2023 10:24
@codecov-commenter
Copy link

Codecov Report

Merging #400 (0312140) into main (14d2454) will not change coverage.
The diff coverage is n/a.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #400   +/-   ##
=======================================
  Coverage   80.47%   80.47%           
=======================================
  Files         156      156           
  Lines       45556    45556           
  Branches    11161    11161           
=======================================
  Hits        36660    36660           
  Misses       6694     6694           
  Partials     2202     2202           
Files Changed Coverage Δ
pytensor/compile/mode.py 84.40% <ø> (ø)

@ricardoV94 ricardoV94 merged commit 4459199 into pymc-devs:main Jul 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working graph rewriting jax
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants