@@ -18,17 +18,26 @@ is_supported(::ADTypes.AutoReverseDiff) = true
1818 LogDensityFunction(
1919 model::Model,
2020 varinfo::AbstractVarInfo=VarInfo(model),
21- context::AbstractContext=DefaultContext()
21+ context::AbstractContext=DefaultContext();
22+ adtype::Union{Nothing,ADTypes.AbstractADType}=model.adtype,
2223 )
2324
2425A struct which contains a model, along with all the information necessary to
2526calculate its log density at a given point.
2627
28+ If the `adtype` keyword argument is specified, it is used to overwrite the
29+ existing `adtype` in the model supplied.
30+
2731At its most basic level, a LogDensityFunction wraps the model together with its
2832the type of varinfo to be used, as well as the evaluation context. These must
2933be known in order to calculate the log density (using
3034[`DynamicPPL.evaluate!!`](@ref)).
3135
36+ Using this information, `DynamicPPL.LogDensityFunction` implements the
37+ LogDensityProblems.jl interface. If the underlying model's `adtype` is nothing,
38+ then only `logdensity` is implemented. If the model's `adtype` is a concrete AD
39+ backend type, then `logdensity_and_gradient` is also implemented.
40+
3241# Fields
3342$(FIELDS)
3443
@@ -77,6 +86,12 @@ julia> model_with_ad = Model(model, ADTypes.AutoForwardDiff());
7786
7887julia> f = LogDensityFunction(model_with_ad);
7988
89+ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
90+ (-2.3378770664093453, [1.0])
91+
92+ julia> # Alternatively, we can set the AD backend when creating the LogDensityFunction.
93+ f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
94+
8095julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
8196(-2.3378770664093453, [1.0])
8297```
@@ -94,18 +109,16 @@ struct LogDensityFunction{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}
94109 function LogDensityFunction (
95110 model:: Model ,
96111 varinfo:: AbstractVarInfo = VarInfo (model),
97- context:: AbstractContext = leafcontext (model. context),
112+ context:: AbstractContext = leafcontext (model. context);
113+ adtype:: Union{Nothing,ADTypes.AbstractADType} = model. adtype,
98114 )
99- adtype = model. adtype
100115 if adtype === nothing
101116 prep = nothing
102117 else
103- # Make backend-specific tweaks to the adtype
104- # This should arguably be done in the model constructor, but it needs the
105- # varinfo and context to do so, and it seems excessive to construct a
106- # varinfo at the point of calling Model().
107118 adtype = tweak_adtype (adtype, model, varinfo, context)
108- model = Model (model, adtype)
119+ if adtype != model. adtype
120+ model = Model (model, adtype)
121+ end
109122 # Check whether it is supported
110123 is_supported (adtype) ||
111124 @warn " The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
@@ -148,7 +161,7 @@ function LogDensityFunction(
148161 return if adtype === f. model. adtype
149162 f # Avoid recomputing prep if not needed
150163 else
151- LogDensityFunction (Model ( f. model, adtype), f. varinfo, f. context)
164+ LogDensityFunction (f. model, f. varinfo, f. context; adtype = adtype )
152165 end
153166end
154167
0 commit comments