Skip to content

Commit d71efa6

Browse files
committed
big update to make AdvancedVI compatible with more recent versions of
packages
1 parent bb7e85c commit d71efa6

File tree

12 files changed

+956
-123
lines changed

12 files changed

+956
-123
lines changed

Project.toml

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
33
version = "0.1.6"
44

55
[deps]
6+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
67
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
8+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
79
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
810
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
911
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -16,22 +18,30 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1618
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1719
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1820

21+
[weakdeps]
22+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
23+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
24+
25+
[extensions]
26+
AdvancedVIReverseDiffExt = ["ReverseDiff"]
27+
AdvancedVIZygoteExt = ["Zygote"]
28+
1929
[compat]
20-
Bijectors = "0.4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10"
30+
ADTypes = "0.2, 1"
31+
Bijectors = "0.4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.13"
2132
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
2233
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
2334
DocStringExtensions = "0.8, 0.9"
2435
ForwardDiff = "0.10.3"
2536
ProgressMeter = "1.0.0"
2637
Requires = "0.5, 1.0"
38+
ReverseDiff = "1"
2739
StatsBase = "0.32, 0.33"
2840
StatsFuns = "0.8, 0.9, 1"
2941
Tracker = "0.2.3"
42+
Zygote = "0.6"
3043
julia = "1"
3144

3245
[extras]
33-
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
34-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
35-
36-
[targets]
37-
test = ["Pkg", "Test"]
46+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
47+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

ext/AdvancedVIFluxExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module AdvancedVIFluxExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedVI: AdvancedVI
5+
using Flux: Flux
6+
else
7+
using ..AdvancedVI: AdvancedVI
8+
using ..Flux: Flux
9+
end
10+
11+
AdvancedVI.apply!(o::Flux.Optimise.AbstractOptimizer, x, Δ) = Flux.Optimise.apply!(o, x, Δ)
12+
13+
end

ext/AdvancedVIReverseDiffExt.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
module AdvancedVIReverseDiffExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
5+
using ReverseDiff: ReverseDiff
6+
else
7+
using ..AdvancedVI: ADTypes, AdvancedVI
8+
using ..ReverseDiff: ReverseDiff
9+
end
10+
11+
AdvancedVI.ADBackend(::Val{:reversediff}) = ADTypes.AutoReverseDiff()
12+
13+
function AdvancedVI.setadbackend(::Val{:reversediff})
14+
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
15+
AdvancedVI.ADBACKEND[] = :reversediff
16+
end
17+
18+
tape(f, x) = ReverseDiff.GradientTape(f, x)
19+
20+
function AdvancedVI.grad!(
21+
vo,
22+
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoReverseDiff},
23+
q,
24+
model,
25+
θ::AbstractVector{<:Real},
26+
out::DiffResults.MutableDiffResult,
27+
args...
28+
)
29+
f(θ) =
30+
if (q isa Distributions.Distribution)
31+
-vo(alg, AdvancedVI.update(q, θ), model, args...)
32+
else
33+
-vo(alg, q(θ), model, args...)
34+
end
35+
tp = tape(f, θ)
36+
ReverseDiff.gradient!(out, tp, θ)
37+
return out
38+
end
39+
40+
end

ext/AdvancedVIZygoteExt.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
module AdvancedVIZygoteExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
5+
using Zygote: Zygote
6+
else
7+
using ..AdvancedVI: ADTypes, AdvancedVI
8+
using ..Zygote: Zygote
9+
end
10+
11+
AdvancedVI.ADBackend(::Val{:zygote}) = ADTypes.AutoZygote()
12+
function AdvancedVI.setadbackend(::Val{:zygote})
13+
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
14+
AdvancedVI.ADBACKEND[] = :zygote
15+
end
16+
17+
function AdvancedVI.grad!(
18+
vo,
19+
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoZygote},
20+
q,
21+
model,
22+
θ::AbstractVector{<:Real},
23+
out::DiffResults.MutableDiffResult,
24+
args...
25+
)
26+
f(θ) =
27+
if (q isa Distributions.Distribution)
28+
-vo(alg, AdvancedVI.update(q, θ), model, args...)
29+
else
30+
-vo(alg, q(θ), model, args...)
31+
end
32+
y, back = Zygote.pullback(f, θ)
33+
dy = first(back(1.0))
34+
DiffResults.value!(out, y)
35+
DiffResults.gradient!(out, dy)
36+
return out
37+
end
38+
39+
end

src/AdvancedVI.jl

Lines changed: 25 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
module AdvancedVI
22

3-
using Random: AbstractRNG
3+
using Random: Random, AbstractRNG
44

55
using Distributions, DistributionsAD, Bijectors
66
using DocStringExtensions
77

88
using ProgressMeter, LinearAlgebra
99

10-
using ForwardDiff
11-
using Tracker
10+
using ADTypes: ADTypes
11+
using DiffResults: DiffResults
12+
13+
using ForwardDiff: ForwardDiff
14+
using Tracker: Tracker
1215

1316
const PROGRESS = Ref(true)
1417
function turnprogress(switch::Bool)
@@ -18,65 +21,6 @@ end
1821

1922
const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))
2023

21-
include("ad.jl")
22-
23-
using Requires
24-
function __init__()
25-
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
26-
apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ)
27-
Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ)
28-
Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ)
29-
end
30-
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
31-
include("compat/zygote.jl")
32-
export ZygoteAD
33-
34-
function AdvancedVI.grad!(
35-
vo,
36-
alg::VariationalInference{<:AdvancedVI.ZygoteAD},
37-
q,
38-
model,
39-
θ::AbstractVector{<:Real},
40-
out::DiffResults.MutableDiffResult,
41-
args...
42-
)
43-
f(θ) = if (q isa Distribution)
44-
- vo(alg, update(q, θ), model, args...)
45-
else
46-
- vo(alg, q(θ), model, args...)
47-
end
48-
y, back = Zygote.pullback(f, θ)
49-
dy = first(back(1.0))
50-
DiffResults.value!(out, y)
51-
DiffResults.gradient!(out, dy)
52-
return out
53-
end
54-
end
55-
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
56-
include("compat/reversediff.jl")
57-
export ReverseDiffAD
58-
59-
function AdvancedVI.grad!(
60-
vo,
61-
alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}},
62-
q,
63-
model,
64-
θ::AbstractVector{<:Real},
65-
out::DiffResults.MutableDiffResult,
66-
args...
67-
)
68-
f(θ) = if (q isa Distribution)
69-
- vo(alg, update(q, θ), model, args...)
70-
else
71-
- vo(alg, q(θ), model, args...)
72-
end
73-
tp = AdvancedVI.tape(f, θ)
74-
ReverseDiff.gradient!(out, tp, θ)
75-
return out
76-
end
77-
end
78-
end
79-
8024
export
8125
vi,
8226
ADVI,
@@ -86,10 +30,12 @@ export
8630
DecayedADAGrad,
8731
VariationalInference
8832

33+
include("compat.jl")
34+
include("ad.jl")
35+
8936
abstract type VariationalInference{AD} end
9037

91-
getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD)
92-
getADtype(::VariationalInference{AD}) where AD = AD
38+
getchunksize(::ADTypes.AutoForwardDiff{chunk}) where chunk = chunk === nothing ? 0 : chunk
9339

9440
abstract type VariationalObjective end
9541

@@ -100,7 +46,7 @@ const VariationalPosterior = Distribution{Multivariate, Continuous}
10046
grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)
10147
10248
Computes the gradients used in `optimize!`. Default implementation is provided for
103-
`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`.
49+
`VariationalInference{AD}` where `AD` is either `ADTypes.AutoForwardDiff` or `ADTypes.AutoTracker`.
10450
This implicitly also gives a default implementation of `optimize!`.
10551
10652
Variance reduction techniques, e.g. control variates, should be implemented in this function.
@@ -129,7 +75,7 @@ function update end
12975
# default implementations
13076
function grad!(
13177
vo,
132-
alg::VariationalInference{<:ForwardDiffAD},
78+
alg::VariationalInference{<:ADTypes.AutoForwardDiff},
13379
q,
13480
model,
13581
θ::AbstractVector{<:Real},
@@ -143,7 +89,7 @@ function grad!(
14389
end
14490

14591
# Set chunk size and do ForwardMode.
146-
chunk_size = getchunksize(typeof(alg))
92+
chunk_size = getchunksize(alg.adtype)
14793
config = if chunk_size == 0
14894
ForwardDiff.GradientConfig(f, θ)
14995
else
@@ -154,7 +100,7 @@ end
154100

155101
function grad!(
156102
vo,
157-
alg::VariationalInference{<:TrackerAD},
103+
alg::VariationalInference{<:ADTypes.AutoTracker},
158104
q,
159105
model,
160106
θ::AbstractVector{<:Real},
@@ -238,4 +184,15 @@ include("optimisers.jl")
238184
# VI algorithms
239185
include("advi.jl")
240186

187+
@static if !isdefined(Base, :get_extension)
188+
function __init__()
189+
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
190+
"../ext/AdvancedVIReverseDiffExt.jl"
191+
)
192+
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
193+
"../ext/AdvancedVIZygoteExt.jl"
194+
)
195+
end
196+
end
197+
241198
end # module

src/ad.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
##############################
2-
# Global variables/constants #
3-
##############################
1+
# FIXME: All this should go away.
42
const ADBACKEND = Ref(:forwarddiff)
53
setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym))
64
function setadbackend(::Val{:forward_diff})
75
Base.depwarn("`AdvancedVI.setadbackend(:forward_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend)
86
setadbackend(Val(:forwarddiff))
97
end
108
function setadbackend(::Val{:forwarddiff})
9+
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
1110
ADBACKEND[] = :forwarddiff
1211
end
1312

@@ -16,6 +15,7 @@ function setadbackend(::Val{:reverse_diff})
1615
setadbackend(Val(:tracker))
1716
end
1817
function setadbackend(::Val{:tracker})
18+
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
1919
ADBACKEND[] = :tracker
2020
end
2121

@@ -32,15 +32,11 @@ function setchunksize(chunk_size::Int)
3232
CHUNKSIZE[] = chunk_size
3333
end
3434

35-
abstract type ADBackend end
36-
struct ForwardDiffAD{chunk} <: ADBackend end
37-
getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk
38-
39-
struct TrackerAD <: ADBackend end
35+
getchunksize(::Type{<:ADTypes.AutoForwardDiff{chunk}}) where chunk = chunk
4036

4137
ADBackend() = ADBackend(ADBACKEND[])
4238
ADBackend(T::Symbol) = ADBackend(Val(T))
4339

44-
ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]}
45-
ADBackend(::Val{:tracker}) = TrackerAD
40+
ADBackend(::Val{:forwarddiff}) = ADTypes.AutoForwardDiff(chunksize=CHUNKSIZE[])
41+
ADBackend(::Val{:tracker}) = ADTypes.AutoTracker()
4642
ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")

src/advi.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@ struct ADVI{AD} <: VariationalInference{AD}
2020
samples_per_step::Int
2121
"Maximum number of gradient steps."
2222
max_iters::Int
23+
"AD backend used for automatic differentiation."
24+
adtype::AD
25+
2326
end
2427

25-
function ADVI(samples_per_step::Int=1, max_iters::Int=1000)
26-
return ADVI{ADBackend()}(samples_per_step, max_iters)
28+
function ADVI(samples_per_step::Int=1, max_iters::Int=1000; adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff())
29+
return ADVI(samples_per_step, max_iters, adtype)
2730
end
2831

2932
alg_str(::ADVI) = "ADVI"

0 commit comments

Comments
 (0)