Description
Summary
Hello; perhaps this is known but I thought I'd file a bug report just in case. I was testing the upper_limits
API and discovered that the example given in the documentation doesn't seem to work with the JAX backend. It fails with a complaint about an unhashable array type (see the traceback). If I switch to the numpy backend, as shown in the documentation, it runs fine.
I see this on both EL7 in an ATLAS environment (StatAnalysis,0.3,latest
) and on my own desktop (Fedora 38); in both cases I have the same pyhf version (0.7.6) and I manually installed jax[CPU] == 0.4.26
on top of that.
I should add that things work fine with JAX if I use the version of upper_limits
where I pass in a range of mu values to scan-- so I guess maybe some extra type conversion is needed to go from the JAX array type to a list or something hashable?
OS / Environment
# Linux
$ cat /etc/os-release
NAME="Fedora Linux"
VERSION="38 (Thirty Eight)"
ID=fedora
VERSION_ID=38
VERSION_CODENAME=""
PLATFORM_ID="platform:f38"
PRETTY_NAME="Fedora Linux 38 (Thirty Eight)"
ANSI_COLOR="0;38;2;60;110;180"
LOGO=fedora-logo-icon
CPE_NAME="cpe:/o:fedoraproject:fedora:38"
DEFAULT_HOSTNAME="fedora"
HOME_URL="https://fedoraproject.org/"
DOCUMENTATION_URL="https://docs.fedoraproject.org/en-US/fedora/f38/system-administrators-guide/"
SUPPORT_URL="https://ask.fedoraproject.org/"
BUG_REPORT_URL="https://bugzilla.redhat.com/"
REDHAT_BUGZILLA_PRODUCT="Fedora"
REDHAT_BUGZILLA_PRODUCT_VERSION=38
REDHAT_SUPPORT_PRODUCT="Fedora"
REDHAT_SUPPORT_PRODUCT_VERSION=38
SUPPORT_END=2024-05-14
Steps to Reproduce
Install pyhf and JAX through pip; then try to run the example in the documentation, but with the JAX backend instead of numpy:
import numpy as np
import pyhf
pyhf.set_backend("JAX")
model = pyhf.simplemodels.uncorrelated_background(
signal=[12.0, 11.0], bkg=[50.0, 52.0], bkg_uncertainty=[3.0, 7.0]
)
observations = [51, 48]
data = pyhf.tensorlib.astensor(observations + model.config.auxdata)
obs_limit, exp_limits = pyhf.infer.intervals.upper_limits.toms748_scan(
data, model, 0., 5., rtol=0.01
)
File Upload (optional)
No response
Expected Results
Ideally the example would run without crashing (as it does with the numpy backend).
Actual Results
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/bjr/.local/lib/python3.11/site-packages/pyhf/infer/intervals/upper_limits.py", line 130, in toms748_scan
toms748(f, bounds_low, bounds_up, args=(level, 0), k=2, xtol=atol, rtol=rtol)
File "/usr/lib64/python3.11/site-packages/scipy/optimize/_zeros_py.py", line 1374, in toms748
result = solver.solve(f, a, b, args=args, k=k, xtol=xtol, rtol=rtol,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib64/python3.11/site-packages/scipy/optimize/_zeros_py.py", line 1229, in solve
fc = self._callf(c)
^^^^^^^^^^^^^^
File "/usr/lib64/python3.11/site-packages/scipy/optimize/_zeros_py.py", line 1083, in _callf
fx = self.f(x, *self.args)
^^^^^^^^^^^^^^^^^^^^^
File "/home/bjr/.local/lib/python3.11/site-packages/pyhf/infer/intervals/upper_limits.py", line 95, in f
f_cached(poi)[0] - level
^^^^^^^^^^^^^
File "/home/bjr/.local/lib/python3.11/site-packages/pyhf/infer/intervals/upper_limits.py", line 80, in f_cached
if poi not in cache:
^^^^^^^^^^^^^^^^
TypeError: unhashable type: 'jaxlib.xla_extension.ArrayImpl'
pyhf Version
$ pyhf --version
pyhf, version 0.7.6
Code of Conduct
- I agree to follow the Code of Conduct