Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/nodes/predefined/softdot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,21 @@ const softdot = SoftDot
m_γ = mean(q_γ)
return (-mean(log, q_γ) + log2π + m_γ * (V_y + m_y^2 - 2m_γ * m_y * m_θ'm_x + mul_trace(V_θ, V_x) + m_x'V_θ * m_x + m_θ' * (V_x + m_x * m_x') * m_θ)) / 2
end

@average_energy softdot (q_y_x::MultivariateNormalDistributionsFamily, q_θ::NormalDistributionsFamily, q_γ::GammaShapeRate) = begin
mθ, Vθ = mean_cov(q_θ)
myx, Vyx = mean_cov(q_y_x)
mγ = mean(q_γ)

order = length(mθ)
F = order == 1 ? Univariate : Multivariate

mx, Vx = ar_slice(F, myx, (order + 1):(2order)), ar_slice(F, Vyx, (order + 1):(2order), (order + 1):(2order))
my1, Vy1 = first(myx), first(Vyx)
Vy1x = ar_slice(F, Vyx, 1, (order + 1):(2order))

# Equivalent to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2
AE = (-mean(log, q_γ) + log2π + mγ * (Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + abs2(dot(mθ, mx)))) / 2

return AE
end
4 changes: 4 additions & 0 deletions src/rules/dot_product/in1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
@rule typeof(dot)(:in1, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in2::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
return @call_rule typeof(dot)(:in2, Marginalisation) (m_out = m_out, m_in1 = m_in2, meta = meta)
end

@rule typeof(dot)(:in1, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in2::NormalDistributionsFamily, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
return error("The rule for the dot product node between two NormalDistributionsFamily instances is not available in closed form. Please use SoftDot instead.")
end
4 changes: 4 additions & 0 deletions src/rules/dot_product/in2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@

return convert(promote_variate_type(variate_form(typeof(m_in1)), NormalWeightedMeanPrecision), ξ, W)
end

@rule typeof(dot)(:in2, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in1::NormalDistributionsFamily, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
return error("The rule for the dot product node between two NormalDistributionsFamily instances is not available in closed form. Please use SoftDot instead.")
end
4 changes: 4 additions & 0 deletions src/rules/dot_product/out.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ end
in2_mean, in2_cov = mean_cov(m_in2)
return NormalMeanVariance(dot(A, in2_mean), dot(A, in2_cov, A))
end

@rule typeof(dot)(:out, Marginalisation) (m_in1::NormalDistributionsFamily, m_in2::NormalDistributionsFamily, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
return error("The rule for the dot product node between two NormalDistributionsFamily instances is not available in closed form. Please use SoftDot instead.")
end
11 changes: 6 additions & 5 deletions src/rules/predefined.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,6 @@ include("dot_product/out.jl")
include("dot_product/in1.jl")
include("dot_product/in2.jl")

include("softdot/y.jl")
include("softdot/x.jl")
include("softdot/theta.jl")
include("softdot/gamma.jl")

include("transition/marginals.jl")
include("transition/out.jl")
include("transition/in.jl")
Expand All @@ -123,6 +118,12 @@ include("autoregressive/theta.jl")
include("autoregressive/gamma.jl")
include("autoregressive/marginals.jl")

include("softdot/y.jl")
include("softdot/x.jl")
include("softdot/theta.jl")
include("softdot/gamma.jl")
include("softdot/marginals.jl")

include("probit/marginals.jl")
include("probit/in.jl")
include("probit/out.jl")
Expand Down
22 changes: 22 additions & 0 deletions src/rules/softdot/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,25 @@
β += (mul_trace(Vx, Vθ) + mθ'Vx * mθ + mx'Vθ * mx + mθ'mx * mx'mθ) / 2
return GammaShapeRate(α, β)
end

# Variational MP: Structured
@rule softdot(:γ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_θ::Any) = begin
# q_y is always Univariate
order = length(q_y_x) - 1
F = order == 1 ? Univariate : Multivariate

y_x_mean, y_x_cov = mean_cov(q_y_x)
mθ, Vθ = mean_cov(q_θ)

my, Vy = first(y_x_mean), first(y_x_cov)
mx, Vx = ar_slice(F, y_x_mean, 2:(order + 1)), ar_slice(F, y_x_cov, 2:(order + 1), 2:(order + 1))
Vyx = ar_slice(F, y_x_cov, 2:(order + 1))

C = rank1update(Vx, mx)
R = rank1update(Vy, my)
L = Vyx + mx * my

B = first(R) - 2 * first(mθ' * L) + first(mθ' * C * mθ) + mul_trace(Vθ, C)

return GammaShapeRate(convert(eltype(B), 3//2), B / 2)
end
27 changes: 27 additions & 0 deletions src/rules/softdot/marginals.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

# The following marginal rule is adaptation of the marginal rule for Autoregressive node
@marginalrule SoftDot(:y_x) (m_y::NormalDistributionsFamily, m_x::NormalDistributionsFamily, q_θ::NormalDistributionsFamily, q_γ::Any) = begin
mθ, Vθ = mean_cov(q_θ)
mγ = mean(q_γ)

b_my, b_Vy = mean_cov(m_y)
f_mx, f_Vx = mean_cov(m_x)

inv_b_Vy = cholinv(b_Vy)
inv_f_Vx = cholinv(f_Vx)

D = inv_f_Vx + mγ * Vθ

W_11 = inv_b_Vy + mγ

W_12 = -mγ * mθ'

W_21 = -mθ * mγ

W_22 = D + mθ * mγ * mθ'

W = [W_11 W_12; W_21 W_22]
ξ = [inv_b_Vy * b_my; inv_f_Vx * f_mx]

return MvNormalWeightedMeanPrecision(ξ, W)
end
20 changes: 20 additions & 0 deletions src/rules/softdot/theta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,23 @@
zθ = mγ * mx * my
return convert(promote_variate_type(variate_form(typeof(q_x)), NormalWeightedMeanPrecision), zθ, Dθ)
end

# Variational MP: Structured
@rule softdot(:θ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_γ::Any) = begin
# q_y is always Univariate
order = length(q_y_x) - 1
F = order == 1 ? Univariate : Multivariate

myx, Vyx = mean_cov(q_y_x)
my, Vy = first(myx), first(Vyx)
mx, Vx = ar_slice(F, myx, 2:(order + 1)), ar_slice(F, Vyx, 2:(order + 1), 2:(order + 1))
Vyx = ar_slice(F, Vyx, 2:(order + 1))

mγ = mean(q_γ)

W = mγ * (Vx + mx * mx')

ξ = (Vyx + mx * my') * mγ

return convert(promote_variate_type(F, NormalWeightedMeanPrecision), ξ, W)
end
18 changes: 18 additions & 0 deletions src/rules/softdot/x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,21 @@
zx = mγ * mθ * my
return convert(promote_variate_type(variate_form(typeof(q_θ)), NormalWeightedMeanPrecision), zx, Dx)
end

# Variational MP: Structured
@rule softdot(:x, Marginalisation) (m_y::UnivariateNormalDistributionsFamily, q_θ::Any, q_γ::Any) = begin
# the naive call of AR rule is not possible, because the softdot rule expects m_y to be a UnivariateNormalDistributionsFamily
mθ, Vθ = mean_cov(q_θ)
my, Vy = mean_cov(m_y)

mγ = mean(q_γ)

mV = inv(mγ)

C = mθ * inv(add_transition(Vy, mV))

W = C * mθ' + mγ * Vθ
ξ = C * my

return convert(promote_variate_type(variate_form(typeof(q_θ)), NormalWeightedMeanPrecision), ξ, W)
end
4 changes: 4 additions & 0 deletions src/rules/softdot/y.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@

# Variational MP: Mean-field
@rule softdot(:y, Marginalisation) (q_θ::Any, q_x::Any, q_γ::Any) = NormalMeanPrecision(mean(q_θ)'mean(q_x), mean(q_γ))

@rule softdot(:y, Marginalisation) (q_θ::Any, m_x::Any, q_γ::Any) = NormalMeanVariance(
first.(mean_cov((@call_rule AR(:y, Marginalisation) (m_x = m_x, q_θ = q_θ, q_γ = q_γ, meta = ARMeta(variate_form(typeof(m_x)), length(q_θ), ARsafe())))))...
)
9 changes: 9 additions & 0 deletions test/nodes/predefined/softdot_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,14 @@

@test score(AverageEnergy(), SoftDot, Val{(:y, :θ, :x, :γ)}(), marginals, nothing) ≈ 8.15193210352257
end

begin
q_y_x = MvNormalMeanCovariance(zeros(2), diageye(2))
q_θ = NormalMeanVariance(0.0, 1.0)
q_γ = GammaShapeRate(2.0, 3.0)

marginals = (Marginal(q_y_x, false, false, nothing), Marginal(q_θ, false, false, nothing), Marginal(q_γ, false, false, nothing))
@test score(AverageEnergy(), SoftDot, Val{(:y_x, :θ, :γ)}(), marginals, nothing) ≈ 1.92351917665616
end
end # testset: AverageEnergy
end # testset
8 changes: 8 additions & 0 deletions test/rules/dot_product/in1_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,12 @@
)
]
end

@testset "Error Belief Propagation: (m_out::UnivariateNormalDistributionsFamily, m_in2::NormalDistributionsFamily)" begin
@test_throws r"The rule for the dot product node between two NormalDistributionsFamily instances is not available in closed form. Please use SoftDot instead." @call_rule typeof(
dot
)(
:in1, Marginalisation
) (m_out = NormalMeanVariance(2.0, 2.0), m_in2 = MvNormalMeanCovariance([-1.0, 1.0], [2.0 -1.0; -1.0 4.0]), meta = NoCorrection())
end
end
8 changes: 8 additions & 0 deletions test/rules/dot_product/in2_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,12 @@
)
]
end

@testset "Error Belief Propagation: (m_out::UnivariateNormalDistributionsFamily, m_in1::NormalDistributionsFamily)" begin
@test_throws r"The rule for the dot product node between two NormalDistributionsFamily instances is not available in closed form. Please use SoftDot instead." @call_rule typeof(
dot
)(
:in2, Marginalisation
) (m_out = NormalMeanVariance(2.0, 2.0), m_in1 = MvNormalMeanCovariance([-1.0, 1.0], [2.0 -1.0; -1.0 4.0]), meta = NoCorrection())
end
end
8 changes: 8 additions & 0 deletions test/rules/dot_product/out_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,12 @@
)
]
end

@testset "Error belief Propagation: (m_in1::NormalDistributionsFamily, m_in2::NormalDistributionsFamily)" begin
@test_throws r"The rule for the dot product node between two NormalDistributionsFamily instances is not available in closed form. Please use SoftDot instead." @call_rule typeof(
dot
)(
:out, Marginalisation
) (m_in1 = NormalMeanVariance(2.0, 2.0), m_in2 = NormalMeanVariance(2.0, 2.0), meta = NoCorrection())
end
end
21 changes: 21 additions & 0 deletions test/rules/softdot/gamma_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,25 @@
)
end
end # testset: mean-field

@testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_γ::Any)" begin
@test_rules [check_type_promotion = true] SoftDot(:γ, Marginalisation) [
(input = (q_y_x = MvNormalMeanCovariance(ones(2), diageye(2)), q_θ = NormalMeanPrecision(1.0, 1.0)), output = GammaShapeRate(3 / 2, 2.0)),
(input = (q_y_x = MvNormalMeanCovariance(2 * ones(2), diageye(2)), q_θ = NormalMeanPrecision(2.0, 1.0)), output = GammaShapeRate(3 / 2, 7.0))
]
end

@testset "Structured : (q_y_x::MultivariateNormalDistributionsFamily, q_θ::Any)" begin
order = 2
@test_rules [check_type_promotion = true] SoftDot(:γ, Marginalisation) [
(
input = (q_y_x = MvNormalMeanCovariance(ones(order + 1), diageye(order + 1)), q_θ = MvNormalMeanPrecision(ones(order), diageye(order))),
output = GammaShapeRate(3 / 2, 4.0)
),
(
input = (q_y_x = MvNormalMeanCovariance(ones(order + 1), diageye(order + 1)), q_θ = MvNormalMeanPrecision(zeros(order), diageye(order))),
output = GammaShapeRate(3 / 2, 3.0)
)
]
end
end # testset
26 changes: 26 additions & 0 deletions test/rules/softdot/test_marginals.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

@testitem "marginalrules:SoftDot" begin
using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions

import ReactiveMP: @test_marginalrules

@testset "y_x: (m_y::UnivariateNormalDistributionsFamily, m_x::UnivariateNormalDistributionsFamily, q_θ::UnivariateNormalDistributionsFamily, q_γ::Any)" begin
@test_marginalrules [check_type_promotion = true] SoftDot(:y_x) [(
input = (m_y = NormalMeanPrecision(0.0, 1.0), m_x = NormalMeanPrecision(0.0, 1.0), q_θ = NormalMeanPrecision(1.0, 1.0), q_γ = GammaShapeRate(1.0, 1.0)),
output = MvNormalWeightedMeanPrecision(zeros(2), [2.0 -1.0; -1.0 3.0])
)]
end

@testset "y_x: (m_y::UnivariateNormalDistributionsFamily), m_x::MultivariateNormalDistributionsFamily, q_θ::MultivariateNormalDistributionsFamily, q_γ::Any)" begin
order = 2
@test_marginalrules [check_type_promotion = true] SoftDot(:y_x) [(
input = (
m_y = NormalMeanPrecision(1.0, 1.0),
m_x = MvNormalMeanCovariance(ones(order), diageye(order)),
q_θ = MvNormalMeanCovariance(ones(order), diageye(order)),
q_γ = GammaShapeRate(1.0, 1.0)
),
output = MvNormalWeightedMeanPrecision(ones(3), [2.0 -1.0 -1.0; -1.0 3.0 1.0; -1.0 1.0 3.0])
)]
end
end
21 changes: 21 additions & 0 deletions test/rules/softdot/theta_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,25 @@
end
# NOTE: γ can theoretically be Any, so also NormalMeanVariance
end

@testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_γ::Any)" begin
@test_rules [check_type_promotion = true] SoftDot(:θ, Marginalisation) [
(input = (q_y_x = MvNormalMeanCovariance(ones(2), diageye(2)), q_γ = GammaShapeRate(1.0, 1.0)), output = NormalWeightedMeanPrecision(1.0, 2.0)),
(input = (q_y_x = MvNormalMeanCovariance(2 * ones(2), diageye(2)), q_γ = GammaShapeScale(2.0, 1.0)), output = NormalWeightedMeanPrecision(8.0, 10.0))
]
end

