Skip to content

Commit eb33e6c

Browse files
authored
Merge pull request #206 from ReactiveBayes/dev_mvscalenormal
Add MvNormalMeanScalePrecision distribution
2 parents 95af252 + 24108bb commit eb33e6c

File tree

6 files changed

+499
-3
lines changed

6 files changed

+499
-3
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"
2828

2929
[compat]
3030
Aqua = "0.8.7"
31-
BayesBase = "1.2"
31+
BayesBase = "1.5.0"
3232
Distributions = "0.25"
3333
DomainSets = "0.5.2, 0.6, 0.7"
3434
FastCholesky = "1.0"
@@ -57,8 +57,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5757
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
5858
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
5959
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
60-
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
6160
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
61+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
6262
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
6363
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6464

docs/src/library.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ ExponentialFamily.NormalWeightedMeanPrecision
1414
ExponentialFamily.MvNormalMeanPrecision
1515
ExponentialFamily.MvNormalMeanCovariance
1616
ExponentialFamily.MvNormalWeightedMeanPrecision
17+
ExponentialFamily.MvNormalMeanScalePrecision
1718
ExponentialFamily.JointNormal
1819
ExponentialFamily.JointGaussian
1920
ExponentialFamily.WishartFast

src/ExponentialFamily.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ include("distributions/normal_family/mv_normal_mean_covariance.jl")
4646
include("distributions/normal_family/mv_normal_mean_precision.jl")
4747
include("distributions/normal_family/mv_normal_weighted_mean_precision.jl")
4848
include("distributions/normal_family/normal_family.jl")
49+
include("distributions/normal_family/mv_normal_mean_scale_precision.jl")
4950
include("distributions/gamma_inverse.jl")
5051
include("distributions/geometric.jl")
5152
include("distributions/matrix_dirichlet.jl")
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
export MvNormalMeanScalePrecision, MvGaussianMeanScalePrecision
2+
3+
import Distributions: logdetcov, distrname, sqmahal, sqmahal!, AbstractMvNormal
4+
import LinearAlgebra: diag, Diagonal, dot
5+
import Base: ndims, precision, length, size, prod
6+
7+
"""
8+
MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal
9+
10+
A multivariate normal distribution with mean `μ` and scale parameter `γ` that scales the identity precision matrix.
11+
12+
# Type Parameters
13+
- `T`: The element type of the mean vector and scale parameter
14+
- `M`: The type of the mean vector, which must be a subtype of `AbstractVector{T}`
15+
16+
# Fields
17+
- `μ::M`: The mean vector of the multivariate normal distribution
18+
- `γ::T`: The scale parameter that scales the identity precision matrix
19+
20+
# Notes
21+
The precision matrix of this distribution is `γ * I`, where `I` is the identity matrix.
22+
The covariance matrix is the inverse of the precision matrix, i.e., `(1/γ) * I`.
23+
"""
24+
struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal
25+
μ::M
26+
γ::T
27+
end
28+
29+
const MvGaussianMeanScalePrecision = MvNormalMeanScalePrecision
30+
31+
function MvNormalMeanScalePrecision::AbstractVector{<:Real}, γ::Real)
32+
T = promote_type(eltype(μ), eltype(γ))
33+
return MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ))
34+
end
35+
36+
function MvNormalMeanScalePrecision::AbstractVector{<:Integer}, γ::Real)
37+
return MvNormalMeanScalePrecision(float.(μ), float(γ))
38+
end
39+
40+
function MvNormalMeanScalePrecision::AbstractVector{T}) where {T}
41+
return MvNormalMeanScalePrecision(μ, convert(T, 1))
42+
end
43+
44+
function MvNormalMeanScalePrecision::AbstractVector{T1}, γ::T2) where {T1, T2}
45+
T = promote_type(T1, T2)
46+
μ_new = convert(AbstractArray{T}, μ)
47+
γ_new = convert(T, γ)(length(μ))
48+
return MvNormalMeanScalePrecision(μ_new, γ_new)
49+
end
50+
51+
function unpack_parameters(::Type{MvNormalMeanScalePrecision}, packed)
52+
p₁ = view(packed, 1:length(packed)-1)
53+
p₂ = packed[end]
54+
55+
return (p₁, p₂)
56+
end
57+
58+
function isproper(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}, η, conditioner)
59+
k = length(η) - 1
60+
if length(η) < 2 || (length(η) !== k + 1)
61+
return false
62+
end
63+
(η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η)
64+
return isnothing(conditioner) && isone(size(η₂, 1)) && isposdef(-η₂)
65+
end
66+
67+
function (::MeanToNatural{MvNormalMeanScalePrecision})(tuple_of_θ::Tuple{Any, Any})
68+
(μ, γ) = tuple_of_θ
69+
return* μ, - γ / 2)
70+
end
71+
72+
function (::NaturalToMean{MvNormalMeanScalePrecision})(tuple_of_η::Tuple{Any, Any})
73+
(η₁, η₂) = tuple_of_η
74+
γ = -2 * η₂
75+
return (η₁ / γ, γ)
76+
end
77+
78+
function nabs2(x)
79+
return sum(map(abs2, x))
80+
end
81+
82+
getsufficientstatistics(::Type{MvNormalMeanScalePrecision}) = (identity, nabs2)
83+
84+
# Conversions
85+
function Base.convert(
86+
::Type{MvNormal{T, C, M}},
87+
dist::MvNormalMeanScalePrecision
88+
) where {T <: Real, C <: Distributions.PDMats.PDMat{T, Matrix{T}}, M <: AbstractVector{T}}
89+
m, σ = mean(dist), std(dist)
90+
return MvNormal(convert(M, m), convert(T, σ))
91+
end
92+
93+
function Base.convert(
94+
::Type{MvNormalMeanScalePrecision{T, M}},
95+
dist::MvNormalMeanScalePrecision
96+
) where {T <: Real, M <: AbstractArray{T}}
97+
m, γ = mean(dist), dist.γ
98+
return MvNormalMeanScalePrecision{T, M}(convert(M, m), convert(T, γ))
99+
end
100+
101+
function Base.convert(
102+
::Type{MvNormalMeanScalePrecision{T}},
103+
dist::MvNormalMeanScalePrecision
104+
) where {T <: Real}
105+
return convert(MvNormalMeanScalePrecision{T, AbstractArray{T, 1}}, dist)
106+
end
107+
108+
function Base.convert(::Type{MvNormalMeanCovariance}, dist::MvNormalMeanScalePrecision)
109+
m, σ = mean(dist), cov(dist)
110+
return MvNormalMeanCovariance(m, σ * diagm(ones(length(m))))
111+
end
112+
113+
function Base.convert(::Type{MvNormalMeanPrecision}, dist::MvNormalMeanScalePrecision)
114+
m, γ = mean(dist), precision(dist)
115+
return MvNormalMeanPrecision(m, γ * diagm(ones(length(m))))
116+
end
117+
118+
function Base.convert(::Type{MvNormalWeightedMeanPrecision}, dist::MvNormalMeanScalePrecision)
119+
m, γ = mean(dist), precision(dist)
120+
return MvNormalWeightedMeanPrecision* m, γ * diagm(ones(length(m))))
121+
end
122+
123+
Distributions.distrname(::MvNormalMeanScalePrecision) = "MvNormalMeanScalePrecision"
124+
125+
BayesBase.weightedmean(dist::MvNormalMeanScalePrecision) = precision(dist) * mean(dist)
126+
127+
BayesBase.mean(dist::MvNormalMeanScalePrecision) = dist.μ
128+
BayesBase.mode(dist::MvNormalMeanScalePrecision) = mean(dist)
129+
BayesBase.var(dist::MvNormalMeanScalePrecision) = diag(cov(dist))
130+
BayesBase.cov(dist::MvNormalMeanScalePrecision) = cholinv(invcov(dist))
131+
BayesBase.invcov(dist::MvNormalMeanScalePrecision) = scale(dist) * I(length(mean(dist)))
132+
BayesBase.std(dist::MvNormalMeanScalePrecision) = cholsqrt(cov(dist))
133+
BayesBase.logdetcov(dist::MvNormalMeanScalePrecision) = -chollogdet(invcov(dist))
134+
BayesBase.scale(dist::MvNormalMeanScalePrecision) = dist.γ
135+
BayesBase.params(dist::MvNormalMeanScalePrecision) = (mean(dist), scale(dist))
136+
137+
function Distributions.sqmahal(dist::MvNormalMeanScalePrecision, x::AbstractVector)
138+
T = promote_type(eltype(x), paramfloattype(dist))
139+
return sqmahal!(similar(x, T), dist, x)
140+
end
141+
142+
function Distributions.sqmahal!(r, dist::MvNormalMeanScalePrecision, x::AbstractVector)
143+
μ, γ = params(dist)
144+
@inbounds @simd for i in 1:length(r)
145+
r[i] = μ[i] - x[i]
146+
end
147+
return dot3arg(r, γ, r) # x' * A * x
148+
end
149+
150+
Base.eltype(::MvNormalMeanScalePrecision{T}) where {T} = T
151+
Base.precision(dist::MvNormalMeanScalePrecision) = invcov(dist)
152+
Base.length(dist::MvNormalMeanScalePrecision) = length(mean(dist))
153+
Base.ndims(dist::MvNormalMeanScalePrecision) = length(dist)
154+
Base.size(dist::MvNormalMeanScalePrecision) = (length(dist),)
155+
156+
Base.convert(::Type{<:MvNormalMeanScalePrecision}, μ::AbstractVector, γ::Real) = MvNormalMeanScalePrecision(μ, γ)
157+
158+
function Base.convert(::Type{<:MvNormalMeanScalePrecision{T}}, μ::AbstractVector, γ::Real) where {T <: Real}
159+
MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ))
160+
end
161+
162+
BayesBase.vague(::Type{<:MvNormalMeanScalePrecision}, dims::Int) =
163+
MvNormalMeanScalePrecision(zeros(Float64, dims), convert(Float64, tiny))
164+
165+
BayesBase.default_prod_rule(::Type{<:MvNormalMeanScalePrecision}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution)
166+
167+
function BayesBase.prod(::PreserveTypeProd{Distribution}, left::MvNormalMeanScalePrecision, right::MvNormalMeanScalePrecision)
168+
w = scale(left) + scale(right)
169+
m = (scale(left) * mean(left) + scale(right) * mean(right)) / w
170+
return MvNormalMeanScalePrecision(m, w)
171+
end
172+
173+
BayesBase.default_prod_rule(::Type{<:MultivariateNormalDistributionsFamily}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution)
174+
175+
function BayesBase.prod(
176+
::PreserveTypeProd{Distribution},
177+
left::L,
178+
right::R
179+
) where {L <: MultivariateNormalDistributionsFamily, R <: MvNormalMeanScalePrecision}
180+
wleft = convert(MvNormalWeightedMeanPrecision, left)
181+
wright = convert(MvNormalWeightedMeanPrecision, right)
182+
return prod(BayesBase.default_prod_rule(wleft, wright), wleft, wright)
183+
end
184+
185+
function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}) where {T}
186+
μ, γ = params(dist)
187+
d = length(μ)
188+
return rand!(rng, dist, Vector{T}(undef, d))
189+
end
190+
191+
function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}, size::Int64) where {T}
192+
container = Matrix{T}(undef, length(dist), size)
193+
return rand!(rng, dist, container)
194+
end
195+
196+
# FIXME: This is not the most efficient way to generate random samples within container
197+
# it needs to work with scale method, not with std
198+
function BayesBase.rand!(
199+
rng::AbstractRNG,
200+
dist::MvGaussianMeanScalePrecision,
201+
container::AbstractArray{T}
202+
) where {T <: Real}
203+
preallocated = similar(container)
204+
randn!(rng, reshape(preallocated, length(preallocated)))
205+
μ, L = mean_std(dist)
206+
@views for i in axes(preallocated, 2)
207+
copyto!(container[:, i], μ)
208+
mul!(container[:, i], L, preallocated[:, i], 1, 1)
209+
end
210+
container
211+
end
212+
213+
function getsupport(ef::ExponentialFamilyDistribution{MvNormalMeanScalePrecision})
214+
dim = length(getnaturalparameters(ef)) - 1
215+
return Domain(IndicatorFunction{AbstractVector}(MvNormalDomainIndicator(dim)))
216+
end
217+
218+
isbasemeasureconstant(::Type{MvNormalMeanScalePrecision}) = ConstantBaseMeasure()
219+
220+
getbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> (2π)^(-length(x) / 2)
221+
222+
getlogbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> -length(x) / 2 * log2π
223+
224+
getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
225+
(η) -> begin
226+
η1 = @view η[1:end-1]
227+
η2 = η[end]
228+
k = length(η1)
229+
Cinv = inv(η2)
230+
return -dot(η1, 1/4*Cinv, η1) - (k / 2)*log(-2*η2)
231+
end
232+
233+
getgradlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
234+
(η) -> begin
235+
η1 = @view η[1:end-1]
236+
η2 = η[end]
237+
inv2 = inv(η2)
238+
k = length(η1)
239+
return pack_parameters(MvNormalMeanCovariance, (-1/(2*η2) * η1, dot(η1,η1) / 4*inv2^2 - k/2 * inv2))
240+
end
241+
242+
getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
243+
(η) -> begin
244+
η1 = @view η[1:end-1]
245+
η2 = η[end]
246+
k = length(η1)
247+
248+
η1_part = -inv(2*η2)* I(length(η1))
249+
η1η2 = zeros(k, 1)
250+
η1η2 .= η1*inv(2*η2^2)
251+
252+
η2_part = k*inv(2abs2(η2)) - dot(η1,η1) / (2*η2^3)
253+
254+
return ArrowheadMatrix(η2_part, η1η2, diag(η1_part))
255+
end
256+
257+
258+
getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
259+
(θ) -> begin
260+
μ = @view θ[1:end-1]
261+
γ = θ[end]
262+
k = length(μ)
263+
264+
matrix = zeros(eltype(μ), (k+1))
265+
matrix[1:k] .= γ
266+
matrix[k+1] = k*inv(2abs2(γ))
267+
return Diagonal(matrix)
268+
end

test/distributions/distributions_setuptests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,4 +557,4 @@ function test_generic_simple_exponentialfamily_product(
557557
end
558558

559559
return true
560-
end
560+
end

0 commit comments

Comments
 (0)