Skip to content

Commit 8bc3f9c

Browse files
committed
Migrating to new ADTypes
1 parent e220325 commit 8bc3f9c

File tree

7 files changed

+39
-53
lines changed

7 files changed

+39
-53
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BoundaryValueDiffEq"
22
uuid = "764a87c0-6b3e-53db-9096-fe964310641d"
3-
version = "5.7.1"
3+
version = "5.8.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -24,6 +24,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2424
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2525
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2626
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
27+
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
2728
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2829
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
2930

@@ -34,7 +35,7 @@ ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
3435
BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"
3536

3637
[compat]
37-
ADTypes = "0.2.6"
38+
ADTypes = "1.2"
3839
Adapt = "4"
3940
Aqua = "0.8"
4041
ArrayInterface = "7.7"

src/BoundaryValueDiffEq.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ end
192192
nlsolve = NewtonRaphson(),
193193
jac_alg = BVPJacobianAlgorithm(;
194194
bc_diffmode = AutoForwardDiff(; chunksize = 2),
195-
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2))))
195+
nonbc_diffmode = AutoSparse(AutoForwardDiff(; chunksize = 2)))))
196196
end
197197

198198
@compile_workload begin
@@ -257,13 +257,13 @@ end
257257
nlsolve = LevenbergMarquardt(; disable_geodesic = Val(true)),
258258
jac_alg = BVPJacobianAlgorithm(;
259259
bc_diffmode = AutoForwardDiff(; chunksize = 2),
260-
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2))),
260+
nonbc_diffmode = AutoSparse(AutoForwardDiff(; chunksize = 2)))),
261261
MultipleShooting(10,
262262
Tsit5();
263263
nlsolve = GaussNewton(),
264264
jac_alg = BVPJacobianAlgorithm(;
265265
bc_diffmode = AutoForwardDiff(; chunksize = 2),
266-
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2)))])
266+
nonbc_diffmode = AutoSparse(AutoForwardDiff(; chunksize = 2))))])
267267
end
268268

269269
@compile_workload begin

src/algorithms.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ Significantly more stable than Single Shooting.
8484
on the input types and problem type.
8585
8686
+ For `TwoPointBVProblem`, only `diffmode` is used (defaults to
87-
`AutoSparseForwardDiff` if possible else `AutoSparseFiniteDiff`).
87+
`AutoSparse(AutoForwardDiff())` if possible else `AutoSparseFiniteDiff`).
8888
+ For `BVProblem`, `bc_diffmode` and `nonbc_diffmode` are used. For `nonbc_diffmode`
89-
we default to `AutoSparseForwardDiff` if possible else `AutoSparseFiniteDiff`. For
89+
we default to `AutoSparse(AutoForwardDiff())` if possible else `AutoSparseFiniteDiff`. For
9090
`bc_diffmode`, we default to `AutoForwardDiff` if possible else `AutoFiniteDiff`.
9191
- `grid_coarsening`: Coarsening the multiple-shooting grid to generate a stable IVP
9292
solution. Possible Choices:
@@ -160,9 +160,9 @@ for order in (2, 3, 4, 5, 6)
160160
`BVPJacobianAlgorithm()`, which automatically decides the best algorithm to
161161
use based on the input types and problem type.
162162
- For `TwoPointBVProblem`, only `diffmode` is used (defaults to
163-
`AutoSparseForwardDiff` if possible else `AutoSparseFiniteDiff`).
163+
`AutoSparse(AutoForwardDiff())` if possible else `AutoSparseFiniteDiff`).
164164
- For `BVProblem`, `bc_diffmode` and `nonbc_diffmode` are used. For
165-
`nonbc_diffmode` defaults to `AutoSparseForwardDiff` if possible else
165+
`nonbc_diffmode` defaults to `AutoSparse(AutoForwardDiff())` if possible else
166166
`AutoSparseFiniteDiff`. For `bc_diffmode`, defaults to `AutoForwardDiff` if
167167
possible else `AutoFiniteDiff`.
168168
- `defect_threshold`: Threshold for defect control.

