Skip to content

Commit becf4ff

Browse files
authored
Merge pull request #248 from skoghoern/feature/logscale_normal_normal
Optimize compute_logscale for Normal×Normal distributions
2 parents ecfea82 + fbcf951 commit becf4ff

File tree

7 files changed

+144
-132
lines changed

7 files changed

+144
-132
lines changed

benchmark/benchmarks/normal_family/mv_normal_mean_covariance.jl

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,74 @@
11
using LinearAlgebra
22
using StaticArrays
33
using ExponentialFamily.BayesBase
4+
using StableRNGs
45

56
SUITE["mvnormal_mean_covariance"] = BenchmarkGroup(
67
["mvnormal_mean_covariance", "normal_family", "distribution"],
78
"prod" => BenchmarkGroup(["prod", "multiplication"])
89
)
910

1011
# Helpers ==========================
11-
spd_matrix(T::Type, d::Integer) = begin
12-
A = rand(T, d, d)
12+
spd_matrix(rng, ::Type{T}, d::Integer) where {T <: Real} = begin
13+
A = rand(rng, T, d, d)
1314
# Make symmetric positive definite and well-conditioned enough for Float16
14-
Σ = A' * A
15+
Σ = A' * A .+ diagm(ones(T, d))
1516
Matrix{T}(Symmetric(Σ))
1617
end
1718

18-
dense_dist(::Type{T}, d::Integer) where {T <: Real} =
19-
MvNormalMeanCovariance(rand(T, d), spd_matrix(T, d))
19+
dense_dist(rng, ::Type{T}, d::Integer) where {T <: Real} =
20+
MvNormalMeanCovariance(rand(rng, T, d), spd_matrix(rng, T, d))
2021

21-
diag_dist(::Type{T}, d::Integer) where {T <: Real} = begin
22-
μ = rand(T, d)
23-
σ = abs.(rand(T, d)) .+ one(T)
22+
diag_dist(rng, ::Type{T}, d::Integer) where {T <: Real} = begin
23+
μ = rand(rng, T, d)
24+
σ = abs.(rand(rng, T, d)) .+ one(T)
2425
MvNormalMeanCovariance(μ, Diagonal(σ))
2526
end
2627

27-
static_dist(::Type{T}, ::Val{D}) where {T <: Real, D} = begin
28-
μ = SVector{D, T}(rand(T, D))
29-
A = SMatrix{D, D, T}(rand(T, D, D))
30-
Σ = A' * A
28+
static_dist(rng, ::Type{T}, ::Val{D}) where {T <: Real, D} = begin
29+
μ = SVector{D, T}(rand(rng, T, D))
30+
A = SMatrix{D, D, T}(rand(rng, T, D, D))
31+
Σ = A' * A .+ diagm(ones(T, D))
3132
MvNormalMeanCovariance(μ, Σ)
3233
end
3334

3435
# prod (PreserveType) ==============
3536
let dims_dense = (10, 50, 100)
3637
# Dense × Dense (BLAS path for Float32/64)
38+
rng = StableRNG(42)
3739
for d in dims_dense
3840
for (TL, TR) in ((Float64, Float64), (Float32, Float32), (Float32, Float64))
39-
left = dense_dist(TL, d)
40-
right = dense_dist(TR, d)
41+
left = dense_dist(rng, TL, d)
42+
right = dense_dist(rng, TR, d)
4143
SUITE["mvnormal_mean_covariance"]["prod"]["PreserveType"]["Dense×Dense"]["d=$d"]["$(TL)×$(TR)"] =
4244
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
4345
end
4446
# Mixed lower precision (generic path)
4547
let TL = Float16, TR = Float64
46-
left = dense_dist(TL, d)
47-
right = dense_dist(TR, d)
48+
left = dense_dist(rng, TL, d)
49+
right = dense_dist(rng, TR, d)
4850
SUITE["mvnormal_mean_covariance"]["prod"]["PreserveType"]["Dense×Dense"]["d=$d"]["Float16×Float64"] =
4951
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
5052
end
5153
end
5254

5355
# Dense × Diagonal
5456
for d in dims_dense
55-
left = dense_dist(Float64, d)
56-
right = diag_dist(Float64, d)
57+
left = dense_dist(rng, Float64, d)
58+
right = diag_dist(rng, Float64, d)
5759
SUITE["mvnormal_mean_covariance"]["prod"]["PreserveType"]["Dense×Diag"]["d=$d"]["Float64×Float64"] =
5860
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
5961

60-
left = dense_dist(Float32, d)
61-
right = diag_dist(Float64, d)
62+
left = dense_dist(rng, Float32, d)
63+
right = diag_dist(rng, Float64, d)
6264
SUITE["mvnormal_mean_covariance"]["prod"]["PreserveType"]["Dense×Diag"]["d=$d"]["Float32×Float64"] =
6365
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
6466
end
6567

6668
# Diagonal × Diagonal
6769
for d in (10, 50, 100)
68-
left = diag_dist(Float64, d)
69-
right = diag_dist(Float64, d)
70+
left = diag_dist(rng, Float64, d)
71+
right = diag_dist(rng, Float64, d)
7072
SUITE["mvnormal_mean_covariance"]["prod"]["PreserveType"]["Diag×Diag"]["d=$d"]["Float64×Float64"] =
7173
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
7274
end
@@ -76,20 +78,20 @@ let dims_dense = (10, 50, 100)
7678
for TL in (Float64, Float32, Float16)
7779
for TR in (Float64, Float32, Float16)
7880
# Static × Static
79-
left = static_dist(TL, Val(D))
80-
right = static_dist(TR, Val(D))
81+
left = static_dist(rng, TL, Val(D))
82+
right = static_dist(rng, TR, Val(D))
8183
SUITE["mvnormal_mean_covariance"]["prod"]["PreserveType"]["SArray×SArray"]["d=$D"]["$(TL)×$(TR)"] =
8284
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
8385

8486
# Static × Dense
85-
left = static_dist(TL, Val(D))
86-
right = dense_dist(TR, D)
87+
left = static_dist(rng, TL, Val(D))
88+
right = dense_dist(rng, TR, D)
8789
SUITE["mvnormal_mean_covariance"]["prod"]["PreserveType"]["SArray×Dense"]["d=$D"]["$(TL)×$(TR)"] =
8890
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
8991

9092
# Static × Diagonal
91-
left = static_dist(TL, Val(D))
92-
right = diag_dist(TR, D)
93+
left = static_dist(rng, TL, Val(D))
94+
right = diag_dist(rng, TR, D)
9395
SUITE["mvnormal_mean_covariance"]["prod"]["PreserveType"]["SArray×Diag"]["d=$D"]["$(TL)×$(TR)"] =
9496
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
9597
end
@@ -101,7 +103,8 @@ end
101103

102104
for dims in (10, 50, 100)
103105
for T in (Float64, Float32, Float16)
104-
dist = dense_dist(T, dims)
106+
rng = StableRNG(42)
107+
dist = dense_dist(rng, T, dims)
105108
SUITE["mvnormal_mean_covariance"]["compute_logscale"]["d=$dims"]["$(T)"] =
106109
@benchmarkable compute_logscale($dist, $dist, $dist)
107110
end

benchmark/benchmarks/normal_family/mv_normal_mean_precision.jl

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,73 @@
11
using LinearAlgebra
22
using StaticArrays
3+
using StableRNGs
34

45
SUITE["mvnormal_mean_precision"] = BenchmarkGroup(
56
["mvnormal_mean_precision", "normal_family", "distribution"],
67
"prod" => BenchmarkGroup(["prod", "multiplication"])
78
)
89

910
# Helpers ==========================
10-
spd_matrix(T::Type, d::Integer) = begin
11-
A = rand(T, d, d)
11+
spd_matrix(rng, ::Type{T}, d::Integer) where {T <: Real} = begin
12+
A = rand(rng, T, d, d)
1213
# Make symmetric positive definite and well-conditioned enough for Float16
13-
Σ = A' * A
14+
Σ = A' * A .+ diagm(ones(T, d))
1415
Matrix{T}(Symmetric(Σ))
1516
end
1617

17-
dense_dist(::Type{T}, d::Integer) where {T <: Real} =
18-
MvNormalMeanPrecision(rand(T, d), spd_matrix(T, d))
18+
dense_dist(rng, ::Type{T}, d::Integer) where {T <: Real} =
19+
MvNormalMeanPrecision(rand(rng, T, d), spd_matrix(rng, T, d))
1920

20-
diag_dist(::Type{T}, d::Integer) where {T <: Real} = begin
21-
μ = rand(T, d)
22-
σ = abs.(rand(T, d)) .+ one(T)
21+
diag_dist(rng, ::Type{T}, d::Integer) where {T <: Real} = begin
22+
μ = rand(rng, T, d)
23+
σ = abs.(rand(rng, T, d)) .+ one(T)
2324
MvNormalMeanPrecision(μ, Diagonal(σ))
2425
end
2526

26-
static_dist(::Type{T}, ::Val{D}) where {T <: Real, D} = begin
27-
μ = SVector{D, T}(rand(T, D))
28-
A = SMatrix{D, D, T}(rand(T, D, D))
29-
Σ = A' * A
27+
static_dist(rng, ::Type{T}, ::Val{D}) where {T <: Real, D} = begin
28+
μ = SVector{D, T}(rand(rng, T, D))
29+
A = SMatrix{D, D, T}(rand(rng, T, D, D))
30+
Σ = A' * A .+ diagm(ones(T, D))
3031
MvNormalMeanPrecision(μ, Σ)
3132
end
3233

3334
# prod (PreserveType) ==============
3435
let dims_dense = (10, 50, 100)
3536
# Dense × Dense (BLAS path for Float32/64)
37+
rng = StableRNG(42)
3638
for d in dims_dense
3739
for (TL, TR) in ((Float64, Float64), (Float32, Float32), (Float32, Float64))
38-
left = dense_dist(TL, d)
39-
right = dense_dist(TR, d)
40+
left = dense_dist(rng, TL, d)
41+
right = dense_dist(rng, TR, d)
4042
SUITE["mvnormal_mean_precision"]["prod"]["PreserveType"]["Dense×Dense"]["d=$d"]["$(TL)×$(TR)"] =
4143
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
4244
end
4345
# Mixed lower precision (generic path)
4446
let TL = Float16, TR = Float64
45-
left = dense_dist(TL, d)
46-
right = dense_dist(TR, d)
47+
left = dense_dist(rng, TL, d)
48+
right = dense_dist(rng, TR, d)
4749
SUITE["mvnormal_mean_precision"]["prod"]["PreserveType"]["Dense×Dense"]["d=$d"]["Float16×Float64"] =
4850
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
4951
end
5052
end
5153

5254
# Dense × Diagonal
5355
for d in dims_dense
54-
left = dense_dist(Float64, d)
55-
right = diag_dist(Float64, d)
56+
left = dense_dist(rng, Float64, d)
57+
right = diag_dist(rng, Float64, d)
5658
SUITE["mvnormal_mean_precision"]["prod"]["PreserveType"]["Dense×Diag"]["d=$d"]["Float64×Float64"] =
5759
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
5860

59-
left = dense_dist(Float32, d)
60-
right = diag_dist(Float64, d)
61+
left = dense_dist(rng, Float32, d)
62+
right = diag_dist(rng, Float64, d)
6163
SUITE["mvnormal_mean_precision"]["prod"]["PreserveType"]["Dense×Diag"]["d=$d"]["Float32×Float64"] =
6264
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
6365
end
6466

6567
# Diagonal × Diagonal
6668
for d in (10, 50, 100)
67-
left = diag_dist(Float64, d)
68-
right = diag_dist(Float64, d)
69+
left = diag_dist(rng, Float64, d)
70+
right = diag_dist(rng, Float64, d)
6971
SUITE["mvnormal_mean_precision"]["prod"]["PreserveType"]["Diag×Diag"]["d=$d"]["Float64×Float64"] =
7072
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
7173
end
@@ -75,20 +77,20 @@ let dims_dense = (10, 50, 100)
7577
for TL in (Float64, Float32, Float16)
7678
for TR in (Float64, Float32, Float16)
7779
# Static × Static
78-
left = static_dist(TL, Val(D))
79-
right = static_dist(TR, Val(D))
80+
left = static_dist(rng, TL, Val(D))
81+
right = static_dist(rng, TR, Val(D))
8082
SUITE["mvnormal_mean_precision"]["prod"]["PreserveType"]["SArray×SArray"]["d=$D"]["$(TL)×$(TR)"] =
8183
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
8284

8385
# Static × Dense
84-
left = static_dist(TL, Val(D))
85-
right = dense_dist(TR, D)
86+
left = static_dist(rng, TL, Val(D))
87+
right = dense_dist(rng, TR, D)
8688
SUITE["mvnormal_mean_precision"]["prod"]["PreserveType"]["SArray×Dense"]["d=$D"]["$(TL)×$(TR)"] =
8789
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
8890

8991
# Static × Diagonal
90-
left = static_dist(TL, Val(D))
91-
right = diag_dist(TR, D)
92+
left = static_dist(rng, TL, Val(D))
93+
right = diag_dist(rng, TR, D)
9294
SUITE["mvnormal_mean_precision"]["prod"]["PreserveType"]["SArray×Diag"]["d=$D"]["$(TL)×$(TR)"] =
9395
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
9496
end
@@ -100,7 +102,8 @@ end
100102

101103
for dims in (10, 50, 100)
102104
for T in (Float64, Float32)
103-
dist = dense_dist(T, dims)
105+
rng = StableRNG(42)
106+
dist = dense_dist(rng, T, dims)
104107
SUITE["mvnormal_mean_precision"]["compute_logscale"]["d=$dims"]["$(T)"] =
105108
@benchmarkable compute_logscale($dist, $dist, $dist)
106109
end

benchmark/benchmarks/normal_family/mv_normal_weighted_mean_precision.jl

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,65 +7,66 @@ SUITE["mvnormal_weighted_mean_precision"] = BenchmarkGroup(
77
)
88

99
# Helpers ==========================
10-
spd_matrix(T::Type, d::Integer) = begin
11-
A = rand(T, d, d)
10+
spd_matrix(rng, ::Type{T}, d::Integer) where {T <: Real} = begin
11+
A = rand(rng, T, d, d)
1212
# Make symmetric positive definite and well-conditioned enough for Float16
13-
Σ = A' * A
13+
Σ = A' * A .+ diagm(ones(T, d))
1414
Matrix{T}(Symmetric(Σ))
1515
end
1616

17-
dense_dist(::Type{T}, d::Integer) where {T <: Real} =
18-
MvNormalWeightedMeanPrecision(rand(T, d), spd_matrix(T, d))
17+
dense_dist(rng, ::Type{T}, d::Integer) where {T <: Real} =
18+
MvNormalWeightedMeanPrecision(rand(rng, T, d), spd_matrix(rng, T, d))
1919

20-
diag_dist(::Type{T}, d::Integer) where {T <: Real} = begin
21-
μ = rand(T, d)
22-
σ = abs.(rand(T, d)) .+ one(T)
20+
diag_dist(rng, ::Type{T}, d::Integer) where {T <: Real} = begin
21+
μ = rand(rng, T, d)
22+
σ = abs.(rand(rng, T, d)) .+ one(T)
2323
MvNormalMeanCovariance(μ, Diagonal(σ))
2424
end
2525

26-
static_dist(::Type{T}, ::Val{D}) where {T <: Real, D} = begin
27-
μ = SVector{D, T}(rand(T, D))
28-
A = SMatrix{D, D, T}(rand(T, D, D))
29-
Σ = A' * A
26+
static_dist(rng, ::Type{T}, ::Val{D}) where {T <: Real, D} = begin
27+
μ = SVector{D, T}(rand(rng, T, D))
28+
A = SMatrix{D, D, T}(rand(rng, T, D, D))
29+
Σ = A' * A .+ diagm(ones(T, D))
3030
MvNormalMeanCovariance(μ, Σ)
3131
end
3232

3333
# prod (PreserveType) ==============
3434
let dims_dense = (10, 50, 100)
3535
# Dense × Dense (BLAS path for Float32/64)
36+
rng = StableRNG(42)
3637
for d in dims_dense
3738
for (TL, TR) in ((Float64, Float64), (Float32, Float32), (Float32, Float64))
38-
left = dense_dist(TL, d)
39-
right = dense_dist(TR, d)
39+
left = dense_dist(rng, TL, d)
40+
right = dense_dist(rng, TR, d)
4041
SUITE["mvnormal_weighted_mean_precision"]["prod"]["PreserveType"]["Dense×Dense"]["d=$d"]["$(TL)×$(TR)"] =
4142
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
4243
end
4344
# Mixed lower precision (generic path)
4445
let TL = Float16, TR = Float64
45-
left = dense_dist(TL, d)
46-
right = dense_dist(TR, d)
46+
left = dense_dist(rng, TL, d)
47+
right = dense_dist(rng, TR, d)
4748
SUITE["mvnormal_weighted_mean_precision"]["prod"]["PreserveType"]["Dense×Dense"]["d=$d"]["Float16×Float64"] =
4849
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
4950
end
5051
end
5152

5253
# Dense × Diagonal
5354
for d in dims_dense
54-
left = dense_dist(Float64, d)
55-
right = diag_dist(Float64, d)
55+
left = dense_dist(rng, Float64, d)
56+
right = diag_dist(rng, Float64, d)
5657
SUITE["mvnormal_weighted_mean_precision"]["prod"]["PreserveType"]["Dense×Diag"]["d=$d"]["Float64×Float64"] =
5758
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
5859

59-
left = dense_dist(Float32, d)
60-
right = diag_dist(Float64, d)
60+
left = dense_dist(rng, Float32, d)
61+
right = diag_dist(rng, Float64, d)
6162
SUITE["mvnormal_weighted_mean_precision"]["prod"]["PreserveType"]["Dense×Diag"]["d=$d"]["Float32×Float64"] =
6263
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
6364
end
6465

6566
# Diagonal × Diagonal
6667
for d in (10, 50, 100)
67-
left = diag_dist(Float64, d)
68-
right = diag_dist(Float64, d)
68+
left = diag_dist(rng, Float64, d)
69+
right = diag_dist(rng, Float64, d)
6970
SUITE["mvnormal_weighted_mean_precision"]["prod"]["PreserveType"]["Diag×Diag"]["d=$d"]["Float64×Float64"] =
7071
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
7172
end
@@ -75,20 +76,20 @@ let dims_dense = (10, 50, 100)
7576
for TL in (Float64, Float32, Float16)
7677
for TR in (Float64, Float32, Float16)
7778
# Static × Static
78-
left = static_dist(TL, Val(D))
79-
right = static_dist(TR, Val(D))
79+
left = static_dist(rng, TL, Val(D))
80+
right = static_dist(rng, TR, Val(D))
8081
SUITE["mvnormal_weighted_mean_precision"]["prod"]["PreserveType"]["SArray×SArray"]["d=$D"]["$(TL)×$(TR)"] =
8182
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
8283

8384
# Static × Dense
84-
left = static_dist(TL, Val(D))
85-
right = dense_dist(TR, D)
85+
left = static_dist(rng, TL, Val(D))
86+
right = dense_dist(rng, TR, D)
8687
SUITE["mvnormal_weighted_mean_precision"]["prod"]["PreserveType"]["SArray×Dense"]["d=$D"]["$(TL)×$(TR)"] =
8788
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
8889

8990
# Static × Diagonal
90-
left = static_dist(TL, Val(D))
91-
right = diag_dist(TR, D)
91+
left = static_dist(rng, TL, Val(D))
92+
right = diag_dist(rng, TR, D)
9293
SUITE["mvnormal_weighted_mean_precision"]["prod"]["PreserveType"]["SArray×Diag"]["d=$D"]["$(TL)×$(TR)"] =
9394
@benchmarkable prod(PreserveTypeProd(Distribution), $left, $right)
9495
end
@@ -100,7 +101,8 @@ end
100101

101102
for dims in (10, 50, 100)
102103
for T in (Float64, Float32)
103-
dist = dense_dist(T, dims)
104+
rng = StableRNG(42)
105+
dist = dense_dist(rng, T, dims)
104106
SUITE["mvnormal_weighted_mean_precision"]["compute_logscale"]["d=$dims"]["$(T)"] =
105107
@benchmarkable compute_logscale($dist, $dist, $dist)
106108
end

0 commit comments

Comments
 (0)