Skip to content

Commit 16f421d

Browse files
authored
Pretty printing (#64)
* Pretty printing * Fix default for FiniteDiff * Use repr
1 parent eab4336 commit 16f421d

File tree

5 files changed

+141
-4
lines changed

5 files changed

+141
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = [
44
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
55
]
6-
version = "1.5.0"
6+
version = "1.5.1"
77

88
[deps]
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/dense.jl

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ end
1919

2020
mode(::AutoChainRules) = ForwardOrReverseMode() # specialized in the extension
2121

22+
function Base.show(io::IO, backend::AutoChainRules)
23+
print(io, "AutoChainRules(ruleconfig=$(repr(backend.ruleconfig, context=io)))")
24+
end
25+
2226
"""
2327
AutoDiffractor
2428
@@ -58,6 +62,14 @@ end
5862

5963
mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension
6064

65+
function Base.show(io::IO, backend::AutoEnzyme)
66+
if isnothing(backend.mode)
67+
print(io, "AutoEnzyme()")
68+
else
69+
print(io, "AutoEnzyme(mode=$(repr(backend.mode, context=io)))")
70+
end
71+
end
72+
6173
"""
6274
AutoFastDifferentiation
6375
@@ -98,6 +110,24 @@ end
98110

99111
mode(::AutoFiniteDiff) = ForwardMode()
100112

113+
function Base.show(io::IO, backend::AutoFiniteDiff)
114+
s = "AutoFiniteDiff("
115+
if backend.fdtype != Val(:forward)
116+
s *= "fdtype=$(repr(backend.fdtype, context=io)), "
117+
end
118+
if backend.fdjtype != backend.fdtype
119+
s *= "fdjtype=$(repr(backend.fdjtype, context=io)), "
120+
end
121+
if backend.fdhtype != Val(:hcentral)
122+
s *= "fdhtype=$(repr(backend.fdhtype, context=io)), "
123+
end
124+
if endswith(s, ", ")
125+
s = s[1:(end - 2)]
126+
end
127+
s *= ")"
128+
print(io, s)
129+
end
130+
101131
"""
102132
AutoFiniteDifferences{T}
103133
@@ -119,6 +149,10 @@ end
119149

120150
mode(::AutoFiniteDifferences) = ForwardMode()
121151

152+
function Base.show(io::IO, backend::AutoFiniteDifferences)
153+
print(io, "AutoFiniteDifferences(fdm=$(repr(backend.fdm, context=io)))")
154+
end
155+
122156
"""
123157
AutoForwardDiff{chunksize,T}
124158
@@ -148,6 +182,21 @@ end
148182

149183
mode(::AutoForwardDiff) = ForwardMode()
150184

185+
function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize}
186+
s = "AutoForwardDiff("
187+
if chunksize !== nothing
188+
s *= "chunksize=$chunksize, "
189+
end
190+
if backend.tag !== nothing
191+
s *= "tag=$(repr(backend.tag, context=io)), "
192+
end
193+
if endswith(s, ", ")
194+
s = s[1:(end - 2)]
195+
end
196+
s *= ")"
197+
print(io, s)
198+
end
199+
151200
"""
152201
AutoPolyesterForwardDiff{chunksize,T}
153202
@@ -177,6 +226,21 @@ end
177226

178227
mode(::AutoPolyesterForwardDiff) = ForwardMode()
179228

229+
function Base.show(io::IO, backend::AutoPolyesterForwardDiff{chunksize}) where {chunksize}
230+
s = "AutoPolyesterForwardDiff("
231+
if chunksize !== nothing
232+
s *= "chunksize=$chunksize, "
233+
end
234+
if backend.tag !== nothing
235+
s *= "tag=$(repr(backend.tag, context=io)), "
236+
end
237+
if endswith(s, ", ")
238+
s = s[1:(end - 2)]
239+
end
240+
s *= ")"
241+
print(io, s)
242+
end
243+
180244
"""
181245
AutoReverseDiff
182246
@@ -193,7 +257,7 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
193257
- `compile::Union{Val, Bool}`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation
194258
"""
195259
struct AutoReverseDiff{C} <: AbstractADType
196-
compile::Bool # this field if left for legacy reasons
260+
compile::Bool # this field is left for legacy reasons
197261

198262
function AutoReverseDiff(; compile::Union{Val, Bool} = Val(false))
199263
_compile = _unwrap_val(compile)
@@ -212,6 +276,14 @@ end
212276

213277
mode(::AutoReverseDiff) = ReverseMode()
214278

279+
function Base.show(io::IO, ::AutoReverseDiff{compile}) where {compile}
280+
if !compile
281+
print(io, "AutoReverseDiff()")
282+
else
283+
print(io, "AutoReverseDiff(compile=true)")
284+
end
285+
end
286+
215287
"""
216288
AutoSymbolics
217289
@@ -248,6 +320,14 @@ end
248320

249321
mode(::AutoTapir) = ReverseMode()
250322

323+
function Base.show(io::IO, backend::AutoTapir)
324+
if backend.safe_mode
325+
print(io, "AutoTapir()")
326+
else
327+
print(io, "AutoTapir(safe_mode=false)")
328+
end
329+
end
330+
251331
"""
252332
AutoTracker
253333

src/sparse.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,18 @@ function AutoSparse(
154154
}(dense_ad, sparsity_detector, coloring_algorithm)
155155
end
156156

157+
function Base.show(io::IO, backend::AutoSparse)
158+
s = "AutoSparse(dense_ad=$(repr(backend.dense_ad, context=io)), "
159+
if backend.sparsity_detector != NoSparsityDetector()
160+
s *= "sparsity_detector=$(repr(backend.sparsity_detector, context=io)), "
161+
end
162+
if backend.coloring_algorithm != NoColoringAlgorithm()
163+
s *= "coloring_algorithm=$(repr(backend.coloring_algorithm, context=io))), "
164+
end
165+
s = s[1:(end - 2)] * ")"
166+
print(io, s)
167+
end
168+
157169
"""
158170
dense_ad(ad::AutoSparse)::AbstractADType
159171

test/misc.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1-
for ad in every_ad()
2-
@test identity.(ad) == ad
1+
@testset "Broadcasting" begin
2+
for ad in every_ad()
3+
@test identity.(ad) == ad
4+
end
5+
end
6+
7+
@testset "Printing" begin
8+
for ad in every_ad_with_options()
9+
@test startswith(string(ad), "Auto")
10+
@test endswith(string(ad), ")")
11+
end
12+
13+
sparse_backend1 = AutoSparse(AutoForwardDiff())
14+
sparse_backend2 = AutoSparse(
15+
AutoForwardDiff();
16+
sparsity_detector = FakeSparsityDetector(),
17+
coloring_algorithm = FakeColoringAlgorithm()
18+
)
19+
@test contains(string(sparse_backend1), string(AutoForwardDiff()))
20+
@test length(string(sparse_backend1)) < length(string(sparse_backend2))
321
end

test/runtests.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ struct ForwardRuleConfig <: RuleConfig{Union{HasForwardsMode, NoReverseMode}} en
3131
struct ReverseRuleConfig <: RuleConfig{Union{NoForwardsMode, HasReverseMode}} end
3232
struct ForwardOrReverseRuleConfig <: RuleConfig{Union{HasForwardsMode, HasReverseMode}} end
3333

34+
struct FakeSparsityDetector <: ADTypes.AbstractSparsityDetector end
35+
struct FakeColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end
36+
3437
function every_ad()
3538
return [
3639
AutoChainRules(; ruleconfig = :rc),
@@ -49,6 +52,30 @@ function every_ad()
4952
]
5053
end
5154

55+
function every_ad_with_options()
56+
return [
57+
AutoChainRules(; ruleconfig = :rc),
58+
AutoDiffractor(),
59+
AutoEnzyme(),
60+
AutoEnzyme(mode = :forward),
61+
AutoFastDifferentiation(),
62+
AutoFiniteDiff(),
63+
AutoFiniteDiff(fdtype = :fd, fdjtype = :fdj, fdhtype = :fdh),
64+
AutoFiniteDifferences(; fdm = :fdm),
65+
AutoForwardDiff(),
66+
AutoForwardDiff(chunksize = 3, tag = :tag),
67+
AutoPolyesterForwardDiff(),
68+
AutoPolyesterForwardDiff(chunksize = 3, tag = :tag),
69+
AutoReverseDiff(),
70+
AutoReverseDiff(compile = true),
71+
AutoSymbolics(),
72+
AutoTapir(),
73+
AutoTapir(safe_mode = false),
74+
AutoTracker(),
75+
AutoZygote()
76+
]
77+
end
78+
5279
## Tests
5380

5481
@testset verbose=true "ADTypes.jl" begin

0 commit comments

Comments
 (0)