-
Notifications
You must be signed in to change notification settings - Fork 226
Reverse-mode AD extremely slow for large number of observations #1642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Try Essentially, |
@torfjelde @mohamed82008 A side note, any reason that we shouldn't translate |
@torfjelde thank you for the hint. I'm still a little confused: did you mean
which is blazingly fast but yields wrong values for \mu and \sigma, or
which seems to have a very similar runtime to the original (still waiting for the ETA)? Or is there another way to use filldist? I was a little confused by the documentation here. |
@yebai dot broadcasting should be fast on observations, even GPU compatible most of the time. That is unless something changed recently. |
@anhi try Zygote and ReverseDiff. You might have better luck with Zygote here because ReverseDiff's performance is a bit brittle depending on whether we use an array of tracked reals (slow) or a tracked array (fast). |
I meant the latter, i.e. the one with Also, as @mohamed82008 pointed out this is for observations so my argument above doesn't actually hold.
It seems like he's using Zygote already? |
Yes my bad, didn't see this. So try ReverseDiff then :) |
One thing I've noticed in the past: if the function being |
:) ok, I'll try... ETA for zygote was ~22 hours, btw, while forwarddiff with filldist took 61 seconds (a little longer than without filldist, which was ~42 seconds) |
Btw, one thing you can do if you want to go really fast, is to use logx = log.(x)
zval = @. (logx - μ) / σ # `StatsFuns.normlogpdf(μ, σ, x)` has an if-statement in it, so we circumvent this by computing the `zval` ourselves.
@addlogprob! sum(StatsFuns.normlogpdf.(zval)) - sum(logx) which should be the same (maybe check this though), assuming you've done This should be muuuch faster using Zygote. |
ok, this is getting close...
I've also changed the TruncatedNormals to Normals because I was not sure if they contain ifs as well...
so this is indeed much faster than the standard LogNormal implementation. I'm running the same experiment with the TruncatedNormals again, and it seems similarly fast. ReverseDiff seems to have similar problems as Zygote, and similar timings. But I'll try that again later. |
Using TruncatedNormals, the sampling took 268 seconds, which indeed looks much better than the several days I started out with. However, I just noticed that the sampling returns wrong values for \sigma. The data was generated with \mu = 1.5 and \sigma = 0.5, but sampling with non-truncated normals returned a \sigma of 41, the TruncatedNormal version a \sigma of 10. \mu was close to 1.5 in both cases... |
Probably unnecessary. This is specifically a problem when you're doing
Yeah sorry, this is because I made a mistake in the above (told yah it needed some checking 😅). It should be sum(StatsFuns.normlogpdf.(zval)) - sum(logx) - sum(log.(σ)) I forgot the log-abs-det-jacobian term from computing the |
Closed in favour of #1934 |
When trying to optimize our Turing code, we experimented with the different AD engines. It seems as if the reverse-mode AD engines are extremely slow for large numbers of observations. Our original model has several hundred dimensions, but the effect can be demonstrated on this simple example:
On my machine, this takes about 42 seconds:
41.868502 seconds (21.57 M allocations: 1.340 GiB, 1.49% gc time, 18.19% compilation time)
I understand that for such a simple model, forward-mode should be more efficient. But when switching to reverse-mode
it takes several hours on my machine to even arrive at an ETA, which starts out at several days. This seems a little excessive.
Are we doing anything wrong, or is reverse-mode just not useable for large numbers of observations?
The text was updated successfully, but these errors were encountered: