Skip to content

Commit eab4336

Browse files
Merge pull request #63 from SciML/ap/reversediff
Promote ReverseDiff compile field to type
2 parents 72f806d + 6c27b87 commit eab4336

File tree

5 files changed

+36
-10
lines changed

5 files changed

+36
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = [
44
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
55
]
6-
version = "1.4.0"
6+
version = "1.5.0"
77

88
[deps]
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/ADTypes.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ abstract type AbstractADType end
1616

1717
Base.broadcastable(ad::AbstractADType) = Ref(ad)
1818

19+
@inline _unwrap_val(::Val{T}) where {T} = T
20+
@inline _unwrap_val(x) = x
21+
1922
include("mode.jl")
2023
include("dense.jl")
2124
include("sparse.jl")

src/dense.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,28 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
186186
187187
# Constructors
188188
189-
AutoReverseDiff(; compile=false)
189+
AutoReverseDiff(; compile::Union{Val, Bool} = Val(false))
190190
191191
# Fields
192192
193-
- `compile::Bool`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation
193+
- `compile::Union{Val, Bool}`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation
194194
"""
195-
Base.@kwdef struct AutoReverseDiff <: AbstractADType
196-
compile::Bool = false
195+
struct AutoReverseDiff{C} <: AbstractADType
196+
compile::Bool # this field if left for legacy reasons
197+
198+
function AutoReverseDiff(; compile::Union{Val, Bool} = Val(false))
199+
_compile = _unwrap_val(compile)
200+
return new{_compile}(_compile)
201+
end
202+
end
203+
204+
function Base.getproperty(ad::AutoReverseDiff, s::Symbol)
205+
if s === :compile
206+
Base.depwarn(
207+
"`ad.compile` where `ad` is `AutoReverseDiff` has been deprecated and will be removed in v2. Instead it is available as a compile-time constant as `AutoReverseDiff{true}` or `AutoReverseDiff{false}`.",
208+
:getproperty)
209+
end
210+
return getfield(ad, s)
197211
end
198212

199213
mode(::AutoReverseDiff) = ReverseMode()

src/symbols.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ ADTypes.AutoZygote()
2222
Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...)
2323

2424
for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation,
25-
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff,
26-
:ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote)
27-
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(args...; kws...)
25+
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff,
26+
:ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote)
27+
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(
28+
args...; kws...)
2829
end
29-

test/dense.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,26 @@ end
113113
end
114114

115115
@testset "AutoReverseDiff" begin
116-
ad = AutoReverseDiff()
116+
ad = @inferred AutoReverseDiff()
117117
@test ad isa AbstractADType
118118
@test ad isa AutoReverseDiff
119119
@test mode(ad) isa ReverseMode
120120
@test !ad.compile
121+
@test_deprecated ad.compile
121122

122123
ad = AutoReverseDiff(; compile = true)
123124
@test ad isa AbstractADType
124125
@test ad isa AutoReverseDiff
125126
@test mode(ad) isa ReverseMode
126127
@test ad.compile
128+
@test_deprecated ad.compile
129+
130+
ad = @inferred AutoReverseDiff(; compile = Val(true))
131+
@test ad isa AbstractADType
132+
@test ad isa AutoReverseDiff
133+
@test mode(ad) isa ReverseMode
134+
@test ad.compile
135+
@test_deprecated ad.compile
127136
end
128137

129138
@testset "AutoSymbolics" begin

0 commit comments

Comments
 (0)