Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions src/nodes/mv_normal_weighted_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,7 @@ import StatsFuns: log2π
@node MvNormalWeightedMeanPrecision Stochastic [out, (ξ, aliases = [xi, weightedmean]), (Λ, aliases = [invcov, precision])]

@average_energy MvNormalWeightedMeanPrecision (q_out::Any, q_ξ::Any, q_Λ::Any) = begin
m_mean, v_mean = mean_cov(q_ξ)
m_out, v_out = mean_cov(q_out)
return (ndims(q_out) * log2π - chollogdet(mean(q_Λ)) + tr(mean(q_Λ) * (v_out + v_mean + (m_out - m_mean) * (m_out - m_mean)'))) / 2
end

@average_energy MvNormalWeightedMeanPrecision (q_out_ξ::Any, q_Λ::Any) = begin
m, V = mean_cov(q_out_ξ)
d = div(ndims(q_out_ξ), 2)
return @views (
d * log2π +
-chollogdet(mean(q_Λ)) +
tr(mean(q_Λ) * (V[1:d, 1:d] - V[1:d, (d + 1):end] - V[(d + 1):end, 1:d] + V[(d + 1):end, (d + 1):end] + (m[1:d] - m[(d + 1):end]) * (m[1:d] - m[(d + 1):end])'))
) / 2
m_ξ, v_ξ = mean_cov(q_ξ)
m_out, v_out = mean_cov(q_out)
return (ndims(q_out) * log2π - mean(logdet, q_Λ) + tr(mean(q_Λ) * (m_out * m_out' + v_out)) - 2m_out'm_ξ + tr(mean(inv, q_Λ) * (m_ξ * m_ξ' + v_ξ))) / 2
end
3 changes: 3 additions & 0 deletions src/rules/mv_normal_weightedmean_precision/out.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
@rule MvNormalWeightedMeanPrecision(:out, Marginalisation) (m_ξ::PointMass, m_Λ::PointMass) = MvNormalWeightedMeanPrecision(mean(m_ξ), mean(m_Λ))

@rule MvNormalWeightedMeanPrecision(:out, Marginalisation) (q_ξ::Any, q_Λ::Any) = MvNormalWeightedMeanPrecision(mean(q_ξ), mean(q_Λ))
1 change: 1 addition & 0 deletions src/rules/prototypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ include("mv_normal_mean_scale_precision/precision.jl")
include("mv_normal_mean_scale_precision/marginals.jl")

include("mv_normal_weightedmean_precision/marginals.jl")
include("mv_normal_weightedmean_precision/out.jl")

include("normal_mean_precision/out.jl")
include("normal_mean_precision/mean.jl")
Expand Down
62 changes: 62 additions & 0 deletions test/nodes/test_mv_normal_weightedmean_precision.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
module MvNormalWeightedMeanPrecisionNodeTest

using Test, ReactiveMP, Random, BayesBase, ExponentialFamily

import ReactiveMP: make_node

@testset "MvNormalWeightedMeanPrecisionNodeTest" begin
@testset "Creation" begin
node = make_node(MvNormalWeightedMeanPrecision)

@test functionalform(node) === MvNormalWeightedMeanPrecision
@test sdtype(node) === Stochastic()
@test name.(interfaces(node)) === (:out, :ξ, :Λ)
@test factorisation(node) === ((1, 2, 3),)
@test localmarginalnames(node) === (:out_ξ_Λ,)
@test metadata(node) === nothing

node = make_node(MvNormalWeightedMeanPrecision, FactorNodeCreationOptions(nothing, 1, nothing))

@test functionalform(node) === MvNormalWeightedMeanPrecision
@test sdtype(node) === Stochastic()
@test name.(interfaces(node)) === (:out, :ξ, :Λ)
@test factorisation(node) === ((1, 2, 3),)
@test localmarginalnames(node) === (:out_ξ_Λ,)
@test metadata(node) === 1

node = make_node(MvNormalWeightedMeanPrecision, FactorNodeCreationOptions(((1,), (2, 3)), nothing, nothing))

@test functionalform(node) === MvNormalWeightedMeanPrecision
@test sdtype(node) === Stochastic()
@test name.(interfaces(node)) === (:out, :ξ, :Λ)
@test factorisation(node) === ((1,), (2, 3))
@test localmarginalnames(node) === (:out, :ξ_Λ)
@test metadata(node) === nothing
end

@testset "AverageEnergy" begin
begin
for i in 2:5
mean_in, L = randn(i), randn(i, i)
Cov_in = L * L'
mean_out = randn(i)

q_out = PointMass(mean_out)
q_Σ = PointMass(Cov_in)
q_μ = PointMass(mean_in)

q_Λ = PointMass(inv(mean(q_Σ)))
q_ξ = PointMass(mean(q_Λ) * mean(q_μ))

for N in (MvNormalMeanPrecision, MvNormalMeanCovariance, MvNormalWeightedMeanPrecision)
marginalsξ = (Marginal(q_out, false, false, nothing), Marginal(q_ξ, false, false, nothing), Marginal(q_Λ, false, false, nothing))
marginalsμ = (Marginal(q_out, false, false, nothing), Marginal(q_μ, false, false, nothing), Marginal(q_Σ, false, false, nothing))
@test score(AverageEnergy(), MvNormalWeightedMeanPrecision, Val{(:out, :ξ, :Λ)}(), marginalsξ, nothing) ≈
score(AverageEnergy(), MvNormalMeanCovariance, Val{(:out, :μ, :Σ)}(), marginalsμ, nothing)
end
end
end
# NOTE: tests for average energy when interfaces are not PointMass are not implemented
end
end
end
46 changes: 46 additions & 0 deletions test/rules/mv_normal_weightedmean_precision/test_out.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module RulesMvNormalWeightedMeanPrecisionOutTest

using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions

import ReactiveMP: @test_rules

@testset "rules:MvNormalWeightedMeanPrecision:out" begin
@testset "Belief Propagation: (m_ξ::PointMass, m_Λ::PointMass)" begin
@test_rules [check_type_promotion = true] MvNormalWeightedMeanPrecision(:out, Marginalisation) [
(input = (m_ξ = PointMass([-1.0]), m_Λ = PointMass([2.0])), output = MvNormalWeightedMeanPrecision([-1.0], [2.0])),
(input = (m_ξ = PointMass([1.0]), m_Λ = PointMass([2.0])), output = MvNormalWeightedMeanPrecision([1.0], [2.0])),
(input = (m_ξ = PointMass([2.0]), m_Λ = PointMass([1.0])), output = MvNormalWeightedMeanPrecision([2.0], [1.0])),
(input = (m_ξ = PointMass([1.0, 3.0]), m_Λ = PointMass([3.0 2.0; 2.0 4.0])), output = MvNormalWeightedMeanPrecision([1.0, 3.0], [3.0 2.0; 2.0 4.0])),
(input = (m_ξ = PointMass([-1.0, 2.0]), m_Λ = PointMass([7.0 -1.0; -1.0 9.0])), output = MvNormalWeightedMeanPrecision([-1.0, 2.0], [7.0 -1.0; -1.0 9.0])),
(input = (m_ξ = PointMass([0.0, 0.0]), m_Λ = PointMass([1.0 0.0; 0.0 1.0])), output = MvNormalWeightedMeanPrecision([0.0, 0.0], [1.0 0.0; 0.0 1.0]))
]
end

@testset "Variational: (q_ξ::Any, q_Λ::Any)" begin
@test_rules [check_type_promotion = true] MvNormalWeightedMeanPrecision(:out, Marginalisation) [
(input = (q_ξ = PointMass([-1.0]), q_Λ = PointMass([2.0])), output = MvNormalWeightedMeanPrecision([-1.0], [2.0])),
(input = (q_ξ = PointMass([1.0]), q_Λ = PointMass([2.0])), output = MvNormalWeightedMeanPrecision([1.0], [2.0])),
(input = (q_ξ = PointMass([2.0]), q_Λ = PointMass([1.0])), output = MvNormalWeightedMeanPrecision([2.0], [1.0])),
(input = (q_ξ = PointMass([1.0, 3.0]), q_Λ = PointMass([3.0 2.0; 2.0 4.0])), output = MvNormalWeightedMeanPrecision([1.0, 3.0], [3.0 2.0; 2.0 4.0])),
(input = (q_ξ = PointMass([-1.0, 2.0]), q_Λ = PointMass([7.0 -1.0; -1.0 9.0])), output = MvNormalWeightedMeanPrecision([-1.0, 2.0], [7.0 -1.0; -1.0 9.0])),
(input = (q_ξ = PointMass([0.0, 0.0]), q_Λ = PointMass([1.0 0.0; 0.0 1.0])), output = MvNormalWeightedMeanPrecision([0.0, 0.0], [1.0 0.0; 0.0 1.0]))
]

@test_rules [check_type_promotion = true] MvNormalWeightedMeanPrecision(:out, Marginalisation) [
(
input = (q_ξ = MvNormalWeightedMeanPrecision([3.0 2.0; 2.0 4.0] * [2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_Λ = Wishart(2.0, [6.0 4.0; 4.0 8.0] ./ 2.0)),
output = MvNormalWeightedMeanPrecision([2.0, 1.0], [6.0 4.0; 4.0 8.0])
),
(
input = (q_ξ = MvNormalWeightedMeanPrecision([0.0, 0.0], [7.0 -1.0; -1.0 9.0]), q_Λ = Wishart(3.0, [12.0 -2.0; -2.0 7.0] ./ 3.0)),
output = MvNormalWeightedMeanPrecision([0.0, 0.0], [12.0 -2.0; -2.0 7.0])
),
(
input = (q_ξ = MvNormalWeightedMeanPrecision([3.0, -1.0], [1.0 0.0; 0.0 1.0]), q_Λ = Wishart(4.0, [1.0 0.0; 0.0 1.0] ./ 4.0)),
output = MvNormalWeightedMeanPrecision([3.0, -1.0], [1.0 0.0; 0.0 1.0])
)
]
end
end

end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ end
addtests(testrunner, "nodes/test_normal_mean_variance.jl")
addtests(testrunner, "nodes/test_mv_normal_mean_precision.jl")
addtests(testrunner, "nodes/test_mv_normal_mean_scale_precision.jl")
addtests(testrunner, "nodes/test_mv_normal_weightedmean_precision.jl")
addtests(testrunner, "nodes/test_mv_normal_mean_covariance.jl")
addtests(testrunner, "nodes/test_poisson.jl")
addtests(testrunner, "nodes/test_wishart_inverse.jl")
Expand Down Expand Up @@ -347,6 +348,8 @@ end
addtests(testrunner, "rules/mv_normal_mean_scale_precision/test_mean.jl")
addtests(testrunner, "rules/mv_normal_mean_scale_precision/test_precision.jl")

addtests(testrunner, "rules/mv_normal_weightedmean_precision/test_out.jl")

addtests(testrunner, "rules/probit/test_out.jl")
addtests(testrunner, "rules/probit/test_in.jl")

Expand Down