diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index 2a85006c04..855aa93f6f 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -157,6 +157,8 @@ def tile(self, tensor_in, repeats): Returns: JAX ndarray: The tensor with repeated axes """ + if not isinstance(tensor_in, jnp.ndarray): + tensor_in = jnp.array(tensor_in) return jnp.tile(tensor_in, repeats) def conditional(self, predicate, true_callable, false_callable):