diff --git a/docs/contributors.rst b/docs/contributors.rst index 9554373417..6e134e381f 100644 --- a/docs/contributors.rst +++ b/docs/contributors.rst @@ -36,3 +36,4 @@ Contributors include: - Lorenz Gaertner - Melissa Weber Mendonça - Matthias Bussonnier +- Peter Fackeldey diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index 2a85006c04..e616302446 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -2,6 +2,7 @@ config.update('jax_enable_x64', True) +from jax.core import Tracer from jax import Array import jax.numpy as jnp from jax.scipy.special import gammaln, xlogy @@ -14,6 +15,13 @@ log = logging.getLogger(__name__) +def _currently_jitting(): + """ + JAX turns arrays into Tracers during jit-compilation, so check for that. + """ + return isinstance(jnp.array(1), Tracer) + + class _BasicPoisson: def __init__(self, rate): self.rate = rate @@ -184,6 +192,9 @@ def conditional(self, predicate, true_callable, false_callable): return true_callable() if predicate else false_callable() def tolist(self, tensor_in): + if _currently_jitting(): + # .aval is the abstract value and has a little nicer representation + return tensor_in.aval try: return jnp.asarray(tensor_in).tolist() except (TypeError, ValueError): diff --git a/tests/test_backends.py b/tests/test_backends.py index 518a8a759b..dbde580cc1 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -84,3 +84,21 @@ def test_backend_array_type(backend): def test_tensor_array_types(): # can't really assert the content of them so easily assert pyhf.tensor.array_types + + +@pytest.mark.only_jax +def test_jax_data_shape_mismatch_during_jitting(backend): + """ + Validate that during JAX tracing time pyhf doesn't try + to convert the data to a list, which is not possible with tracers, + for a shape mismatch. + Instead, return the tracer itself for a proper error message. + Issue: https://github.com/scikit-hep/pyhf/issues/1422 + PR: https://github.com/scikit-hep/pyhf/pull/2580 + """ + model = pyhf.simplemodels.uncorrelated_background([10], [15], [5]) + with pytest.raises( + pyhf.exceptions.InvalidPdfData, + match="eval failed as data has len 1 but 2 was expected", + ): + pyhf.infer.mle.fit([12.5], model)