Skip to content

Commit 8fdff72

Browse files
Red-Portalgithub-actions[bot]sunxd3
authored
Proximal operator for the entropy of location-scale families (#168)
* add proximal operator for the entropy of location-scale families Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * improve docstring for zero gradient entropy estimators * add missing file * add documentation for proximal operator * run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix improve type stability * apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix typo in doctring Co-authored-by: Xianda Sun <[email protected]> * fix typo in comment Co-authored-by: Xianda Sun <[email protected]> * apply code review comments * bump compat bound for subprojects --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Xianda Sun <[email protected]>
1 parent 1aaf0ac commit 8fdff72

17 files changed

+479
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedVI"
22
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
3-
version = "0.3.2"
3+
version = "0.4.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

bench/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2020

2121
[compat]
2222
ADTypes = "1"
23-
AdvancedVI = "0.3"
23+
AdvancedVI = "0.3, 0.4"
2424
BenchmarkTools = "1"
2525
Bijectors = "0.13, 0.14, 0.15"
2626
Distributions = "0.25.111"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1515

1616
[compat]
1717
ADTypes = "1"
18-
AdvancedVI = "0.3, 0.2"
18+
AdvancedVI = "0.4"
1919
Bijectors = "0.13.6, 0.14, 0.15"
2020
Distributions = "0.25"
2121
Documenter = "1"

docs/src/optimization.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,20 @@ For this, an operator acting on the parameters can be supplied via the `operato
3333

3434
### [`ClipScale`](@id clipscale)
3535

36-
For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020].
36+
For the location-scale family, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020].
3737
To ensure this, we provide the following projection operator:
3838

3939
```@docs
4040
ClipScale
4141
```
4242

43+
### [`ProximalLocationScaleEntropy`](@id proximalocationscaleentropy)
44+
45+
ELBO maximization with the location-scale family tends to be unstable when the scale has small eigenvalues or the stepsize is large.
46+
To remedy this, a proximal operator of the entropy[^D2020] can be used.
47+
48+
```@docs
49+
ProximalLocationScaleEntropy
50+
```
51+
4352
[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*.

ext/AdvancedVIBijectorsExt.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Random
99
function AdvancedVI.apply(
1010
op::ClipScale,
1111
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
12+
state,
1213
params,
1314
restructure,
1415
)
@@ -27,6 +28,7 @@ end
2728
function AdvancedVI.apply(
2829
op::ClipScale,
2930
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScaleLowRank}},
31+
state,
3032
params,
3133
restructure,
3234
)
@@ -40,6 +42,26 @@ function AdvancedVI.apply(
4042
return params
4143
end
4244

45+
function AdvancedVI.apply(
46+
::AdvancedVI.ProximalLocationScaleEntropy,
47+
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
48+
leaf::Optimisers.Leaf{<:Union{<:DoG,<:DoWG,<:Descent},S},
49+
params,
50+
restructure,
51+
) where {S}
52+
q = restructure(params)
53+
54+
stepsize = AdvancedVI.stepsize_from_optimizer_state(leaf.rule, leaf.state)
55+
diag_idx = diagind(q.dist.scale)
56+
scale_diag = q.dist.scale[diag_idx]
57+
@. q.dist.scale[diag_idx] =
58+
scale_diag + 1 / 2 * (sqrt(scale_diag^2 + 4 * stepsize) - scale_diag)
59+
60+
params, _ = Optimisers.destructure(q)
61+
62+
return params
63+
end
64+
4365
function AdvancedVI.reparam_with_entropy(
4466
rng::Random.AbstractRNG,
4567
q::Bijectors.TransformedDistribution,

src/AdvancedVI.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,14 @@ function estimate_gradient! end
177177
abstract type AbstractEntropyEstimator end
178178

179179
"""
180-
estimate_entropy(entropy_estimator, mc_samples, q)
180+
estimate_entropy(entropy_estimator, mc_samples, q, q_stop)
181181
182182
Estimate the entropy of `q`.
183183
184184
# Arguments
185185
- `entropy_estimator`: Entropy estimation strategy.
186186
- `q`: Variational approximation.
187+
- `q_stop`: Variational approximation with detached from the automatic differentiation graph.
187188
- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.)
188189
189190
# Returns
@@ -192,7 +193,12 @@ Estimate the entropy of `q`.
192193
function estimate_entropy end
193194

194195
export RepGradELBO,
195-
ScoreGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy
196+
ScoreGradELBO,
197+
ClosedFormEntropy,
198+
StickingTheLandingEntropy,
199+
MonteCarloEntropy,
200+
ClosedFormEntropyZeroGradient,
201+
StickingTheLandingEntropyZeroGradient
196202

197203
include("objectives/elbo/entropy.jl")
198204
include("objectives/elbo/repgradelbo.jl")
@@ -259,20 +265,21 @@ export NoAveraging, PolynomialAveraging
259265
abstract type AbstractOperator end
260266

261267
"""
262-
apply(op::AbstractOperator, family, params, restructure)
268+
apply(op::AbstractOperator, family, rule, opt_state, params, restructure)
263269
264270
Apply operator `op` on the variational parameters `params`. For instance, `op` could be a projection or proximal operator.
265271
266272
# Arguments
267273
- `op::AbstractOperator`: Operator operating on the parameters `params`.
268274
- `family::Type`: Type of the variational approximation `restructure(params)`.
275+
- `opt_state`: State of the optimizer.
269276
- `params`: Variational parameters.
270277
- `restructure`: Function that reconstructs the variational approximation from `params`.
271278
272279
# Returns
273280
- `oped_params`: Parameters resulting from applying the operator.
274281
"""
275-
function apply(::AbstractOperator, ::Type, ::Any, ::Any) end
282+
function apply(::AbstractOperator, ::Type, ::Optimisers.AbstractRule, ::Any, ::Any, ::Any) end
276283

277284
"""
278285
IdentityOperator()
@@ -281,11 +288,12 @@ Identity operator.
281288
"""
282289
struct IdentityOperator <: AbstractOperator end
283290

284-
apply(::IdentityOperator, ::Type, params, restructure) = params
291+
apply(::IdentityOperator, ::Type, opt_st, params, restructure) = params
285292

286293
include("optimization/clip_scale.jl")
294+
include("optimization/proximal_location_scale_entropy.jl")
287295

288-
export IdentityOperator, ClipScale
296+
export IdentityOperator, ClipScale, ProximalLocationScaleEntropy
289297

290298
# Main optimization routine
291299
function optimize end

src/objectives/elbo/entropy.jl

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,19 @@
11

2+
"""
3+
ClosedFormEntropyZeroGradient()
4+
5+
Use closed-form expression of entropy but detach it from the AD graph.
6+
This is expected to be used only with `ProximalLocationScaleEntropy`.
7+
8+
# Requirements
9+
- The variational approximation implements `entropy`.
10+
"""
11+
struct ClosedFormEntropyZeroGradient <: AbstractEntropyEstimator end
12+
13+
function estimate_entropy(::ClosedFormEntropyZeroGradient, ::Any, ::Any, q_stop)
14+
return entropy(q_stop)
15+
end
16+
217
"""
318
ClosedFormEntropy()
419
@@ -9,12 +24,27 @@ Use closed-form expression of entropy[^TL2014][^KTRGB2017].
924
"""
1025
struct ClosedFormEntropy <: AbstractEntropyEstimator end
1126

12-
maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q
13-
14-
function estimate_entropy(::ClosedFormEntropy, ::Any, q)
27+
function estimate_entropy(::ClosedFormEntropy, ::Any, q, q_stop)
1528
return entropy(q)
1629
end
1730

31+
"""
32+
MonteCarloEntropy()
33+
34+
Monte Carlo estimation of the entropy.
35+
36+
# Requirements
37+
- The variational approximation `q` implements `logpdf`.
38+
- `logpdf(q, η)` must be differentiable by the selected AD framework.
39+
"""
40+
struct MonteCarloEntropy <: AbstractEntropyEstimator end
41+
42+
function estimate_entropy(::MonteCarloEntropy, mc_samples::AbstractMatrix, q, q_stop)
43+
return mean(eachcol(mc_samples)) do mc_sample
44+
-logpdf(q, mc_sample)
45+
end
46+
end
47+
1848
"""
1949
StickingTheLandingEntropy()
2050
@@ -26,14 +56,35 @@ The "sticking the landing" entropy estimator[^RWD2017].
2656
"""
2757
struct StickingTheLandingEntropy <: AbstractEntropyEstimator end
2858

29-
struct MonteCarloEntropy <: AbstractEntropyEstimator end
59+
function estimate_entropy(
60+
::StickingTheLandingEntropy, mc_samples::AbstractMatrix, q, q_stop
61+
)
62+
return mean(eachcol(mc_samples)) do mc_sample
63+
-logpdf(q_stop, mc_sample)
64+
end
65+
end
3066

31-
maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop
67+
"""
68+
StickingTheLandingEntropyZeroGradient()
69+
70+
The "sticking the landing" entropy estimator[^RWD2017] but modified to have a gradient of mean zero.
71+
This is expected to be used only with `ProximalLocationScaleEntropy`.
72+
73+
# Requirements
74+
- The variational approximation `q` implements `logpdf`.
75+
- `logpdf(q, η)` must be differentiable by the selected AD framework.
76+
- The variational approximation implements `entropy`.
77+
"""
78+
struct StickingTheLandingEntropyZeroGradient <: AbstractEntropyEstimator end
3279

3380
function estimate_entropy(
34-
::Union{MonteCarloEntropy,StickingTheLandingEntropy}, mc_samples::AbstractMatrix, q
81+
::Union{MonteCarloEntropy,StickingTheLandingEntropyZeroGradient},
82+
mc_samples::AbstractMatrix,
83+
q,
84+
q_stop,
3585
)
36-
mean(eachcol(mc_samples)) do mc_sample
37-
-logpdf(q, mc_sample)
86+
entropy_stl = mean(eachcol(mc_samples)) do mc_sample
87+
-logpdf(q_stop, mc_sample)
3888
end
89+
return entropy_stl - entropy(q) + entropy(q_stop)
3990
end

src/objectives/elbo/repgradelbo.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,6 @@ function Base.show(io::IO, obj::RepGradELBO)
6767
return print(io, ")")
6868
end
6969

70-
function estimate_entropy_maybe_stl(
71-
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
72-
)
73-
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
74-
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
75-
end
76-
7770
function estimate_energy_with_samples(prob, samples)
7871
return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
7972
end
@@ -98,7 +91,7 @@ function reparam_with_entropy(
9891
rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator
9992
)
10093
samples = rand(rng, q, n_samples)
101-
entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop)
94+
entropy = estimate_entropy(ent_est, samples, q, q_stop)
10295
return samples, entropy
10396
end
10497

src/optimization/clip_scale.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ Optimisers.@def struct ClipScale <: AbstractOperator
99
epsilon = 1e-5
1010
end
1111

12-
function apply(::ClipScale, family::Type, params, restructure)
12+
function apply(::ClipScale, family::Type, state, params, restructure)
1313
return error("`ClipScale` is not defined for the variational family of type $(family).")
1414
end
1515

16-
function apply(op::ClipScale, ::Type{<:MvLocationScale}, params, restructure)
16+
function apply(op::ClipScale, ::Type{<:MvLocationScale}, state, params, restructure)
1717
q = restructure(params)
1818
ϵ = convert(eltype(params), op.epsilon)
1919

@@ -26,7 +26,7 @@ function apply(op::ClipScale, ::Type{<:MvLocationScale}, params, restructure)
2626
return params
2727
end
2828

29-
function apply(op::ClipScale, ::Type{<:MvLocationScaleLowRank}, params, restructure)
29+
function apply(op::ClipScale, ::Type{<:MvLocationScaleLowRank}, state, params, restructure)
3030
q = restructure(params)
3131
ϵ = convert(eltype(params), op.epsilon)
3232

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
2+
"""
3+
ProximalLocationScaleEntropy()
4+
5+
Proximal operator for the entropy of a location-scale distribution, which is defined as
6+
```math
7+
\\mathrm{prox}(\\lambda) = \\argmin_{\\lambda^{\\prime}} - \\mathbb{H}(q_{\\lambda^{\\prime}}) + \\frac{1}{2 \\gamma_t} \\left\\lVert \\lambda - \\lambda^{\\prime} \\right\\rVert ,
8+
```
9+
where \$\\gamma_t\$ is the stepsize the optimizer used with the proximal operator.
10+
This assumes the variational family is `<:VILocationScale` and the optimizer is one of the following:
11+
- `DoG`
12+
- `DoWG`
13+
- `Descent`
14+
15+
For ELBO maximization, since this proximal operator handles the entropy, the gradient estimator for the ELBO must ignore the entropy term.
16+
That is, the `entropy` keyword argument of `RepGradELBO` muse be one of the following:
17+
- `ClosedFormEntropyZeroGradient`
18+
- `StickingTheLandingEntropyZeroGradient`
19+
"""
20+
struct ProximalLocationScaleEntropy <: AbstractOperator end
21+
22+
function apply(::ProximalLocationScaleEntropy, family, state, params, restructure)
23+
return error("`ProximalLocationScaleEntropy` only supports `<:MvLocationScale`.")
24+
end
25+
26+
function stepsize_from_optimizer_state(rule::Optimisers.AbstractRule, state)
27+
return error(
28+
"`ProximalLocationScaleEntropy` does not support optimization rule $(typeof(rule))."
29+
)
30+
end
31+
32+
stepsize_from_optimizer_state(rule::Descent, ::Any) = rule.eta
33+
34+
function stepsize_from_optimizer_state(::DoG, state)
35+
_, v, r = state
36+
return r / sqrt(v)
37+
end
38+
39+
function stepsize_from_optimizer_state(::DoWG, state)
40+
_, v, r = state
41+
return r * r / sqrt(v)
42+
end
43+
44+
function apply(
45+
::ProximalLocationScaleEntropy,
46+
::Type{<:MvLocationScale},
47+
leaf::Optimisers.Leaf{<:Union{<:DoG,<:DoWG,<:Descent},S},
48+
params,
49+
restructure,
50+
) where {S}
51+
q = restructure(params)
52+
53+
stepsize = stepsize_from_optimizer_state(leaf.rule, leaf.state)
54+
diag_idx = diagind(q.scale)
55+
scale_diag = q.scale[diag_idx]
56+
@. q.scale[diag_idx] = scale_diag + (sqrt(scale_diag^2 + 4 * stepsize) - scale_diag) / 2
57+
58+
params, _ = Optimisers.destructure(q)
59+
60+
return params
61+
end

0 commit comments

Comments
 (0)