Skip to content

Commit 9f06eb4

Browse files
authored
Merge pull request #226 from ReactiveBayes/logmean
Implement log mean for tensordirichlet
2 parents 7a5d361 + a5c0164 commit 9f06eb4

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

src/distributions/tensor_dirichlet.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ getlogbasemeasure(::Type{TensorDirichlet}, conditioner) = (x) -> zero(Float64)
7575
getsufficientstatistics(::Type{TensorDirichlet}, conditioner) = (x -> vmap(log, x),)
7676

7777
BayesBase.mean(dist::TensorDirichlet) = dist.a ./ dist.α0
78+
BayesBase.mean(::BroadcastFunction{typeof(log)}, dist::TensorDirichlet) = digamma.(dist.a) .- digamma.(dist.α0)
79+
7880
function BayesBase.cov(dist::TensorDirichlet{T}) where {T}
7981
s = size(dist.a)
8082
news = (first(s), first(s), Base.tail(s)...)

test/distributions/tensor_dirichlet_test.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,28 @@ end
7373
end
7474
end
7575

76+
@testitem "TensorDirichlet: logmean" begin
77+
include("distributions_setuptests.jl")
78+
79+
for rank in (3, 5)
80+
for d in (2, 5, 10)
81+
for _ in 1:10
82+
alpha = rand([d for _ in 1:rank]...)
83+
84+
distribution = TensorDirichlet(alpha)
85+
mat_of_dir = Dirichlet.(eachslice(alpha, dims = Tuple(2:rank)))
86+
87+
temp = mean.(Base.Broadcast.BroadcastFunction(log), mat_of_dir)
88+
mat_mean = similar(alpha)
89+
for i in CartesianIndices(Base.tail(size(alpha)))
90+
mat_mean[:, i] = temp[i]
91+
end
92+
@test mean(Base.Broadcast.BroadcastFunction(log), distribution) mat_mean
93+
end
94+
end
95+
end
96+
end
97+
7698
@testitem "TensorDirichlet: std" begin
7799
include("distributions_setuptests.jl")
78100

0 commit comments

Comments
 (0)