Closed
Description
I was trying to run a cumulative minimum when I realized that this isn't part of the main jax.numpy
API, while it is the case for numpy
.
For now I've managed simply swap it out with:
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
values = jax.random.normal(key, shape=(10,))
jax.lax.associative_scan(jax.numpy.minimum, values)
>> DeviceArray([-0.3721109, -0.3721109, -0.3721109, -0.7368197, -0.7368197,
-0.7368197, -0.7368197, -0.7368197, -0.7368197, -0.7368197], dtype=float32)
# Numpy API:
np.minimum.accumulate(values)
>> array([-0.3721109, -0.3721109, -0.3721109, -0.7368197, -0.7368197,
-0.7368197, -0.7368197, -0.7368197, -0.7368197, -0.7368197],
dtype=float32)
So, I have a functional workaround, but this is not as clean as the main numpy API.
Is this perhaps related to #1565?