@testset "Structured : (q_y_x::MultivariateNormalDistributionsFamily, q_γ::Any)" begin
order = 2
@test_rules [check_type_promotion = true] SoftDot(:θ, Marginalisation) [
(
input = (q_y_x = MvNormalMeanCovariance(ones(order + 1), diageye(order + 1)), q_γ = GammaShapeRate(1.0, 1.0)),
output = MvNormalWeightedMeanPrecision(ones(order), [2.0 1.0; 1.0 2.0])
),
(
input = (q_y_x = MvNormalMeanCovariance(zeros(order + 1), diageye(order + 1)), q_γ = GammaShapeRate(1.0, 1.0)),
output = MvNormalWeightedMeanPrecision(zeros(order), [1.0 0.0; 0.0 1.0])
)
]
end
end # testset
30 changes: 30 additions & 0 deletions test/rules/softdot/x_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,34 @@
end
# NOTE: γ can theoretically be Any, so also NormalMeanVariance
end

@testset "VMP: structured rules" begin
@testset "(m_y::NormalMeanVariance, q_θ::NormalMeanVariance, q_γ::Any)" begin
@test_rules [check_type_promotion = true] SoftDot(:x, Marginalisation) [
(input = (m_y = NormalMeanVariance(1.0, 1.0), q_θ = NormalMeanVariance(1.0, 1.0), q_γ = GammaShapeRate(1.0, 1.0)), output = NormalWeightedMeanPrecision(0.5, 1.5)),
(
input = (m_y = NormalWeightedMeanPrecision(1.0, 1.0), q_θ = NormalMeanPrecision(1.0, 2.0), q_γ = GammaShapeScale(1.0, 1.0)),
output = NormalWeightedMeanPrecision(0.5, 1.0)
)
]
end

