Skip to content

Commit efdc12e

Browse files
committed
Promote ReverseDiff compile field to type
1 parent 72f806d commit efdc12e

File tree

5 files changed

+24
-10
lines changed

5 files changed

+24
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = [
44
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
55
]
6-
version = "1.4.0"
6+
version = "1.5.0"
77

88
[deps]
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/ADTypes.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ abstract type AbstractADType end
1616

1717
Base.broadcastable(ad::AbstractADType) = Ref(ad)
1818

19+
@inline _unwrap_val(::Val{T}) where {T} = T
20+
@inline _unwrap_val(x) = x
21+
1922
include("mode.jl")
2023
include("dense.jl")
2124
include("sparse.jl")

src/dense.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,19 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
186186
187187
# Constructors
188188
189-
AutoReverseDiff(; compile=false)
189+
AutoReverseDiff(; compile::Union{Val, Bool} = Val(false))
190190
191191
# Fields
192192
193-
- `compile::Bool`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation
193+
- `compile::Union{Val, Bool}`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation
194194
"""
195-
Base.@kwdef struct AutoReverseDiff <: AbstractADType
196-
compile::Bool = false
195+
struct AutoReverseDiff{C} <: AbstractADType
196+
compile::Bool # this field if left for legacy reasons
197+
198+
function AutoReverseDiff(; compile::Union{Val, Bool} = Val(false))
199+
_compile = _unwrap_val(compile)
200+
return new{_compile}(_compile)
201+
end
197202
end
198203

199204
mode(::AutoReverseDiff) = ReverseMode()

src/symbols.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ ADTypes.AutoZygote()
2222
Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...)
2323

2424
for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation,
25-
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff,
26-
:ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote)
27-
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(args...; kws...)
25+
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff,
26+
:ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote)
27+
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(
28+
args...; kws...)
2829
end
29-

test/dense.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ end
113113
end
114114

115115
@testset "AutoReverseDiff" begin
116-
ad = AutoReverseDiff()
116+
ad = @inferred AutoReverseDiff()
117117
@test ad isa AbstractADType
118118
@test ad isa AutoReverseDiff
119119
@test mode(ad) isa ReverseMode
@@ -124,6 +124,12 @@ end
124124
@test ad isa AutoReverseDiff
125125
@test mode(ad) isa ReverseMode
126126
@test ad.compile
127+
128+
ad = @inferred AutoReverseDiff(; compile = Val(true))
129+
@test ad isa AbstractADType
130+
@test ad isa AutoReverseDiff
131+
@test mode(ad) isa ReverseMode
132+
@test ad.compile
127133
end
128134

129135
@testset "AutoSymbolics" begin

0 commit comments

Comments
 (0)