-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
AdvanceSubtensor Error with JAX sampler #4431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks! Can you post a minimal example to reproduce the error? |
Working on getting generic/alternative code that will trigger the error. Unfortunately, I can't share the code that generated that error. |
@muunetheus Yes, shared variables are currently not supported in JAX, something we need to fix. |
I think the only situation in which shared variables weren't working was random sampling, but the others should work. We definitely need a MWE. |
If the only remaining issue involves shared variables, then this issue is a duplicate of #4142. |
When using JAX num_pyro NUTS sampler the following error is produced:
MissingInputError: Input 1 of the graph (indices start from 0), used to compute AdvancedSubtensor1(beta_mat ~ Deterministic, <TensorType(int32, vector)>), was not provided and not given a value. Use the Theano flag exception_verbosity='high', for more information on this error.
This error does not occur when using the standard sampler.
Tested with PYMC3 3.10.0 and 3.11.0.
The text was updated successfully, but these errors were encountered: