Skip to content

Commit b05286d

Browse files
committed
feat: add ConstructionBaseExt to allow Setfield and Functors support
1 parent c5c2b8c commit b05286d

File tree

4 files changed

+69
-5
lines changed

4 files changed

+69
-5
lines changed

Project.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
name = "ADTypes"
22
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = ["Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors"]
4-
version = "1.9.1"
4+
version = "1.10.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
89
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
910

1011
[weakdeps]
1112
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
13+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1214
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1315

1416
[extensions]
1517
ADTypesChainRulesCoreExt = "ChainRulesCore"
18+
ADTypesConstructionBaseExt = "ConstructionBase"
1619
ADTypesEnzymeCoreExt = "EnzymeCore"
1720

1821
[compat]
1922
ChainRulesCore = "1.0.2"
23+
ConstructionBase = "1.5"
2024
EnzymeCore = "0.5.3,0.6,0.7,0.8"
2125
julia = "1.6"
2226

@@ -25,7 +29,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2529
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2630
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
2731
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
32+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2833
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2934

3035
[targets]
31-
test = ["Aqua", "ChainRulesCore", "EnzymeCore", "JET", "Test"]
36+
test = ["Aqua", "ChainRulesCore", "EnzymeCore", "JET", "Setfield", "Test"]

ext/ADTypesConstructionBaseExt.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
module ADTypesConstructionBaseExt
2+
3+
using ADTypes: AutoEnzyme, AutoForwardDiff, AutoPolyesterForwardDiff
4+
using ConstructionBase: ConstructionBase
5+
6+
function ConstructionBase.constructorof(::Type{<:AutoEnzyme{M, A}}) where {M, A}
7+
return AutoEnzyme{A}
8+
end
9+
10+
function ConstructionBase.constructorof(::Type{<:AutoForwardDiff{chunksize}}) where {chunksize}
11+
return AutoForwardDiff{chunksize}
12+
end
13+
14+
function ConstructionBase.constructorof(::Type{<:AutoPolyesterForwardDiff{chunksize}}) where {chunksize}
15+
return AutoPolyesterForwardDiff{chunksize}
16+
end
17+
18+
end

src/dense.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,11 @@ struct AutoEnzyme{M, A} <: AbstractADType
6767
mode::M
6868
end
6969

70+
AutoEnzyme{A}(mode::M) where {M, A} = AutoEnzyme{M, A}(mode)
71+
7072
function AutoEnzyme(;
7173
mode::M = nothing, function_annotation::Type{A} = Nothing) where {M, A}
72-
return AutoEnzyme{M, A}(mode)
74+
return AutoEnzyme{A}(mode)
7375
end
7476

7577
mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension
@@ -181,8 +183,10 @@ struct AutoForwardDiff{chunksize, T} <: AbstractADType
181183
tag::T
182184
end
183185

186+
AutoForwardDiff{chunksize}(tag::T) where {chunksize, T} = AutoForwardDiff{chunksize, T}(tag)
187+
184188
function AutoForwardDiff(; chunksize = nothing, tag = nothing)
185-
AutoForwardDiff{chunksize, typeof(tag)}(tag)
189+
return AutoForwardDiff{chunksize}(tag)
186190
end
187191

188192
mode(::AutoForwardDiff) = ForwardMode()
@@ -271,8 +275,12 @@ struct AutoPolyesterForwardDiff{chunksize, T} <: AbstractADType
271275
tag::T
272276
end
273277

278+
function AutoPolyesterForwardDiff{chunksize}(tag::T) where {chunksize, T}
279+
return AutoPolyesterForwardDiff{chunksize, T}(tag)
280+
end
281+
274282
function AutoPolyesterForwardDiff(; chunksize = nothing, tag = nothing)
275-
AutoPolyesterForwardDiff{chunksize, typeof(tag)}(tag)
283+
return AutoPolyesterForwardDiff{chunksize}(tag)
276284
end
277285

278286
mode(::AutoPolyesterForwardDiff) = ForwardMode()

test/misc.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,36 @@ for backend in [
6161
]
6262
println(backend)
6363
end
64+
65+
using Setfield
66+
67+
@testset "Setfield compatibility" begin
68+
ad = AutoEnzyme()
69+
@test ad.mode === nothing
70+
@set! ad.mode = EnzymeCore.Reverse
71+
@test ad.mode isa EnzymeCore.ReverseMode
72+
73+
struct CustomTestTag end
74+
75+
ad = AutoForwardDiff()
76+
@test ad.tag === nothing
77+
@set! ad.tag = CustomTestTag()
78+
@test ad.tag isa CustomTestTag
79+
80+
ad = AutoForwardDiff(; chunksize = 10)
81+
@test ad.tag === nothing
82+
@set! ad.tag = CustomTestTag()
83+
@test ad.tag isa CustomTestTag
84+
@test ad isa AutoForwardDiff{10}
85+
86+
ad = AutoPolyesterForwardDiff()
87+
@test ad.tag === nothing
88+
@set! ad.tag = CustomTestTag()
89+
@test ad.tag isa CustomTestTag
90+
91+
ad = AutoPolyesterForwardDiff(; chunksize = 10)
92+
@test ad.tag === nothing
93+
@set! ad.tag = CustomTestTag()
94+
@test ad.tag isa CustomTestTag
95+
@test ad isa AutoPolyesterForwardDiff{10}
96+
end

0 commit comments

Comments
 (0)