-
Notifications
You must be signed in to change notification settings - Fork 15
Add rules and tests for MvNormalWeightedMeanPrecision node #370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #370 +/- ##
==========================================
+ Coverage 62.10% 62.12% +0.01%
==========================================
Files 180 181 +1
Lines 5927 5930 +3
==========================================
+ Hits 3681 3684 +3
Misses 2246 2246 ☔ View full report in Codecov by Sentry. |
| m_mean, v_mean = mean_cov(q_ξ) | ||
| m_out, v_out = mean_cov(q_out) | ||
| return (ndims(q_out) * log2π - chollogdet(mean(q_Λ)) + tr(mean(q_Λ) * (v_out + v_mean + (m_out - m_mean) * (m_out - m_mean)'))) / 2 | ||
| return @views (ndims(q_out) * log2π - mean(logdet, q_Λ) + tr(mean(q_Λ) * (v_out + v_mean + (m_out - m_mean) * (m_out - m_mean)'))) / 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Would it make sense to add the method mean(chollogdet, ...) to BayesBase.jl? Then we can use the symmetry information to speed up this computation.
| m_mean, v_mean = mean_cov(q_ξ) | ||
| m_out, v_out = mean_cov(q_out) | ||
| return (ndims(q_out) * log2π - chollogdet(mean(q_Λ)) + tr(mean(q_Λ) * (v_out + v_mean + (m_out - m_mean) * (m_out - m_mean)'))) / 2 | ||
| return @views (ndims(q_out) * log2π - mean(logdet, q_Λ) + tr(mean(q_Λ) * (v_out + v_mean + (m_out - m_mean) * (m_out - m_mean)'))) / 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some concerns as to whether this rule is actually correct (likely my bad). Just did the derivations again and seem to end up with
ndims(q_out) * log2π - mean(chollogdet, q_Λ) + tr(mean(q_Λ) * (v_out + v_mean + m_mean*m_mean'+ m_out * m_out')) - 2*m_out*m_mean ) / 2
I might have a mistake somewhere, but we better double check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've changed it to a different form.
|
Something is completely wrong with this PR, I disagree with the change as for now. I also think that the tests might be incorrect. Lets discuss. But tldr, we cannot simply call average energy of the mean-precision parametrization, because semantics of the inputs is different. |
|
Thanks, @bvdmitri and @bartvanerp, for spotting the mistakes. It's a little more difficult than I thought. Anyway, just to re-iterate, I want to use this node as a prior mostly. |
Co-authored-by: Bart van Erp <[email protected]>
This PR encompasses the following contributions for the
MvNormalWeightedMeanPrecisionnode: