Skip to content

Commit 2687050

Browse files
authored
Merge pull request #430 from ReactiveBayes/refactor-cvi-projection-marginal-rule
Refactor cvi projection marginal rule (with proposal distribution)
2 parents 7b82652 + 39cd541 commit 2687050

File tree

6 files changed

+365
-13
lines changed

6 files changed

+365
-13
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
9090
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
9191
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
9292
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
93+
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
9394
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
9495
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
9596
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -103,4 +104,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
103104
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
104105

105106
[targets]
106-
test = ["Aqua", "Hwloc", "ReTestItems", "Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "Dates", "Distributed", "Documenter", "BenchmarkTools", "JET", "PkgBenchmark", "StableRNGs", "Optimisers", "DiffResults", "ExponentialFamilyProjection", "REPL"]
107+
test = ["Aqua", "Hwloc", "ReTestItems", "Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "Dates", "Distributed", "Documenter", "BenchmarkTools", "JET", "PkgBenchmark", "StableRNGs", "Optimisers", "DiffResults", "ExponentialFamilyProjection", "REPL", "Manopt"]

docs/src/lib/nodes/delta.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ ReactiveMP.UnscentedTransform
1616
ReactiveMP.ProdCVI
1717
ReactiveMP.CVI
1818
ReactiveMP.CVIProjection
19+
ReactiveMP.CVISamplingStrategy
20+
ReactiveMP.FullSampling
21+
ReactiveMP.MeanBased
22+
ReactiveMP.ProposalDistributionContainer
1923
ReactiveMP.cvi_setup!
2024
ReactiveMP.cvi_update!
2125
ReactiveMP.DeltaFnDefaultRuleLayout

ext/ReactiveMPProjectionExt/rules/marginals.jl

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,53 @@ end
4242
return FactorizedJoint((q,))
4343
end
4444

45+
function create_density_function(forms_match, i, pre_samples, logp_nc_drop_index, m_in)
46+
if forms_match
47+
return z -> logp_nc_drop_index(z, i, pre_samples)
48+
end
49+
return z -> logp_nc_drop_index(z, i, pre_samples) + logpdf(m_in, z)
50+
end
51+
52+
function optimize_parameters(i, pre_samples, m_ins, logp_nc_drop_index, method)
53+
m_in = m_ins[i]
54+
default_type = ExponentialFamily.exponential_family_typetag(m_in)
55+
prj = create_project_to_ins(method, m_in, i)
56+
57+
typeform = ExponentialFamilyProjection.get_projected_to_type(prj)
58+
dims = ExponentialFamilyProjection.get_projected_to_dims(prj)
59+
conditioner = prj.conditioner
60+
ef_in = convert(ExponentialFamilyDistribution, m_in)
61+
forms_match = typeform === default_type && dims == size(m_in) && conditioner == getconditioner(ef_in)
62+
63+
df = create_density_function(forms_match, i, pre_samples, logp_nc_drop_index, m_in)
64+
logp = convert(promote_variate_type(variate_form(typeof(m_in)), BayesBase.AbstractContinuousGenericLogPdf), UnspecifiedDomain(), df)
65+
66+
return forms_match ? project_to(prj, logp, m_in) : project_to(prj, logp)
67+
end
68+
69+
function generate_samples(rng, ::Nothing, m_ins, sampling_strategy::FullSampling)
70+
return zip(map(m_in -> ReactiveMP.cvilinearize(rand(rng, m_in, sampling_strategy.samples)), m_ins)...)
71+
end
72+
73+
function generate_samples(::Any, ::Nothing, m_ins, ::MeanBased)
74+
return zip(map(m_in -> [mean(m_in)], m_ins)...)
75+
end
76+
77+
function generate_samples(rng, proposal_distribution::FactorizedJoint, ::Any, sampling_strategy::FullSampling)
78+
return zip(map(q_in -> ReactiveMP.cvilinearize(rand(rng, q_in, sampling_strategy.samples)), proposal_distribution.multipliers)...)
79+
end
80+
81+
function generate_samples(::Any, proposal_distribution::FactorizedJoint, ::Any, ::MeanBased)
82+
return zip(map(q_in -> [mean(q_in)], proposal_distribution.multipliers)...)
83+
end
84+
4585
@marginalrule DeltaFn(:ins) (m_out::Any, m_ins::ManyOf{N, Any}, meta::DeltaMeta{M}) where {N, M <: CVIProjection} = begin
4686
method = ReactiveMP.getmethod(meta)
4787
rng = method.rng
48-
pre_samples = zip(map(m_in_k -> ReactiveMP.cvilinearize(rand(rng, m_in_k, method.marginalsamples)), m_ins)...)
88+
proposal_distribution_container = method.proposal_distribution
89+
sampling_strategy = method.sampling_strategy
90+
91+
pre_samples = generate_samples(rng, proposal_distribution_container.distribution, m_ins, sampling_strategy)
4992

5093
logp_nc_drop_index = let g = getnodefn(meta, Val(:out)), pre_samples = pre_samples
5194
(z, i, pre_samples) -> begin
@@ -84,5 +127,7 @@ end
84127
end
85128
end
86129

87-
return FactorizedJoint(ntuple(i -> optimize_natural_parameters(i, pre_samples), length(m_ins)))
130+
result = FactorizedJoint(ntuple(i -> optimize_natural_parameters(i, pre_samples), length(m_ins)))
131+
proposal_distribution_container.distribution = result
132+
return result
88133
end

src/approximations/cvi_projection.jl

Lines changed: 101 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,124 @@
11
export CVIProjection
22

3+
export CVISamplingStrategy, FullSampling, MeanBased
4+
5+
"""
6+
CVISamplingStrategy
7+
8+
An abstract type representing the sampling strategy for the CVI projection method.
9+
Concrete subtypes implement different approaches for generating samples used in
10+
approximating distributions.
11+
"""
12+
abstract type CVISamplingStrategy end
13+
14+
"""
15+
FullSampling <: CVISamplingStrategy
16+
FullSampling(samples::Int = 10)
17+
18+
A sampling strategy that uses multiple samples drawn from distributions.
19+
20+
# Arguments
21+
- `samples::Int`: The number of samples to draw from each distribution. Default is 10.
22+
23+
# Example
24+
```julia
25+
# Use 100 samples for more accurate approximation
26+
strategy = FullSampling(100)
27+
```
28+
"""
29+
struct FullSampling <: CVISamplingStrategy
30+
samples::Int
31+
32+
FullSampling(samples::Int = 10) = new(samples)
33+
end
34+
35+
"""
36+
MeanBased <: CVISamplingStrategy
37+
38+
A sampling strategy that uses only the mean of the proposal distribution as a single sample.
39+
"""
40+
struct MeanBased <: CVISamplingStrategy end
41+
42+
"""
43+
ProposalDistributionContainer{PD}
44+
45+
A mutable wrapper for proposal distributions used in the CVI projection method.
46+
47+
The container allows the proposal distribution to be updated during inference
48+
without recreating the entire approximation method structure.
49+
50+
# Fields
51+
- `distribution::PD`: The wrapped proposal distribution, can be of any compatible type.
52+
"""
53+
mutable struct ProposalDistributionContainer{PD}
54+
distribution::PD
55+
end
56+
357
"""
458
CVIProjection(; parameters...)
559
660
A structure representing the parameters for the Conjugate Variational Inference (CVI) projection method.
761
This structure is a subtype of `AbstractApproximationMethod` and is used to configure the settings for CVI.
862
9-
!!! note
10-
The `CVIProjection` method requires `ExponentialFamilyProjection` package installed in the current environment.
63+
CVI approximates the posterior distribution by projecting it onto a family of distributions with a conjugate form.
64+
65+
# Requirements
66+
67+
The `CVIProjection` method requires the `ExponentialFamilyProjection` package to be installed and loaded
68+
in the current environment with `using ExponentialFamilyProjection`.
1169
1270
# Parameters
1371
1472
- `rng::R`: The random number generator used for sampling. Default is `Random.MersenneTwister(42)`.
15-
- `marginalsamples::S`: The number of samples used for approximating marginal distributions. Default is `10`.
1673
- `outsamples::S`: The number of samples used for approximating output message distributions. Default is `100`.
17-
- `out_prjparams::OF`: the form parameter used to select the distribution form on which one to project out edge, if it's not provided will be infered from marginal form
18-
- `in_prjparams::IFS`: a namedtuple like object to select the form on which one to project in the input edge, if it's not provided will be infered from the incoming message onto this edge
74+
- `out_prjparams::OF`: The form parameter used to specify the target distribution family for the output message.
75+
If `nothing` (default), the form will be inferred from the marginal form.
76+
- `in_prjparams::IFS`: A NamedTuple-like object that specifies the target distribution family for each input edge.
77+
Keys should be of the form `:in_k` where `k` is the input edge index. If `nothing` (default), the forms
78+
will be inferred from the incoming messages.
79+
- `proposal_distribution::PD`: The proposal distribution used for generating samples. If not provided or set to
80+
`nothing`, it will be inferred from incoming messages and automatically updated during iterations.
81+
- `sampling_strategy::SS`: The strategy for approximating the logpdf:
82+
- `FullSampling(n)`: Uses `n` samples drawn from distributions (default: `n=10`). Provides more accurate
83+
approximation at the cost of increased computation time.
84+
- `MeanBased()`: Uses only the mean of each distribution as a single sample. Significantly faster but
85+
less accurate for non-linear nodes or complex distributions.
86+
87+
# Examples
88+
89+
```julia
90+
# Standard CVI projection with default settings
91+
method = CVIProjection()
92+
93+
# Fast approximation using mean-based sampling
94+
method = CVIProjection(sampling_strategy = MeanBased())
95+
96+
# Custom proposal with increased sample count
97+
using Distributions
98+
proposal = FactorizedJoint((NormalMeanVariance(0.0, 1.0), NormalMeanVariance(0.0, 1.0)))
99+
method = CVIProjection(
100+
proposal_distribution = ProposalDistributionContainer(proposal),
101+
sampling_strategy = FullSampling(1000)
102+
)
103+
104+
# Specify projection family for the output message
105+
method = CVIProjection(out_prjparams = ProjectedTo(NormalMeanPrecision))
106+
107+
# Specify projection family for input edges
108+
method = CVIProjection(in_prjparams = (in_1 = ProjectedTo(NormalMeanVariance), in_2 = ProjectedTo(GammaShapeRate)))
109+
```
19110
20111
!!! note
21-
The `CVIProjection` method is an experimental enhancement of the now-deprecated `CVI`, offering better stability and improved accuracy.
22-
Note that the parameters of this structure, as well as their defaults, are subject to change during the experimentation phase.
112+
The `CVIProjection` method is an enhanced version of the deprecated `CVI`, offering better stability
113+
and improved accuracy. Parameters and defaults may change as the implementation evolves.
23114
"""
24-
Base.@kwdef struct CVIProjection{R, S, OF, IFS} <: AbstractApproximationMethod
115+
Base.@kwdef struct CVIProjection{R, S, OF, IFS, PD, SS} <: AbstractApproximationMethod
25116
rng::R = Random.MersenneTwister(42)
26-
marginalsamples::S = 10
27117
outsamples::S = 100
28118
out_prjparams::OF = nothing
29119
in_prjparams::IFS = nothing
120+
proposal_distribution::PD = ProposalDistributionContainer{Any}(nothing)
121+
sampling_strategy::SS = FullSampling(10)
30122
end
31123

32124
function get_kth_in_form(::CVIProjection{R, S, OF, Nothing}, ::Int) where {R, S, OF}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
@testitem "CVI Projection Extension Tests" begin
2+
using ExponentialFamily
3+
using ExponentialFamilyProjection
4+
using BayesBase
5+
using ReactiveMP
6+
using Distributions
7+
using Random
8+
9+
ReactiveMPProjectionExt = Base.get_extension(ReactiveMP, :ReactiveMPProjectionExt)
10+
@test !isnothing(ReactiveMPProjectionExt)
11+
using .ReactiveMPProjectionExt
12+
13+
@testset "create_density_function" begin
14+
# Mock functions and data for testing
15+
pre_samples = [1.0, 2.0, 3.0]
16+
17+
# Mock message with a simple normal distribution
18+
m_in = NormalMeanVariance(0.0, 1.0)
19+
20+
# Mock logp_nc_drop_index function that just returns a constant + the input value
21+
logp_nc_drop_index = (z, i, samples) -> -0.5 * z^2
22+
23+
# Test when forms match (should not include the message logpdf)
24+
forms_match = true
25+
df_match = ReactiveMPProjectionExt.create_density_function(forms_match, 1, pre_samples, logp_nc_drop_index, m_in)
26+
@test df_match(0.5) logp_nc_drop_index(0.5, 1, pre_samples)
27+
@test df_match(1.0) -0.5 # Just the logp_nc_drop_index result
28+
29+
# Test when forms don't match (should include the message logpdf)
30+
forms_match = false
31+
df_no_match = ReactiveMPProjectionExt.create_density_function(forms_match, 1, pre_samples, logp_nc_drop_index, m_in)
32+
# Expected: logp_nc_drop_index + logpdf of the message
33+
expected_value = logp_nc_drop_index(0.5, 1, pre_samples) + logpdf(m_in, 0.5)
34+
@test df_no_match(0.5) expected_value
35+
end
36+
37+
@testset "optimize_parameters" begin
38+
# Test with normal distribution - we can derive exact expected results
39+
m_in = NormalMeanVariance(0.0, 1.0) # Prior: mean=0, variance=1 (precision=1)
40+
m_ins = [m_in]
41+
pre_samples = [0.0, 0.5, -0.5]
42+
method = CVIProjection()
43+
44+
# Case 1: Quadratic log-likelihood centered at 0 (-0.5*z²) corresponds to Normal(0,1)
45+
# When combining Normal(0,1) prior with Normal(0,1) likelihood:
46+
# Expected posterior: Normal(0, 0.5) - precision adds (1+1=2, variance=1/2=0.5)
47+
log_fn1 = (z, i, samples) -> -0.5 * z^2
48+
result1 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn1, method)
49+
50+
@test result1 isa NormalMeanVariance
51+
@test mean(result1) 0.0 atol = 1e-1
52+
@test var(result1) 0.5 atol = 1e-1
53+
54+
# Case 2: Quadratic centered at 2.0 (-0.5*(z-2)²) corresponds to Normal(2,1)
55+
# Combining Normal(0,1) prior with Normal(2,1) likelihood:
56+
# Expected posterior: Normal(1, 0.5) - weighted average of means
57+
log_fn2 = (z, i, samples) -> -0.5 * (z - 2.0)^2
58+
result2 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn2, method)
59+
60+
@test result2 isa NormalMeanVariance
61+
@test mean(result2) 1.0 atol = 1e-1 # (0*1 + 2*1)/(1+1) = 1.0
62+
@test var(result2) 0.5 atol = 1e-1 # 1/(1+1) = 0.5
63+
64+
# Case 3: Stronger quadratic (-2.0*(z-2)²) corresponds to Normal(2,0.25)
65+
# Combining Normal(0,1) prior with Normal(2,0.25) likelihood:
66+
# Expected posterior: Normal(1.6, 0.2)
67+
log_fn3 = (z, i, samples) -> -2.0 * (z - 2.0)^2
68+
result3 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn3, method)
69+
70+
@test result3 isa NormalMeanVariance
71+
@test mean(result3) 1.6 atol = 1e-1 # (0*1 + 2*4)/(1+4) = 8/5 = 1.6
72+
@test var(result3) 0.2 atol = 1e-1 # 1/(1+4) = 0.2
73+
74+
# Case 4: Test with a different prior
75+
m_in2 = NormalMeanVariance(1.0, 2.0) # Prior: mean=1, variance=2 (precision=0.5)
76+
m_ins2 = [m_in2]
77+
78+
# Combining Normal(1,2) prior with Normal(2,1) likelihood:
79+
# Expected posterior: Normal(5/3, 2/3)
80+
result4 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins2, log_fn2, method)
81+
82+
@test result4 isa NormalMeanVariance
83+
@test mean(result4) 5 / 3 atol = 1e-1 # (1*0.5 + 2*1)/(0.5+1) = 1.67
84+
@test var(result4) 2 / 3 atol = 1e-1 # 1/(0.5+1) = 0.67
85+
end
86+
end
87+
88+
@testitem "optimize_parameters: with specified form" begin
89+
using ExponentialFamily
90+
using ExponentialFamilyProjection
91+
using BayesBase
92+
using ReactiveMP
93+
using Distributions
94+
using Random
95+
using Manopt
96+
97+
ReactiveMPProjectionExt = Base.get_extension(ReactiveMP, :ReactiveMPProjectionExt)
98+
@test !isnothing(ReactiveMPProjectionExt)
99+
using .ReactiveMPProjectionExt
100+
101+
m_in = NormalMeanVariance(0.0, 1.0)
102+
m_ins = [m_in]
103+
pre_samples = [0.0, 0.5, -0.5]
104+
method = CVIProjection(in_prjparams = (in_1 = ProjectedTo(NormalMeanVariance),))
105+
106+
log_fn1 = (z, i, samples) -> -0.5 * z^2
107+
result1 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn1, method)
108+
109+
@test result1 isa NormalMeanVariance
110+
@test mean(result1) 0.0 atol = 1e-1
111+
@test var(result1) 0.5 atol = 1e-1
112+
113+
m_in = Laplace(0.0, 1.0)
114+
m_ins = [m_in]
115+
pre_samples = [0.0, 0.5, -0.5]
116+
cost_recorder = Manopt.RecordCost()
117+
method = CVIProjection(in_prjparams = (in_1 = ProjectedTo(Laplace, conditioner = 1, kwargs = (record = [cost_recorder],)),))
118+
119+
log_fn1 = (z, i, samples) -> -0.5 * abs(z)
120+
result1 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn1, method)
121+
ef_result1 = convert(ExponentialFamilyDistribution, result1)
122+
123+
@test getconditioner(ef_result1) 1.0
124+
@test result1 isa Laplace
125+
@test cost_recorder.recorded_values[end] < cost_recorder.recorded_values[1]
126+
end

0 commit comments

Comments
 (0)