diff --git a/Project.toml b/Project.toml index c894b31..d9ea062 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.5.0" +version = "1.5.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/dense.jl b/src/dense.jl index c6e812a..7958a67 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -19,6 +19,10 @@ end mode(::AutoChainRules) = ForwardOrReverseMode() # specialized in the extension +function Base.show(io::IO, backend::AutoChainRules) + print(io, "AutoChainRules(ruleconfig=$(repr(backend.ruleconfig, context=io)))") +end + """ AutoDiffractor @@ -58,6 +62,14 @@ end mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension +function Base.show(io::IO, backend::AutoEnzyme) + if isnothing(backend.mode) + print(io, "AutoEnzyme()") + else + print(io, "AutoEnzyme(mode=$(repr(backend.mode, context=io)))") + end +end + """ AutoFastDifferentiation @@ -98,6 +110,24 @@ end mode(::AutoFiniteDiff) = ForwardMode() +function Base.show(io::IO, backend::AutoFiniteDiff) + s = "AutoFiniteDiff(" + if backend.fdtype != Val(:forward) + s *= "fdtype=$(repr(backend.fdtype, context=io)), " + end + if backend.fdjtype != backend.fdtype + s *= "fdjtype=$(repr(backend.fdjtype, context=io)), " + end + if backend.fdhtype != Val(:hcentral) + s *= "fdhtype=$(repr(backend.fdhtype, context=io)), " + end + if endswith(s, ", ") + s = s[1:(end - 2)] + end + s *= ")" + print(io, s) +end + """ AutoFiniteDifferences{T} @@ -119,6 +149,10 @@ end mode(::AutoFiniteDifferences) = ForwardMode() +function Base.show(io::IO, backend::AutoFiniteDifferences) + print(io, "AutoFiniteDifferences(fdm=$(repr(backend.fdm, context=io)))") +end + """ AutoForwardDiff{chunksize,T} @@ -148,6 +182,21 @@ end mode(::AutoForwardDiff) = ForwardMode() +function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize} + s = "AutoForwardDiff(" + if chunksize !== nothing + s *= "chunksize=$chunksize, " + end + if backend.tag !== nothing + s *= "tag=$(repr(backend.tag, context=io)), " + end + if endswith(s, ", ") + s = s[1:(end - 2)] + end + s *= ")" + print(io, s) +end + """ AutoPolyesterForwardDiff{chunksize,T} @@ -177,6 +226,21 @@ end mode(::AutoPolyesterForwardDiff) = ForwardMode() +function Base.show(io::IO, backend::AutoPolyesterForwardDiff{chunksize}) where {chunksize} + s = "AutoPolyesterForwardDiff(" + if chunksize !== nothing + s *= "chunksize=$chunksize, " + end + if backend.tag !== nothing + s *= "tag=$(repr(backend.tag, context=io)), " + end + if endswith(s, ", ") + s = s[1:(end - 2)] + end + s *= ")" + print(io, s) +end + """ AutoReverseDiff @@ -193,7 +257,7 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). - `compile::Union{Val, Bool}`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation """ struct AutoReverseDiff{C} <: AbstractADType - compile::Bool # this field if left for legacy reasons + compile::Bool # this field is left for legacy reasons function AutoReverseDiff(; compile::Union{Val, Bool} = Val(false)) _compile = _unwrap_val(compile) @@ -212,6 +276,14 @@ end mode(::AutoReverseDiff) = ReverseMode() +function Base.show(io::IO, ::AutoReverseDiff{compile}) where {compile} + if !compile + print(io, "AutoReverseDiff()") + else + print(io, "AutoReverseDiff(compile=true)") + end +end + """ AutoSymbolics @@ -248,6 +320,14 @@ end mode(::AutoTapir) = ReverseMode() +function Base.show(io::IO, backend::AutoTapir) + if backend.safe_mode + print(io, "AutoTapir()") + else + print(io, "AutoTapir(safe_mode=false)") + end +end + """ AutoTracker diff --git a/src/sparse.jl b/src/sparse.jl index 70bf692..85a8e97 100644 --- a/src/sparse.jl +++ b/src/sparse.jl @@ -154,6 +154,18 @@ function AutoSparse( }(dense_ad, sparsity_detector, coloring_algorithm) end +function Base.show(io::IO, backend::AutoSparse) + s = "AutoSparse(dense_ad=$(repr(backend.dense_ad, context=io)), " + if backend.sparsity_detector != NoSparsityDetector() + s *= "sparsity_detector=$(repr(backend.sparsity_detector, context=io)), " + end + if backend.coloring_algorithm != NoColoringAlgorithm() + s *= "coloring_algorithm=$(repr(backend.coloring_algorithm, context=io))), " + end + s = s[1:(end - 2)] * ")" + print(io, s) +end + """ dense_ad(ad::AutoSparse)::AbstractADType diff --git a/test/misc.jl b/test/misc.jl index 00a56d7..a1b2274 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -1,3 +1,21 @@ -for ad in every_ad() - @test identity.(ad) == ad +@testset "Broadcasting" begin + for ad in every_ad() + @test identity.(ad) == ad + end +end + +@testset "Printing" begin + for ad in every_ad_with_options() + @test startswith(string(ad), "Auto") + @test endswith(string(ad), ")") + end + + sparse_backend1 = AutoSparse(AutoForwardDiff()) + sparse_backend2 = AutoSparse( + AutoForwardDiff(); + sparsity_detector = FakeSparsityDetector(), + coloring_algorithm = FakeColoringAlgorithm() + ) + @test contains(string(sparse_backend1), string(AutoForwardDiff())) + @test length(string(sparse_backend1)) < length(string(sparse_backend2)) end diff --git a/test/runtests.jl b/test/runtests.jl index e7d72f9..76fccba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,9 @@ struct ForwardRuleConfig <: RuleConfig{Union{HasForwardsMode, NoReverseMode}} en struct ReverseRuleConfig <: RuleConfig{Union{NoForwardsMode, HasReverseMode}} end struct ForwardOrReverseRuleConfig <: RuleConfig{Union{HasForwardsMode, HasReverseMode}} end +struct FakeSparsityDetector <: ADTypes.AbstractSparsityDetector end +struct FakeColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end + function every_ad() return [ AutoChainRules(; ruleconfig = :rc), @@ -49,6 +52,30 @@ function every_ad() ] end +function every_ad_with_options() + return [ + AutoChainRules(; ruleconfig = :rc), + AutoDiffractor(), + AutoEnzyme(), + AutoEnzyme(mode = :forward), + AutoFastDifferentiation(), + AutoFiniteDiff(), + AutoFiniteDiff(fdtype = :fd, fdjtype = :fdj, fdhtype = :fdh), + AutoFiniteDifferences(; fdm = :fdm), + AutoForwardDiff(), + AutoForwardDiff(chunksize = 3, tag = :tag), + AutoPolyesterForwardDiff(), + AutoPolyesterForwardDiff(chunksize = 3, tag = :tag), + AutoReverseDiff(), + AutoReverseDiff(compile = true), + AutoSymbolics(), + AutoTapir(), + AutoTapir(safe_mode = false), + AutoTracker(), + AutoZygote() + ] +end + ## Tests @testset verbose=true "ADTypes.jl" begin