From dcb7f80e92ca46c42dfdaa792321f77788a2b95e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 25 Jun 2024 09:58:41 +0200 Subject: [PATCH 1/3] Pretty printing --- Project.toml | 2 +- src/dense.jl | 82 +++++++++++++++++++++++++++++++++++++++++++++++- src/sparse.jl | 12 +++++++ test/misc.jl | 22 +++++++++++-- test/runtests.jl | 27 ++++++++++++++++ 5 files changed, 141 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index c894b31..d9ea062 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.5.0" +version = "1.5.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/dense.jl b/src/dense.jl index c6e812a..c1276d3 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -19,6 +19,10 @@ end mode(::AutoChainRules) = ForwardOrReverseMode() # specialized in the extension +function Base.show(io::IO, backend::AutoChainRules) + print(io, "AutoChainRules(ruleconfig=$(backend.ruleconfig))") +end + """ AutoDiffractor @@ -58,6 +62,14 @@ end mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension +function Base.show(io::IO, backend::AutoEnzyme) + if isnothing(backend.mode) + print(io, "AutoEnzyme()") + else + print(io, "AutoEnzyme(mode=$(backend.mode))") + end +end + """ AutoFastDifferentiation @@ -98,6 +110,24 @@ end mode(::AutoFiniteDiff) = ForwardMode() +function Base.show(io::IO, backend::AutoFiniteDiff) + s = "AutoFiniteDiff(" + if backend.fdtype != Val(:forward) + s *= "fdtype=$(backend.fdtype), " + end + if backend.fdjtype != Val(:forward) + s *= "fdjtype=$(backend.fdjtype), " + end + if backend.fdhtype != Val(:hcentral) + s *= "fdhtype=$(backend.fdhtype), " + end + if endswith(s, ", ") + s = s[1:(end - 2)] + end + s *= ")" + print(io, s) +end + """ AutoFiniteDifferences{T} @@ -119,6 +149,10 @@ end mode(::AutoFiniteDifferences) = ForwardMode() +function Base.show(io::IO, backend::AutoFiniteDifferences) + print(io, "AutoFiniteDifferences(fdm=$(backend.fdm))") +end + """ AutoForwardDiff{chunksize,T} @@ -148,6 +182,21 @@ end mode(::AutoForwardDiff) = ForwardMode() +function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize} + s = "AutoForwardDiff(" + if chunksize !== nothing + s *= "chunksize=$chunksize, " + end + if backend.tag !== nothing + s *= "tag=$(backend.tag), " + end + if endswith(s, ", ") + s = s[1:(end - 2)] + end + s *= ")" + print(io, s) +end + """ AutoPolyesterForwardDiff{chunksize,T} @@ -177,6 +226,21 @@ end mode(::AutoPolyesterForwardDiff) = ForwardMode() +function Base.show(io::IO, backend::AutoPolyesterForwardDiff{chunksize}) where {chunksize} + s = "AutoPolyesterForwardDiff(" + if chunksize !== nothing + s *= "chunksize=$chunksize, " + end + if backend.tag !== nothing + s *= "tag=$(backend.tag), " + end + if endswith(s, ", ") + s = s[1:(end - 2)] + end + s *= ")" + print(io, s) +end + """ AutoReverseDiff @@ -193,7 +257,7 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). - `compile::Union{Val, Bool}`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation """ struct AutoReverseDiff{C} <: AbstractADType - compile::Bool # this field if left for legacy reasons + compile::Bool # this field is left for legacy reasons function AutoReverseDiff(; compile::Union{Val, Bool} = Val(false)) _compile = _unwrap_val(compile) @@ -212,6 +276,14 @@ end mode(::AutoReverseDiff) = ReverseMode() +function Base.show(io::IO, ::AutoReverseDiff{compile}) where {compile} + if !compile + print(io, "AutoReverseDiff()") + else + print(io, "AutoReverseDiff(compile=true)") + end +end + """ AutoSymbolics @@ -248,6 +320,14 @@ end mode(::AutoTapir) = ReverseMode() +function Base.show(io::IO, backend::AutoTapir) + if backend.safe_mode + print(io, "AutoTapir()") + else + print(io, "AutoTapir(safe_mode=false)") + end +end + """ AutoTracker diff --git a/src/sparse.jl b/src/sparse.jl index 70bf692..0db7914 100644 --- a/src/sparse.jl +++ b/src/sparse.jl @@ -154,6 +154,18 @@ function AutoSparse( }(dense_ad, sparsity_detector, coloring_algorithm) end +function Base.show(io::IO, backend::AutoSparse) + s = "AutoSparse(dense_ad=$(backend.dense_ad), " + if backend.sparsity_detector != NoSparsityDetector() + s *= "sparsity_detector=$(backend.sparsity_detector), " + end + if backend.coloring_algorithm != NoColoringAlgorithm() + s *= "coloring_algorithm=$(backend.coloring_algorithm)), " + end + s = s[1:(end - 2)] * ")" + print(io, s) +end + """ dense_ad(ad::AutoSparse)::AbstractADType diff --git a/test/misc.jl b/test/misc.jl index 00a56d7..a1b2274 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -1,3 +1,21 @@ -for ad in every_ad() - @test identity.(ad) == ad +@testset "Broadcasting" begin + for ad in every_ad() + @test identity.(ad) == ad + end +end + +@testset "Printing" begin + for ad in every_ad_with_options() + @test startswith(string(ad), "Auto") + @test endswith(string(ad), ")") + end + + sparse_backend1 = AutoSparse(AutoForwardDiff()) + sparse_backend2 = AutoSparse( + AutoForwardDiff(); + sparsity_detector = FakeSparsityDetector(), + coloring_algorithm = FakeColoringAlgorithm() + ) + @test contains(string(sparse_backend1), string(AutoForwardDiff())) + @test length(string(sparse_backend1)) < length(string(sparse_backend2)) end diff --git a/test/runtests.jl b/test/runtests.jl index e7d72f9..76fccba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,9 @@ struct ForwardRuleConfig <: RuleConfig{Union{HasForwardsMode, NoReverseMode}} en struct ReverseRuleConfig <: RuleConfig{Union{NoForwardsMode, HasReverseMode}} end struct ForwardOrReverseRuleConfig <: RuleConfig{Union{HasForwardsMode, HasReverseMode}} end +struct FakeSparsityDetector <: ADTypes.AbstractSparsityDetector end +struct FakeColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end + function every_ad() return [ AutoChainRules(; ruleconfig = :rc), @@ -49,6 +52,30 @@ function every_ad() ] end +function every_ad_with_options() + return [ + AutoChainRules(; ruleconfig = :rc), + AutoDiffractor(), + AutoEnzyme(), + AutoEnzyme(mode = :forward), + AutoFastDifferentiation(), + AutoFiniteDiff(), + AutoFiniteDiff(fdtype = :fd, fdjtype = :fdj, fdhtype = :fdh), + AutoFiniteDifferences(; fdm = :fdm), + AutoForwardDiff(), + AutoForwardDiff(chunksize = 3, tag = :tag), + AutoPolyesterForwardDiff(), + AutoPolyesterForwardDiff(chunksize = 3, tag = :tag), + AutoReverseDiff(), + AutoReverseDiff(compile = true), + AutoSymbolics(), + AutoTapir(), + AutoTapir(safe_mode = false), + AutoTracker(), + AutoZygote() + ] +end + ## Tests @testset verbose=true "ADTypes.jl" begin From e7462030aa5bc4acabb29d9953535875c3702ed9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 25 Jun 2024 10:13:24 +0200 Subject: [PATCH 2/3] Fix default for FiniteDiff --- src/dense.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dense.jl b/src/dense.jl index c1276d3..2355838 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -115,7 +115,7 @@ function Base.show(io::IO, backend::AutoFiniteDiff) if backend.fdtype != Val(:forward) s *= "fdtype=$(backend.fdtype), " end - if backend.fdjtype != Val(:forward) + if backend.fdjtype != backend.fdtype s *= "fdjtype=$(backend.fdjtype), " end if backend.fdhtype != Val(:hcentral) From 117cf236d021becc4907b07d6d2bc24a9ae488fc Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 25 Jun 2024 12:00:20 +0200 Subject: [PATCH 3/3] Use repr --- src/dense.jl | 16 ++++++++-------- src/sparse.jl | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/dense.jl b/src/dense.jl index 2355838..7958a67 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -20,7 +20,7 @@ end mode(::AutoChainRules) = ForwardOrReverseMode() # specialized in the extension function Base.show(io::IO, backend::AutoChainRules) - print(io, "AutoChainRules(ruleconfig=$(backend.ruleconfig))") + print(io, "AutoChainRules(ruleconfig=$(repr(backend.ruleconfig, context=io)))") end """ @@ -66,7 +66,7 @@ function Base.show(io::IO, backend::AutoEnzyme) if isnothing(backend.mode) print(io, "AutoEnzyme()") else - print(io, "AutoEnzyme(mode=$(backend.mode))") + print(io, "AutoEnzyme(mode=$(repr(backend.mode, context=io)))") end end @@ -113,13 +113,13 @@ mode(::AutoFiniteDiff) = ForwardMode() function Base.show(io::IO, backend::AutoFiniteDiff) s = "AutoFiniteDiff(" if backend.fdtype != Val(:forward) - s *= "fdtype=$(backend.fdtype), " + s *= "fdtype=$(repr(backend.fdtype, context=io)), " end if backend.fdjtype != backend.fdtype - s *= "fdjtype=$(backend.fdjtype), " + s *= "fdjtype=$(repr(backend.fdjtype, context=io)), " end if backend.fdhtype != Val(:hcentral) - s *= "fdhtype=$(backend.fdhtype), " + s *= "fdhtype=$(repr(backend.fdhtype, context=io)), " end if endswith(s, ", ") s = s[1:(end - 2)] @@ -150,7 +150,7 @@ end mode(::AutoFiniteDifferences) = ForwardMode() function Base.show(io::IO, backend::AutoFiniteDifferences) - print(io, "AutoFiniteDifferences(fdm=$(backend.fdm))") + print(io, "AutoFiniteDifferences(fdm=$(repr(backend.fdm, context=io)))") end """ @@ -188,7 +188,7 @@ function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize s *= "chunksize=$chunksize, " end if backend.tag !== nothing - s *= "tag=$(backend.tag), " + s *= "tag=$(repr(backend.tag, context=io)), " end if endswith(s, ", ") s = s[1:(end - 2)] @@ -232,7 +232,7 @@ function Base.show(io::IO, backend::AutoPolyesterForwardDiff{chunksize}) where { s *= "chunksize=$chunksize, " end if backend.tag !== nothing - s *= "tag=$(backend.tag), " + s *= "tag=$(repr(backend.tag, context=io)), " end if endswith(s, ", ") s = s[1:(end - 2)] diff --git a/src/sparse.jl b/src/sparse.jl index 0db7914..85a8e97 100644 --- a/src/sparse.jl +++ b/src/sparse.jl @@ -155,12 +155,12 @@ function AutoSparse( end function Base.show(io::IO, backend::AutoSparse) - s = "AutoSparse(dense_ad=$(backend.dense_ad), " + s = "AutoSparse(dense_ad=$(repr(backend.dense_ad, context=io)), " if backend.sparsity_detector != NoSparsityDetector() - s *= "sparsity_detector=$(backend.sparsity_detector), " + s *= "sparsity_detector=$(repr(backend.sparsity_detector, context=io)), " end if backend.coloring_algorithm != NoColoringAlgorithm() - s *= "coloring_algorithm=$(backend.coloring_algorithm)), " + s *= "coloring_algorithm=$(repr(backend.coloring_algorithm, context=io))), " end s = s[1:(end - 2)] * ")" print(io, s)