|
| 1 | +@testitem "CVI Projection Extension Tests" begin |
| 2 | + using ExponentialFamily |
| 3 | + using ExponentialFamilyProjection |
| 4 | + using BayesBase |
| 5 | + using ReactiveMP |
| 6 | + using Distributions |
| 7 | + using Random |
| 8 | + |
| 9 | + ReactiveMPProjectionExt = Base.get_extension(ReactiveMP, :ReactiveMPProjectionExt) |
| 10 | + @test !isnothing(ReactiveMPProjectionExt) |
| 11 | + using .ReactiveMPProjectionExt |
| 12 | + |
| 13 | + @testset "create_density_function" begin |
| 14 | + # Mock functions and data for testing |
| 15 | + pre_samples = [1.0, 2.0, 3.0] |
| 16 | + |
| 17 | + # Mock message with a simple normal distribution |
| 18 | + m_in = NormalMeanVariance(0.0, 1.0) |
| 19 | + |
| 20 | + # Mock logp_nc_drop_index function that just returns a constant + the input value |
| 21 | + logp_nc_drop_index = (z, i, samples) -> -0.5 * z^2 |
| 22 | + |
| 23 | + # Test when forms match (should not include the message logpdf) |
| 24 | + forms_match = true |
| 25 | + df_match = ReactiveMPProjectionExt.create_density_function(forms_match, 1, pre_samples, logp_nc_drop_index, m_in) |
| 26 | + @test df_match(0.5) ≈ logp_nc_drop_index(0.5, 1, pre_samples) |
| 27 | + @test df_match(1.0) ≈ -0.5 # Just the logp_nc_drop_index result |
| 28 | + |
| 29 | + # Test when forms don't match (should include the message logpdf) |
| 30 | + forms_match = false |
| 31 | + df_no_match = ReactiveMPProjectionExt.create_density_function(forms_match, 1, pre_samples, logp_nc_drop_index, m_in) |
| 32 | + # Expected: logp_nc_drop_index + logpdf of the message |
| 33 | + expected_value = logp_nc_drop_index(0.5, 1, pre_samples) + logpdf(m_in, 0.5) |
| 34 | + @test df_no_match(0.5) ≈ expected_value |
| 35 | + end |
| 36 | + |
| 37 | + @testset "optimize_parameters" begin |
| 38 | + # Test with normal distribution - we can derive exact expected results |
| 39 | + m_in = NormalMeanVariance(0.0, 1.0) # Prior: mean=0, variance=1 (precision=1) |
| 40 | + m_ins = [m_in] |
| 41 | + pre_samples = [0.0, 0.5, -0.5] |
| 42 | + method = CVIProjection() |
| 43 | + |
| 44 | + # Case 1: Quadratic log-likelihood centered at 0 (-0.5*z²) corresponds to Normal(0,1) |
| 45 | + # When combining Normal(0,1) prior with Normal(0,1) likelihood: |
| 46 | + # Expected posterior: Normal(0, 0.5) - precision adds (1+1=2, variance=1/2=0.5) |
| 47 | + log_fn1 = (z, i, samples) -> -0.5 * z^2 |
| 48 | + result1 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn1, method) |
| 49 | + |
| 50 | + @test result1 isa NormalMeanVariance |
| 51 | + @test mean(result1) ≈ 0.0 atol = 1e-1 |
| 52 | + @test var(result1) ≈ 0.5 atol = 1e-1 |
| 53 | + |
| 54 | + # Case 2: Quadratic centered at 2.0 (-0.5*(z-2)²) corresponds to Normal(2,1) |
| 55 | + # Combining Normal(0,1) prior with Normal(2,1) likelihood: |
| 56 | + # Expected posterior: Normal(1, 0.5) - weighted average of means |
| 57 | + log_fn2 = (z, i, samples) -> -0.5 * (z - 2.0)^2 |
| 58 | + result2 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn2, method) |
| 59 | + |
| 60 | + @test result2 isa NormalMeanVariance |
| 61 | + @test mean(result2) ≈ 1.0 atol = 1e-1 # (0*1 + 2*1)/(1+1) = 1.0 |
| 62 | + @test var(result2) ≈ 0.5 atol = 1e-1 # 1/(1+1) = 0.5 |
| 63 | + |
| 64 | + # Case 3: Stronger quadratic (-2.0*(z-2)²) corresponds to Normal(2,0.25) |
| 65 | + # Combining Normal(0,1) prior with Normal(2,0.25) likelihood: |
| 66 | + # Expected posterior: Normal(1.6, 0.2) |
| 67 | + log_fn3 = (z, i, samples) -> -2.0 * (z - 2.0)^2 |
| 68 | + result3 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn3, method) |
| 69 | + |
| 70 | + @test result3 isa NormalMeanVariance |
| 71 | + @test mean(result3) ≈ 1.6 atol = 1e-1 # (0*1 + 2*4)/(1+4) = 8/5 = 1.6 |
| 72 | + @test var(result3) ≈ 0.2 atol = 1e-1 # 1/(1+4) = 0.2 |
| 73 | + |
| 74 | + # Case 4: Test with a different prior |
| 75 | + m_in2 = NormalMeanVariance(1.0, 2.0) # Prior: mean=1, variance=2 (precision=0.5) |
| 76 | + m_ins2 = [m_in2] |
| 77 | + |
| 78 | + # Combining Normal(1,2) prior with Normal(2,1) likelihood: |
| 79 | + # Expected posterior: Normal(5/3, 2/3) |
| 80 | + result4 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins2, log_fn2, method) |
| 81 | + |
| 82 | + @test result4 isa NormalMeanVariance |
| 83 | + @test mean(result4) ≈ 5 / 3 atol = 1e-1 # (1*0.5 + 2*1)/(0.5+1) = 1.67 |
| 84 | + @test var(result4) ≈ 2 / 3 atol = 1e-1 # 1/(0.5+1) = 0.67 |
| 85 | + end |
| 86 | +end |
| 87 | + |
| 88 | +@testitem "optimize_parameters: with specified form" begin |
| 89 | + using ExponentialFamily |
| 90 | + using ExponentialFamilyProjection |
| 91 | + using BayesBase |
| 92 | + using ReactiveMP |
| 93 | + using Distributions |
| 94 | + using Random |
| 95 | + using Manopt |
| 96 | + |
| 97 | + ReactiveMPProjectionExt = Base.get_extension(ReactiveMP, :ReactiveMPProjectionExt) |
| 98 | + @test !isnothing(ReactiveMPProjectionExt) |
| 99 | + using .ReactiveMPProjectionExt |
| 100 | + |
| 101 | + m_in = NormalMeanVariance(0.0, 1.0) |
| 102 | + m_ins = [m_in] |
| 103 | + pre_samples = [0.0, 0.5, -0.5] |
| 104 | + method = CVIProjection(in_prjparams = (in_1 = ProjectedTo(NormalMeanVariance),)) |
| 105 | + |
| 106 | + log_fn1 = (z, i, samples) -> -0.5 * z^2 |
| 107 | + result1 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn1, method) |
| 108 | + |
| 109 | + @test result1 isa NormalMeanVariance |
| 110 | + @test mean(result1) ≈ 0.0 atol = 1e-1 |
| 111 | + @test var(result1) ≈ 0.5 atol = 1e-1 |
| 112 | + |
| 113 | + m_in = Laplace(0.0, 1.0) |
| 114 | + m_ins = [m_in] |
| 115 | + pre_samples = [0.0, 0.5, -0.5] |
| 116 | + cost_recorder = Manopt.RecordCost() |
| 117 | + method = CVIProjection(in_prjparams = (in_1 = ProjectedTo(Laplace, conditioner = 1, kwargs = (record = [cost_recorder],)),)) |
| 118 | + |
| 119 | + log_fn1 = (z, i, samples) -> -0.5 * abs(z) |
| 120 | + result1 = ReactiveMPProjectionExt.optimize_parameters(1, pre_samples, m_ins, log_fn1, method) |
| 121 | + ef_result1 = convert(ExponentialFamilyDistribution, result1) |
| 122 | + |
| 123 | + @test getconditioner(ef_result1) ≈ 1.0 |
| 124 | + @test result1 isa Laplace |
| 125 | + @test cost_recorder.recorded_values[end] < cost_recorder.recorded_values[1] |
| 126 | +end |
0 commit comments