Skip to content

Implement jax.numpy.<fun>.accumulate #11281

Closed
@joeryjoery

Description

@joeryjoery

I was trying to run a cumulative minimum when I realized that this isn't part of the main jax.numpy API, while it is the case for numpy.

For now I've managed simply swap it out with:

import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(0)
values = jax.random.normal(key, shape=(10,))

jax.lax.associative_scan(jax.numpy.minimum, values)
>> DeviceArray([-0.3721109, -0.3721109, -0.3721109, -0.7368197, -0.7368197,
             -0.7368197, -0.7368197, -0.7368197, -0.7368197, -0.7368197],            dtype=float32)

# Numpy API:
np.minimum.accumulate(values)
>> array([-0.3721109, -0.3721109, -0.3721109, -0.7368197, -0.7368197,
       -0.7368197, -0.7368197, -0.7368197, -0.7368197, -0.7368197],
      dtype=float32)

So, I have a functional workaround, but this is not as clean as the main numpy API.

Is this perhaps related to #1565?

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions