>>> import pyhf
>>> pyhf.set_backend("jax")
>>> m = pyhf.simplemodels.hepdata_like([10], [15], [5])
>>> pyhf.infer.mle.fit([12.5], m)
crashes like so

with a possible hint?
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using jnp together with import jax.numpy as jnp rather than using np via import numpy as np. If this error arises on a line that involves array indexing, like x[idx], it may be that the array being indexed x is a raw numpy.ndarray while the indices idx are a JAX Tracer instance; in that case, you can instead write jax.device_put(x)[idx].