src/solve/multiple_shooting.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ function __solve_nlproblem!(
113113
jac_prototype = init_jacobian(jac_cache)
114114

115115
ode_cache_jac_fn = __multiple_shooting_init_jacobian_odecache(
116-
ensemblealg, prob, jac_cache, alg.jac_alg.diffmode,
116+
ensemblealg, prob, jac_cache, __cache_trait(alg.jac_alg.diffmode),
117117
alg.ode_alg, cur_nshoot, u0; internal_ode_kwargs...)
118118

119119
loss_fnₚ = @closure (du, u) -> __multiple_shooting_2point_loss!(
@@ -158,7 +158,7 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
158158
ode_jac_cache = sparse_jacobian_cache(alg.jac_alg.nonbc_diffmode, sd_ode, nothing,
159159
similar(u_at_nodes, cur_nshoot * N), u_at_nodes)
160160
ode_cache_ode_jac_fn = __multiple_shooting_init_jacobian_odecache(
161-
ensemblealg, prob, ode_jac_cache, alg.jac_alg.nonbc_diffmode,
161+
ensemblealg, prob, ode_jac_cache, __cache_trait(alg.jac_alg.nonbc_diffmode),
162162
alg.ode_alg, cur_nshoot, u0; internal_ode_kwargs...)
163163

164164
# BC Part
@@ -167,7 +167,7 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
167167
bc_jac_cache = sparse_jacobian_cache(
168168
alg.jac_alg.bc_diffmode, sd_bc, nothing, similar(bcresid_prototype), u_at_nodes)
169169
ode_cache_bc_jac_fn = __multiple_shooting_init_jacobian_odecache(
170-
ensemblealg, prob, bc_jac_cache, alg.jac_alg.bc_diffmode,
170+
ensemblealg, prob, bc_jac_cache, __cache_trait(alg.jac_alg.bc_diffmode),
171171
alg.ode_alg, cur_nshoot, u0; internal_ode_kwargs...)
172172

173173
jac_prototype = vcat(init_jacobian(bc_jac_cache), init_jacobian(ode_jac_cache))
@@ -208,12 +208,12 @@ function __multiple_shooting_init_odecache(
208208
end
209209

210210
function __multiple_shooting_init_jacobian_odecache(
211-
ensemblealg, prob, jac_cache, ad, alg, nshoots, u; kwargs...)
211+
ensemblealg, prob, jac_cache, ::NoDiffCacheNeeded, alg, nshoots, u; kwargs...)
212212
return __multiple_shooting_init_odecache(ensemblealg, prob, alg, u, nshoots; kwargs...)
213213
end
214214

215-
function __multiple_shooting_init_jacobian_odecache(ensemblealg, prob, jac_cache,
216-
::Union{AutoForwardDiff, AutoSparseForwardDiff}, alg, nshoots, u; kwargs...)
215+
function __multiple_shooting_init_jacobian_odecache(
216+
ensemblealg, prob, jac_cache, ::DiffCacheNeeded, alg, nshoots, u; kwargs...)
217217
cache = jac_cache.cache
218218
if cache isa ForwardDiff.JacobianConfig
219219
xduals = reshape(cache.duals[2][1:length(u)], size(u))

src/solve/single_shooting.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;),
4949
end
5050

5151
ode_cache_jac_fn = __single_shooting_jacobian_ode_cache(
52-
internal_prob, jac_cache, alg.jac_alg.diffmode, u0, alg.ode_alg; ode_kwargs...)
52+
internal_prob, jac_cache, __cache_trait(alg.jac_alg.diffmode),
53+
u0, alg.ode_alg; ode_kwargs...)
5354

5455
jac_prototype = init_jacobian(jac_cache)
5556

@@ -126,13 +127,13 @@ function __single_shooting_jacobian(J, u, jac_cache, diffmode, loss_fn::L) where
126127
return J
127128
end
128129

129-
function __single_shooting_jacobian_ode_cache(prob, jac_cache, alg, u0, ode_alg; kwargs...)
130+
function __single_shooting_jacobian_ode_cache(
131+
prob, jac_cache, ::NoDiffCacheNeeded, u0, ode_alg; kwargs...)
130132
return SciMLBase.__init(remake(prob; u0), ode_alg; kwargs...)
131133
end
132134

133135
function __single_shooting_jacobian_ode_cache(
134-
prob, jac_cache, ::Union{AutoForwardDiff, AutoSparseForwardDiff},
135-
u0, ode_alg; kwargs...)
136+
prob, jac_cache, ::DiffCacheNeeded, u0, ode_alg; kwargs...)
136137
cache = jac_cache.cache
137138
if cache isa ForwardDiff.JacobianConfig
138139
xduals = cache.duals isa Tuple ? cache.duals[2] : cache.duals

src/types.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ If user provided all the required fields, then return the user provided algorith
8989
Otherwise, based on the problem type and the algorithm, decide the missing fields.
9090
9191
For example, for `TwoPointBVProblem`, the `bc_diffmode` is set to
92-
`AutoSparseForwardDiff` while for `StandardBVProblem`, the `bc_diffmode` is set to
93-
`AutoForwardDiff`.
92+
`AutoSparse(AutoForwardDiff())` while for `StandardBVProblem`, the `bc_diffmode` is set to
93+
`AutoForwardDiff()`.
9494
"""
9595
function concrete_jacobian_algorithm(jac_alg::BVPJacobianAlgorithm, prob::BVProblem, alg)
9696
return concrete_jacobian_algorithm(jac_alg, prob.problem_type, prob, alg)
@@ -109,21 +109,13 @@ function concrete_jacobian_algorithm(
109109
return BVPJacobianAlgorithm(bc_diffmode, nonbc_diffmode, diffmode)
110110
end
111111

112-
struct BoundaryValueDiffEqTag end
113-
114-
function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:BoundaryValueDiffEqTag, <:T}},
115-
f::F, x::AbstractArray{T}) where {T, F}
116-
return true
117-
end
118-
119112
@inline function __default_sparse_ad(x::AbstractArray{T}) where {T}
120113
return isbitstype(T) ? __default_sparse_ad(T) : __default_sparse_ad(first(x))
121114
end
122115
@inline __default_sparse_ad(x::T) where {T} = __default_sparse_ad(T)
123116
@inline __default_sparse_ad(::Type{<:Complex}) = AutoSparseFiniteDiff()
124117
@inline function __default_sparse_ad(::Type{T}) where {T}
125-
return ForwardDiff.can_dual(T) ?
126-
AutoSparseForwardDiff(; tag = BoundaryValueDiffEqTag()) : AutoSparseFiniteDiff()
118+
return AutoSparse(ifelse(ForwardDiff.can_dual(T), AutoForwardDiff(), AutoFiniteDiff()))
127119
end
128120

129121
@inline function __default_nonsparse_ad(x::AbstractArray{T}) where {T}
@@ -132,8 +124,7 @@ end
132124
@inline __default_nonsparse_ad(x::T) where {T} = __default_nonsparse_ad(T)
133125
@inline __default_nonsparse_ad(::Type{<:Complex}) = AutoFiniteDiff()
134126
@inline function __default_nonsparse_ad(::Type{T}) where {T}
135-
return ForwardDiff.can_dual(T) ? AutoForwardDiff(; tag = BoundaryValueDiffEqTag()) :
136-
AutoFiniteDiff()
127+
return ifelse(ForwardDiff.can_dual(T), AutoForwardDiff(), AutoFiniteDiff())
137128
end
138129

139130
# This can cause Type Instability
@@ -146,9 +137,10 @@ Base.@deprecate MIRKJacobianComputationAlgorithm(
146137
diffmode = missing; collocation_diffmode = missing, bc_diffmode = missing) BVPJacobianAlgorithm(
147138
diffmode; nonbc_diffmode = collocation_diffmode, bc_diffmode)
148139

149-
__needs_diffcache(::Union{AutoForwardDiff, AutoSparseForwardDiff}) = true
150-
__needs_diffcache(_) = false
151-
function __needs_diffcache(jac_alg::BVPJacobianAlgorithm)
140+
@inline __needs_diffcache(::AutoForwardDiff) = true
141+
@inline __needs_diffcache(ad::AutoSparse) = __needs_diffcache(ADTypes.dense_ad(ad))
142+
@inline __needs_diffcache(_) = false
143+
@inline function __needs_diffcache(jac_alg::BVPJacobianAlgorithm)
152144
return __needs_diffcache(jac_alg.diffmode) ||
153145
__needs_diffcache(jac_alg.bc_diffmode) ||
154146
__needs_diffcache(jac_alg.nonbc_diffmode)
@@ -176,3 +168,11 @@ const MaybeDiffCache = Union{DiffCache, FakeDiffCache}
176168
PreallocationTools.get_tmp(dc, u)
177169
end
178170
end
171+
172+
# DiffCache
173+
struct DiffCacheNeeded end
174+
struct NoDiffCacheNeeded end
175+
176+
@inline __cache_trait(::AutoForwardDiff) = DiffCacheNeeded()
177+
@inline __cache_trait(ad::AutoSparse) = __cache_trait(ADTypes.dense_ad(ad))
178+
@inline __cache_trait(_) = NoDiffCacheNeeded()

src/utils.jl

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -220,24 +220,8 @@ end
220220
__vec_bc(sol, p, t, bc, u_size) = vec(bc(__restructure_sol(sol, u_size), p, t))
221221
__vec_bc(sol, p, bc, u_size) = vec(bc(reshape(sol, u_size), p))
222222

223-
__get_non_sparse_ad(ad::AbstractADType) = ad
224-
function __get_non_sparse_ad(ad::AbstractSparseADType)
225-
if ad isa AutoSparseForwardDiff
226-
return AutoForwardDiff{__get_chunksize(ad), typeof(ad.tag)}(ad.tag)
227-
elseif ad isa AutoSparseEnzyme
228-
return AutoEnzyme()
229-
elseif ad isa AutoSparseFiniteDiff
230-
return AutoFiniteDiff()
231-
elseif ad isa AutoSparseReverseDiff
232-
return AutoReverseDiff(ad.compile)
233-
elseif ad isa AutoSparseZygote
234-
return AutoZygote()
235-
else
236-
throw(ArgumentError("Unknown AD Type"))
237-
end
238-
end
239-
240-
__get_chunksize(::AutoSparseForwardDiff{CK}) where {CK} = CK
223+
@inline __get_non_sparse_ad(ad::AbstractADType) = ad
224+
@inline __get_non_sparse_ad(ad::AutoSparse) = ADTypes.dense_ad(ad)
241225

242226
# Restructure Solution
243227
function __restructure_sol(sol::Vector{<:AbstractArray}, u_size)

0 commit comments

Comments
 (0)