-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Open
Description
Description
Samples with multiple steps like Slice could be written to loop in PyTensor. On Numba/JAX backends this should be much faster for models with cheap logp as we can skip Python altogether.
Simple-step samplers like Metropolis, when used alone (not blocked with other step samplers), could also benefit from this optimization.
Historically it didn't make much sense because Scan in the C-backend is pretty slow anyway. JAX and Numba fare much better, although there is room for improvement. Maybe only relevant after #7926
Could also write them in rust, which we know can interop well with numba as nutpie showed.