You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The solution is to either revert 4235ccc or tweak the JAX Broadcast funcify to work with scalars.
The source of the problem is actually silly, we have a BroadcastTo(np.array(0), np.array(0).shape) in the graph that should be optimized away. I think this specific case may no longer be introduced after #361
Just not spinning up a PR immediately because we should probably discuss whether we want to implicitly downcast constant 0d array to float/integers in JAX or keep types consistent. We could always tweak the specific Op dispatch functions to handle Constant TensorScalarVariable in a special way.
This touches on a more general question of handling scalars in our graphs, that also applies to other backends. See #107 and #349
ricardoV94
changed the title
BUG: Forward sampling fails for half-normal distribution when mode="JAX"
HalfNormal in JAX failing due to implicit downcasting of constant 0d TensorVariable to float
Jul 7, 2023
Describe the issue:
You can't forward sample from a half-normal distribution in JAX mode
Reproduceable code example:
Error message:
PyMC version information:
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: