Skip to content

Commit 9fdfe75

Browse files
authored
Merge pull request #565 from ReactiveBayes/node_mv_normal_wishart
`MvNormalWishart` node implementation
2 parents e6be62e + e028206 commit 9fdfe75

File tree

5 files changed

+18
-0
lines changed

5 files changed

+18
-0
lines changed

src/nodes/predefined.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ include("predefined/mv_normal_mean_precision.jl")
99
include("predefined/mv_normal_mean_scale_precision.jl")
1010
include("predefined/mv_normal_weighted_mean_precision.jl")
1111
include("predefined/mv_normal_mean_scale_matrix_precision.jl")
12+
include("predefined/mv_normal_wishart.jl")
1213
include("predefined/gamma.jl")
1314
include("predefined/gamma_inverse.jl")
1415
include("predefined/gamma_shape_rate.jl")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
@node MvNormalWishart Stochastic [out, (μ, aliases = [mean]), (W, aliases = [scale]), λ, ν]

src/rules/mv_normal_wishart/out.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
@rule MvNormalWishart(:out, Marginalisation) (q_μ::PointMass, q_W::PointMass, q_λ::PointMass, q_ν::PointMass) = MvNormalWishart(mean(q_μ), mean(q_W), mean(q_λ), mean(q_ν))

src/rules/predefined.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ include("mv_normal_mean_covariance/out.jl")
5555
include("mv_normal_mean_covariance/mean.jl")
5656
include("mv_normal_mean_covariance/covariance.jl")
5757
include("mv_normal_mean_covariance/marginals.jl")
58+
include("mv_normal_wishart/out.jl")
5859

5960
include("normal_mean_variance/out.jl")
6061
include("normal_mean_variance/mean.jl")
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
@testitem "rules:MvNormalWishart:out" begin
3+
using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions
4+
5+
import ReactiveMP: @test_rules
6+
7+
@testset "Variational Message Passing: (q_μ::PointMass, q_λ::PointMass, q_W::PointMass, q_ν::PointMass)" begin
8+
# Type promotion is false because λ and ν will also be promoted but not necessarily promote μ and W.
9+
@test_rules [check_type_promotion = false] MvNormalWishart(:out, Marginalisation) [(
10+
input = (q_μ = PointMass([1.0, 2.0]), q_W = PointMass([1.0 0.0; 0.0 1.0]), q_λ = PointMass(1.0), q_ν = PointMass(1.0)),
11+
output = MvNormalWishart([1.0, 2.0], [1.0 0.0; 0.0 1.0], 1.0, 1.0)
12+
),]
13+
end
14+
end # testset

0 commit comments

Comments
 (0)