-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Turn NUTS sampler into Theano Op #4210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I'm willing to take this one on if you can provide some guidance. Could be a good Christmas project. |
Great! @brandonwillard is the person with the most mature vision on this. |
This is a really important implementation, so, yeah, if you need any input from me, ask away. Otherwise, here's a template for one approach to converting our samplers to Theano The questions that arise within this task are mostly about where/when to "Theano-ize" the NUTS sampler. For instance, we could make one big NUTS The other end of the spectrum would have us rewriting the core functionality in One of the difficulties with this approach is that our samplers uses some custom "helper" classes (e.g. Basically, Theano graphs model lambda expressions only, and Theano doesn't offer the OO abstraction on top of that, so one simply has to reformulate the OO-centric parts of the sampler implementations. |
Ok so if we do the core rewrite approach - step one would be to go through |
Yes, but, before that even, we might want to clarify exactly how this whole thing should work and what the inputs and outputs should be. For example, we could write a Once those inputs and their Theano types are clear, we can start to walk through the PyMC3 code and get an understanding of exactly what needs to be done. We can always take a cue from For that matter, we might be better off implementing a few simple Metropolis samplers first to see where the difficulties are (if any). Here's an old project that started doing just that. There's not much we can borrow from it except perhaps some examples of simple sampler loop logic (e.g. a Metropolis-Hastings sampler), but it's worth a look. |
Makes sense - in the absence of the ability to think critically (at this stage), about this, let me just get the requirements for the first step clear and I'll start working on it. What you're after here is
So just to be clear, the map that the sample function should take comprises |
Having looked through |
Delegating this to https://github.com/aesara-devs/aehmc |
If we added a NUTS Op to Theano with implementations in JAX (e.g. from numpyro) as well as C (here's an implementation https://github.com/alumbreras/NUTS-Cpp, or we can use the STAN one directly) we could get incredible speeds across different execution backends.
The text was updated successfully, but these errors were encountered: