|
| 1 | +export MvNormalMeanScalePrecision, MvGaussianMeanScalePrecision |
| 2 | + |
| 3 | +import Distributions: logdetcov, distrname, sqmahal, sqmahal!, AbstractMvNormal |
| 4 | +import LinearAlgebra: diag, Diagonal, dot |
| 5 | +import Base: ndims, precision, length, size, prod |
| 6 | + |
| 7 | +""" |
| 8 | + MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal |
| 9 | +
|
| 10 | +A multivariate normal distribution with mean `μ` and scale parameter `γ` that scales the identity precision matrix. |
| 11 | +
|
| 12 | +# Type Parameters |
| 13 | +- `T`: The element type of the mean vector and scale parameter |
| 14 | +- `M`: The type of the mean vector, which must be a subtype of `AbstractVector{T}` |
| 15 | +
|
| 16 | +# Fields |
| 17 | +- `μ::M`: The mean vector of the multivariate normal distribution |
| 18 | +- `γ::T`: The scale parameter that scales the identity precision matrix |
| 19 | +
|
| 20 | +# Notes |
| 21 | +The precision matrix of this distribution is `γ * I`, where `I` is the identity matrix. |
| 22 | +The covariance matrix is the inverse of the precision matrix, i.e., `(1/γ) * I`. |
| 23 | +""" |
| 24 | +struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal |
| 25 | + μ::M |
| 26 | + γ::T |
| 27 | +end |
| 28 | + |
| 29 | +const MvGaussianMeanScalePrecision = MvNormalMeanScalePrecision |
| 30 | + |
| 31 | +function MvNormalMeanScalePrecision(μ::AbstractVector{<:Real}, γ::Real) |
| 32 | + T = promote_type(eltype(μ), eltype(γ)) |
| 33 | + return MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ)) |
| 34 | +end |
| 35 | + |
| 36 | +function MvNormalMeanScalePrecision(μ::AbstractVector{<:Integer}, γ::Real) |
| 37 | + return MvNormalMeanScalePrecision(float.(μ), float(γ)) |
| 38 | +end |
| 39 | + |
| 40 | +function MvNormalMeanScalePrecision(μ::AbstractVector{T}) where {T} |
| 41 | + return MvNormalMeanScalePrecision(μ, convert(T, 1)) |
| 42 | +end |
| 43 | + |
| 44 | +function MvNormalMeanScalePrecision(μ::AbstractVector{T1}, γ::T2) where {T1, T2} |
| 45 | + T = promote_type(T1, T2) |
| 46 | + μ_new = convert(AbstractArray{T}, μ) |
| 47 | + γ_new = convert(T, γ)(length(μ)) |
| 48 | + return MvNormalMeanScalePrecision(μ_new, γ_new) |
| 49 | +end |
| 50 | + |
| 51 | +function unpack_parameters(::Type{MvNormalMeanScalePrecision}, packed) |
| 52 | + p₁ = view(packed, 1:length(packed)-1) |
| 53 | + p₂ = packed[end] |
| 54 | + |
| 55 | + return (p₁, p₂) |
| 56 | +end |
| 57 | + |
| 58 | +function isproper(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}, η, conditioner) |
| 59 | + k = length(η) - 1 |
| 60 | + if length(η) < 2 || (length(η) !== k + 1) |
| 61 | + return false |
| 62 | + end |
| 63 | + (η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η) |
| 64 | + return isnothing(conditioner) && isone(size(η₂, 1)) && isposdef(-η₂) |
| 65 | +end |
| 66 | + |
| 67 | +function (::MeanToNatural{MvNormalMeanScalePrecision})(tuple_of_θ::Tuple{Any, Any}) |
| 68 | + (μ, γ) = tuple_of_θ |
| 69 | + return (γ * μ, - γ / 2) |
| 70 | +end |
| 71 | + |
| 72 | +function (::NaturalToMean{MvNormalMeanScalePrecision})(tuple_of_η::Tuple{Any, Any}) |
| 73 | + (η₁, η₂) = tuple_of_η |
| 74 | + γ = -2 * η₂ |
| 75 | + return (η₁ / γ, γ) |
| 76 | +end |
| 77 | + |
| 78 | +function nabs2(x) |
| 79 | + return sum(map(abs2, x)) |
| 80 | +end |
| 81 | + |
| 82 | +getsufficientstatistics(::Type{MvNormalMeanScalePrecision}) = (identity, nabs2) |
| 83 | + |
| 84 | +# Conversions |
| 85 | +function Base.convert( |
| 86 | + ::Type{MvNormal{T, C, M}}, |
| 87 | + dist::MvNormalMeanScalePrecision |
| 88 | +) where {T <: Real, C <: Distributions.PDMats.PDMat{T, Matrix{T}}, M <: AbstractVector{T}} |
| 89 | + m, σ = mean(dist), std(dist) |
| 90 | + return MvNormal(convert(M, m), convert(T, σ)) |
| 91 | +end |
| 92 | + |
| 93 | +function Base.convert( |
| 94 | + ::Type{MvNormalMeanScalePrecision{T, M}}, |
| 95 | + dist::MvNormalMeanScalePrecision |
| 96 | +) where {T <: Real, M <: AbstractArray{T}} |
| 97 | + m, γ = mean(dist), dist.γ |
| 98 | + return MvNormalMeanScalePrecision{T, M}(convert(M, m), convert(T, γ)) |
| 99 | +end |
| 100 | + |
| 101 | +function Base.convert( |
| 102 | + ::Type{MvNormalMeanScalePrecision{T}}, |
| 103 | + dist::MvNormalMeanScalePrecision |
| 104 | +) where {T <: Real} |
| 105 | + return convert(MvNormalMeanScalePrecision{T, AbstractArray{T, 1}}, dist) |
| 106 | +end |
| 107 | + |
| 108 | +function Base.convert(::Type{MvNormalMeanCovariance}, dist::MvNormalMeanScalePrecision) |
| 109 | + m, σ = mean(dist), cov(dist) |
| 110 | + return MvNormalMeanCovariance(m, σ * diagm(ones(length(m)))) |
| 111 | +end |
| 112 | + |
| 113 | +function Base.convert(::Type{MvNormalMeanPrecision}, dist::MvNormalMeanScalePrecision) |
| 114 | + m, γ = mean(dist), precision(dist) |
| 115 | + return MvNormalMeanPrecision(m, γ * diagm(ones(length(m)))) |
| 116 | +end |
| 117 | + |
| 118 | +function Base.convert(::Type{MvNormalWeightedMeanPrecision}, dist::MvNormalMeanScalePrecision) |
| 119 | + m, γ = mean(dist), precision(dist) |
| 120 | + return MvNormalWeightedMeanPrecision(γ * m, γ * diagm(ones(length(m)))) |
| 121 | +end |
| 122 | + |
| 123 | +Distributions.distrname(::MvNormalMeanScalePrecision) = "MvNormalMeanScalePrecision" |
| 124 | + |
| 125 | +BayesBase.weightedmean(dist::MvNormalMeanScalePrecision) = precision(dist) * mean(dist) |
| 126 | + |
| 127 | +BayesBase.mean(dist::MvNormalMeanScalePrecision) = dist.μ |
| 128 | +BayesBase.mode(dist::MvNormalMeanScalePrecision) = mean(dist) |
| 129 | +BayesBase.var(dist::MvNormalMeanScalePrecision) = diag(cov(dist)) |
| 130 | +BayesBase.cov(dist::MvNormalMeanScalePrecision) = cholinv(invcov(dist)) |
| 131 | +BayesBase.invcov(dist::MvNormalMeanScalePrecision) = scale(dist) * I(length(mean(dist))) |
| 132 | +BayesBase.std(dist::MvNormalMeanScalePrecision) = cholsqrt(cov(dist)) |
| 133 | +BayesBase.logdetcov(dist::MvNormalMeanScalePrecision) = -chollogdet(invcov(dist)) |
| 134 | +BayesBase.scale(dist::MvNormalMeanScalePrecision) = dist.γ |
| 135 | +BayesBase.params(dist::MvNormalMeanScalePrecision) = (mean(dist), scale(dist)) |
| 136 | + |
| 137 | +function Distributions.sqmahal(dist::MvNormalMeanScalePrecision, x::AbstractVector) |
| 138 | + T = promote_type(eltype(x), paramfloattype(dist)) |
| 139 | + return sqmahal!(similar(x, T), dist, x) |
| 140 | +end |
| 141 | + |
| 142 | +function Distributions.sqmahal!(r, dist::MvNormalMeanScalePrecision, x::AbstractVector) |
| 143 | + μ, γ = params(dist) |
| 144 | + @inbounds @simd for i in 1:length(r) |
| 145 | + r[i] = μ[i] - x[i] |
| 146 | + end |
| 147 | + return dot3arg(r, γ, r) # x' * A * x |
| 148 | +end |
| 149 | + |
| 150 | +Base.eltype(::MvNormalMeanScalePrecision{T}) where {T} = T |
| 151 | +Base.precision(dist::MvNormalMeanScalePrecision) = invcov(dist) |
| 152 | +Base.length(dist::MvNormalMeanScalePrecision) = length(mean(dist)) |
| 153 | +Base.ndims(dist::MvNormalMeanScalePrecision) = length(dist) |
| 154 | +Base.size(dist::MvNormalMeanScalePrecision) = (length(dist),) |
| 155 | + |
| 156 | +Base.convert(::Type{<:MvNormalMeanScalePrecision}, μ::AbstractVector, γ::Real) = MvNormalMeanScalePrecision(μ, γ) |
| 157 | + |
| 158 | +function Base.convert(::Type{<:MvNormalMeanScalePrecision{T}}, μ::AbstractVector, γ::Real) where {T <: Real} |
| 159 | + MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ)) |
| 160 | +end |
| 161 | + |
| 162 | +BayesBase.vague(::Type{<:MvNormalMeanScalePrecision}, dims::Int) = |
| 163 | + MvNormalMeanScalePrecision(zeros(Float64, dims), convert(Float64, tiny)) |
| 164 | + |
| 165 | +BayesBase.default_prod_rule(::Type{<:MvNormalMeanScalePrecision}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution) |
| 166 | + |
| 167 | +function BayesBase.prod(::PreserveTypeProd{Distribution}, left::MvNormalMeanScalePrecision, right::MvNormalMeanScalePrecision) |
| 168 | + w = scale(left) + scale(right) |
| 169 | + m = (scale(left) * mean(left) + scale(right) * mean(right)) / w |
| 170 | + return MvNormalMeanScalePrecision(m, w) |
| 171 | +end |
| 172 | + |
| 173 | +BayesBase.default_prod_rule(::Type{<:MultivariateNormalDistributionsFamily}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution) |
| 174 | + |
| 175 | +function BayesBase.prod( |
| 176 | + ::PreserveTypeProd{Distribution}, |
| 177 | + left::L, |
| 178 | + right::R |
| 179 | +) where {L <: MultivariateNormalDistributionsFamily, R <: MvNormalMeanScalePrecision} |
| 180 | + wleft = convert(MvNormalWeightedMeanPrecision, left) |
| 181 | + wright = convert(MvNormalWeightedMeanPrecision, right) |
| 182 | + return prod(BayesBase.default_prod_rule(wleft, wright), wleft, wright) |
| 183 | +end |
| 184 | + |
| 185 | +function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}) where {T} |
| 186 | + μ, γ = params(dist) |
| 187 | + d = length(μ) |
| 188 | + return rand!(rng, dist, Vector{T}(undef, d)) |
| 189 | +end |
| 190 | + |
| 191 | +function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}, size::Int64) where {T} |
| 192 | + container = Matrix{T}(undef, length(dist), size) |
| 193 | + return rand!(rng, dist, container) |
| 194 | +end |
| 195 | + |
| 196 | +# FIXME: This is not the most efficient way to generate random samples within container |
| 197 | +# it needs to work with scale method, not with std |
| 198 | +function BayesBase.rand!( |
| 199 | + rng::AbstractRNG, |
| 200 | + dist::MvGaussianMeanScalePrecision, |
| 201 | + container::AbstractArray{T} |
| 202 | +) where {T <: Real} |
| 203 | + preallocated = similar(container) |
| 204 | + randn!(rng, reshape(preallocated, length(preallocated))) |
| 205 | + μ, L = mean_std(dist) |
| 206 | + @views for i in axes(preallocated, 2) |
| 207 | + copyto!(container[:, i], μ) |
| 208 | + mul!(container[:, i], L, preallocated[:, i], 1, 1) |
| 209 | + end |
| 210 | + container |
| 211 | +end |
| 212 | + |
| 213 | +function getsupport(ef::ExponentialFamilyDistribution{MvNormalMeanScalePrecision}) |
| 214 | + dim = length(getnaturalparameters(ef)) - 1 |
| 215 | + return Domain(IndicatorFunction{AbstractVector}(MvNormalDomainIndicator(dim))) |
| 216 | +end |
| 217 | + |
| 218 | +isbasemeasureconstant(::Type{MvNormalMeanScalePrecision}) = ConstantBaseMeasure() |
| 219 | + |
| 220 | +getbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> (2π)^(-length(x) / 2) |
| 221 | + |
| 222 | +getlogbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> -length(x) / 2 * log2π |
| 223 | + |
| 224 | +getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = |
| 225 | + (η) -> begin |
| 226 | + η1 = @view η[1:end-1] |
| 227 | + η2 = η[end] |
| 228 | + k = length(η1) |
| 229 | + Cinv = inv(η2) |
| 230 | + return -dot(η1, 1/4*Cinv, η1) - (k / 2)*log(-2*η2) |
| 231 | + end |
| 232 | + |
| 233 | +getgradlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = |
| 234 | + (η) -> begin |
| 235 | + η1 = @view η[1:end-1] |
| 236 | + η2 = η[end] |
| 237 | + inv2 = inv(η2) |
| 238 | + k = length(η1) |
| 239 | + return pack_parameters(MvNormalMeanCovariance, (-1/(2*η2) * η1, dot(η1,η1) / 4*inv2^2 - k/2 * inv2)) |
| 240 | + end |
| 241 | + |
| 242 | +getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = |
| 243 | + (η) -> begin |
| 244 | + η1 = @view η[1:end-1] |
| 245 | + η2 = η[end] |
| 246 | + k = length(η1) |
| 247 | + |
| 248 | + η1_part = -inv(2*η2)* I(length(η1)) |
| 249 | + η1η2 = zeros(k, 1) |
| 250 | + η1η2 .= η1*inv(2*η2^2) |
| 251 | + |
| 252 | + η2_part = k*inv(2abs2(η2)) - dot(η1,η1) / (2*η2^3) |
| 253 | + |
| 254 | + return ArrowheadMatrix(η2_part, η1η2, diag(η1_part)) |
| 255 | + end |
| 256 | + |
| 257 | + |
| 258 | +getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) = |
| 259 | + (θ) -> begin |
| 260 | + μ = @view θ[1:end-1] |
| 261 | + γ = θ[end] |
| 262 | + k = length(μ) |
| 263 | + |
| 264 | + matrix = zeros(eltype(μ), (k+1)) |
| 265 | + matrix[1:k] .= γ |
| 266 | + matrix[k+1] = k*inv(2abs2(γ)) |
| 267 | + return Diagonal(matrix) |
| 268 | + end |
0 commit comments