@testset "(m_y::UnivariateNormalDistributionsFamily, q_θ::MultivariateNormalDistributionsFamily, q_γ::Any)" begin
order = 2
@test_rules [check_type_promotion = false] SoftDot(:x, Marginalisation) [
(
input = (m_y = NormalMeanVariance(0.0, 1.0), q_θ = MvNormalMeanCovariance(ones(order), diageye(order)), q_γ = GammaShapeRate(1.0, 1.0)),
output = MvNormalWeightedMeanPrecision([0.0, 0.0], [1.5 0.5; 0.5 1.5])
),
(
input = (m_y = NormalMeanVariance(1.0, 1.0), q_θ = MvNormalMeanCovariance(zeros(order), diageye(order)), q_γ = GammaShapeScale(1.0, 1.0)),
output = MvNormalWeightedMeanPrecision([0.0, 0.0], [1.0 0.0; 0.0 1.0])
),
(
input = (m_y = NormalMeanVariance(1.0, 1.0), q_θ = MvNormalMeanCovariance(ones(order), diageye(order)), q_γ = Gamma(1.0, 1.0)),
output = MvNormalWeightedMeanPrecision([0.5, 0.5], [1.5 0.5; 0.5 1.5])
)
]
end
end
end # testset
33 changes: 33 additions & 0 deletions test/rules/softdot/y_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,37 @@
)
end
end

