Skip to content

Commit ae10936

Browse files
authored
Merge pull request #211 from ReactiveBayes/fix-rand-matrixDirichlet
fix-rand-matrixDirichlet
2 parents cee36d7 + 335551b commit ae10936

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

src/distributions/matrix_dirichlet.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,9 @@ function BayesBase.rand(rng::AbstractRNG, dist::MatrixDirichlet{T}, nsamples::In
7474
end
7575

7676
function BayesBase.rand!(rng::AbstractRNG, dist::MatrixDirichlet, container::AbstractMatrix{T}) where {T <: Real}
77-
samples = vmap(d -> rand(rng, Dirichlet(convert(Vector, d))), eachcol(dist.a))
78-
@views for row in 1:isqrt(length(container))
79-
b = container[:, row]
80-
b[:] .= samples[row]
77+
@views for (i, col) in enumerate(eachcol(dist.a))
78+
rand!(rng, Dirichlet(col), container[:, i])
8179
end
82-
8380
return container
8481
end
8582

test/distributions/matrix_dirichlet_tests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,13 @@ end
148148
@test promote_variate_type(Multivariate, MatrixDirichlet) === Dirichlet
149149
@test promote_variate_type(Matrixvariate, MatrixDirichlet) === MatrixDirichlet
150150
end
151+
152+
@testitem "MatrixDirichlet: rand" begin
153+
include("distributions_setuptests.jl")
154+
155+
@test_throws DimensionMismatch sum(rand(MatrixDirichlet(ones(3, 5))), dims = 1) [1.0;; 1.0;; 1.0]
156+
157+
@test sum(rand(MatrixDirichlet(ones(3, 5))), dims = 1) [1.0;; 1.0;; 1.0;; 1.0;; 1.0]
158+
@test sum(rand(MatrixDirichlet(ones(5, 3))), dims = 1) [1.0;; 1.0;; 1.0]
159+
@test sum(rand(MatrixDirichlet(ones(5, 5))), dims = 1) [1.0;; 1.0;; 1.0;; 1.0;; 1.0]
160+
end

0 commit comments

Comments
 (0)