Skip to content

Underflow in jax.jit(jnp.tri) #22751

@jessegrabowski

Description

@jessegrabowski

Description

Jit compiling jnp.tri appears to cause an underflow:

import jax
jax.jit(jax.numpy.tri)(3, 3, 0)
Full traceback
ValueError                                Traceback (most recent call last)
Cell In[23], line 1
----> 1 jax.jit(jax.numpy.tri)(3,3,0)

    [... skipping hidden 11 frame]

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:4823, in tri(N, M, k, dtype)
   4821 M = M if M is not None else N
   4822 dtype = dtype or float32
-> 4823 return lax_internal._tri(dtype, (N, M), k)

    [... skipping hidden 4 frame]

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:540, in min(a, axis, out, keepdims, initial, where)
    471 def min(a: ArrayLike, axis: Axis = None, out: None = None,
    472         keepdims: bool = False, initial: ArrayLike | None = None,
    473         where: ArrayLike | None = None) -> Array:
    474   r"""Return the minimum of array elements along a given axis.
    475
    476   JAX implementation of :func:`numpy.min`.
   (...)
    538     Array([[0, 0, 0, 0]], dtype=int32)
    539   """
--> 540   return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out,
    541                      keepdims=keepdims, initial=initial, where=where)

    [... skipping hidden 11 frame]

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:466, in _reduce_min(a, axis, out, keepdims, initial, where)
    462 @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
    463 def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None,
    464                 keepdims: bool = False, initial: ArrayLike | None = None,
    465                 where: ArrayLike | None = None) -> Array:
--> 466   return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
    467                     axis=axis, out=out, keepdims=keepdims,
    468                     initial=initial, where_=where, parallel_reduce=lax.pmin)

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:111, in _reduction(a, name, np_fun, op, init_val, has_identity, preproc, bool_op, upcast_f16_for_computation, axis, dtype, out, keepdims, initial, where_, parallel_reduce, promote_integers)
    109 a = a if isinstance(a, Array) else lax_internal.asarray(a)
    110 a = preproc(a) if preproc else a
--> 111 pos_dims, dims = _reduction_dims(a, axis)
    113 if initial is None and not has_identity:
    114   shape = np.shape(a)

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:160, in _reduction_dims(a, axis)
    158 elif not isinstance(axis, (np.ndarray, tuple, list)):
    159   axis = (axis,)  # type: ignore[assignment]
--> 160 canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a))
    161                    for x in axis)  # type: ignore[union-attr]
    162 if len(canon_axis) != len(set(canon_axis)):
    163   raise ValueError(f"duplicate value in 'axis': {axis}")

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:160, in <genexpr>(.0)
    158 elif not isinstance(axis, (np.ndarray, tuple, list)):
    159   axis = (axis,)  # type: ignore[assignment]
--> 160 canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a))
    161                    for x in axis)  # type: ignore[union-attr]
    162 if len(canon_axis) != len(set(canon_axis)):
    163   raise ValueError(f"duplicate value in 'axis': {axis}")

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:153, in _canonicalize_axis_allow_named(x, rank)
    152 def _canonicalize_axis_allow_named(x, rank):
--> 153   return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)

    [... skipping hidden 1 frame]

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:153, in _canonicalize_axis_allow_named.<locals>.<lambda>(i)
    152 def _canonicalize_axis_allow_named(x, rank):
--> 153   return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/util.py:389, in canonicalize_axis(axis, num_dims)
    387 axis = operator.index(axis)
    388 if not -num_dims <= axis < num_dims:
--> 389   raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
    390 if axis < 0:
    391   axis = axis + num_dims

ValueError: axis 2147483647 is out of bounds for array of dimension 0

This previously failed with a ConcretizationTypeError, which is (was?) the expected behavior for a function like this (since the output size is not static)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.30
numpy:  2.0.1
python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')


$ nvidia-smi
Tue Jul 30 19:42:09 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.76.01              Driver Version: 552.22         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce GTX 1660 Ti     On  |   00000000:02:00.0  On |                  N/A |
| 31%   38C    P8              9W /  120W |    5768MiB /   6144MiB |      8%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A   1536557      C   /python3.11                                 N/A      |
+-----------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions