Skip to content

Conversation

@albertpod
Copy link
Member

This PR encompasses the following contributions for the MvNormalWeightedMeanPrecision node:

  1. Fix for average energy.
  2. Addition of rules for the out interface (particularly useful for priors).
  3. Inclusion of missing tests.

@albertpod albertpod requested a review from bartvanerp January 5, 2024 14:05
@codecov
Copy link

codecov bot commented Jan 5, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (f4835c2) 62.10% compared to head (b3444a2) 62.12%.
Report is 4 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #370      +/-   ##
==========================================
+ Coverage   62.10%   62.12%   +0.01%     
==========================================
  Files         180      181       +1     
  Lines        5927     5930       +3     
==========================================
+ Hits         3681     3684       +3     
  Misses       2246     2246              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

m_mean, v_mean = mean_cov(q_ξ)
m_out, v_out = mean_cov(q_out)
return (ndims(q_out) * log2π - chollogdet(mean(q_Λ)) + tr(mean(q_Λ) * (v_out + v_mean + (m_out - m_mean) * (m_out - m_mean)'))) / 2
return @views (ndims(q_out) * log2π - mean(logdet, q_Λ) + tr(mean(q_Λ) * (v_out + v_mean + (m_out - m_mean) * (m_out - m_mean)'))) / 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Would it make sense to add the method mean(chollogdet, ...) to BayesBase.jl? Then we can use the symmetry information to speed up this computation.

m_mean, v_mean = mean_cov(q_ξ)
m_out, v_out = mean_cov(q_out)
return (ndims(q_out) * log2π - chollogdet(mean(q_Λ)) + tr(mean(q_Λ) * (v_out + v_mean + (m_out - m_mean) * (m_out - m_mean)'))) / 2
return @views (ndims(q_out) * log2π - mean(logdet, q_Λ) + tr(mean(q_Λ) * (v_out + v_mean + (m_out - m_mean) * (m_out - m_mean)'))) / 2
Copy link
Member

@bartvanerp bartvanerp Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some concerns as to whether this rule is actually correct (likely my bad). Just did the derivations again and seem to end up with
ndims(q_out) * log2π - mean(chollogdet, q_Λ) + tr(mean(q_Λ) * (v_out + v_mean + m_mean*m_mean'+ m_out * m_out')) - 2*m_out*m_mean ) / 2
I might have a mistake somewhere, but we better double check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed it to a different form.

@bvdmitri
Copy link
Member

bvdmitri commented Jan 8, 2024

Something is completely wrong with this PR, I disagree with the change as for now. I also think that the tests might be incorrect. Lets discuss. But tldr, we cannot simply call average energy of the mean-precision parametrization, because semantics of the inputs is different.

@albertpod albertpod marked this pull request as draft January 8, 2024 13:08
@albertpod
Copy link
Member Author

Thanks, @bvdmitri and @bartvanerp, for spotting the mistakes. It's a little more difficult than I thought.

Anyway, just to re-iterate, I want to use this node as a prior mostly.

@albertpod albertpod marked this pull request as ready for review January 8, 2024 19:19
@bvdmitri bvdmitri merged commit 8b85e8a into main Jan 9, 2024
@bvdmitri bvdmitri deleted the dev-rule-mvnormalwmp branch January 9, 2024 11:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants