-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Description
Jit compiling jnp.tri
appears to cause an underflow:
import jax
jax.jit(jax.numpy.tri)(3, 3, 0)
Full traceback
ValueError Traceback (most recent call last)
Cell In[23], line 1
----> 1 jax.jit(jax.numpy.tri)(3,3,0)
[... skipping hidden 11 frame]
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:4823, in tri(N, M, k, dtype)
4821 M = M if M is not None else N
4822 dtype = dtype or float32
-> 4823 return lax_internal._tri(dtype, (N, M), k)
[... skipping hidden 4 frame]
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:540, in min(a, axis, out, keepdims, initial, where)
471 def min(a: ArrayLike, axis: Axis = None, out: None = None,
472 keepdims: bool = False, initial: ArrayLike | None = None,
473 where: ArrayLike | None = None) -> Array:
474 r"""Return the minimum of array elements along a given axis.
475
476 JAX implementation of :func:`numpy.min`.
(...)
538 Array([[0, 0, 0, 0]], dtype=int32)
539 """
--> 540 return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out,
541 keepdims=keepdims, initial=initial, where=where)
[... skipping hidden 11 frame]
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:466, in _reduce_min(a, axis, out, keepdims, initial, where)
462 @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
463 def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None,
464 keepdims: bool = False, initial: ArrayLike | None = None,
465 where: ArrayLike | None = None) -> Array:
--> 466 return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
467 axis=axis, out=out, keepdims=keepdims,
468 initial=initial, where_=where, parallel_reduce=lax.pmin)
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:111, in _reduction(a, name, np_fun, op, init_val, has_identity, preproc, bool_op, upcast_f16_for_computation, axis, dtype, out, keepdims, initial, where_, parallel_reduce, promote_integers)
109 a = a if isinstance(a, Array) else lax_internal.asarray(a)
110 a = preproc(a) if preproc else a
--> 111 pos_dims, dims = _reduction_dims(a, axis)
113 if initial is None and not has_identity:
114 shape = np.shape(a)
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:160, in _reduction_dims(a, axis)
158 elif not isinstance(axis, (np.ndarray, tuple, list)):
159 axis = (axis,) # type: ignore[assignment]
--> 160 canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a))
161 for x in axis) # type: ignore[union-attr]
162 if len(canon_axis) != len(set(canon_axis)):
163 raise ValueError(f"duplicate value in 'axis': {axis}")
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:160, in <genexpr>(.0)
158 elif not isinstance(axis, (np.ndarray, tuple, list)):
159 axis = (axis,) # type: ignore[assignment]
--> 160 canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a))
161 for x in axis) # type: ignore[union-attr]
162 if len(canon_axis) != len(set(canon_axis)):
163 raise ValueError(f"duplicate value in 'axis': {axis}")
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:153, in _canonicalize_axis_allow_named(x, rank)
152 def _canonicalize_axis_allow_named(x, rank):
--> 153 return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)
[... skipping hidden 1 frame]
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:153, in _canonicalize_axis_allow_named.<locals>.<lambda>(i)
152 def _canonicalize_axis_allow_named(x, rank):
--> 153 return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/jax/_src/util.py:389, in canonicalize_axis(axis, num_dims)
387 axis = operator.index(axis)
388 if not -num_dims <= axis < num_dims:
--> 389 raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
390 if axis < 0:
391 axis = axis + num_dims
ValueError: axis 2147483647 is out of bounds for array of dimension 0
This previously failed with a ConcretizationTypeError
, which is (was?) the expected behavior for a function like this (since the output size is not static)
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.31
jaxlib: 0.4.30
numpy: 2.0.1
python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')
$ nvidia-smi
Tue Jul 30 19:42:09 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.76.01 Driver Version: 552.22 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce GTX 1660 Ti On | 00000000:02:00.0 On | N/A |
| 31% 38C P8 9W / 120W | 5768MiB / 6144MiB | 8% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 1536557 C /python3.11 N/A |
+-----------------------------------------------------------------------------------------+
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working