@testset "VMP: structured rules" begin
@testset "(q_θ::NormalMeanVariance, m_x::NormalMeanVariance, q_γ::Any" begin
@test_rules [check_type_promotion = true] SoftDot(:y, Marginalisation) [
(input = (q_θ = PointMass(3.0), q_x = PointMass(11.0), q_γ = GammaShapeRate(7.0, 5.0)), output = NormalMeanPrecision(33.0, 1.4)),
(input = (q_θ = PointMass(3.0), q_x = PointMass(11.0), q_γ = GammaShapeScale(7.0, 5.0)), output = NormalMeanPrecision(33.0, 35.0))
]

@test_rules [check_type_promotion = true] SoftDot(:y, Marginalisation) [
(input = (m_x = NormalMeanVariance(1.0, 1.0), q_θ = NormalMeanVariance(1.0, 1.0), q_γ = GammaShapeRate(1.0, 1.0)), output = NormalMeanVariance(0.5, 1.5)),
(
input = (m_x = NormalWeightedMeanPrecision(1.0, 1.0), q_θ = NormalMeanPrecision(1.0, 2.0), q_γ = GammaShapeScale(2.0, 1.0)),
output = NormalMeanVariance(0.5, 1.0)
)
]
end

@testset "(q_θ::MvNormalMeanCovariance, m_x::MvNormalMeanCovariance, q_γ::Any" begin
order = 2
@test_rules [check_type_promotion = true] SoftDot(:y, Marginalisation) [
(
input = (
m_x = MvNormalMeanCovariance(ones(order), diageye(order)), q_θ = MvNormalMeanCovariance(zeros(order), diageye(order)), q_γ = GammaShapeScale(1.0, 1.0)
),
output = NormalMeanVariance(0.0, 1.0)
),
(
input = (m_x = MvNormalMeanCovariance(ones(order), diageye(order)), q_θ = MvNormalMeanCovariance(ones(order), diageye(order)), q_γ = Gamma(1.0, 1.0)),
output = NormalMeanVariance(1.0, 2.0)
)
]
end
end
end # testset