Skip to content

Pretty printing #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = [
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
]
version = "1.5.0"
version = "1.5.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
82 changes: 81 additions & 1 deletion src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ end

mode(::AutoChainRules) = ForwardOrReverseMode() # specialized in the extension

function Base.show(io::IO, backend::AutoChainRules)
print(io, "AutoChainRules(ruleconfig=$(repr(backend.ruleconfig, context=io)))")
end

"""
AutoDiffractor

Expand Down Expand Up @@ -58,6 +62,14 @@ end

mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension

function Base.show(io::IO, backend::AutoEnzyme)
if isnothing(backend.mode)
print(io, "AutoEnzyme()")
else
print(io, "AutoEnzyme(mode=$(repr(backend.mode, context=io)))")
end
end

"""
AutoFastDifferentiation

Expand Down Expand Up @@ -98,6 +110,24 @@ end

mode(::AutoFiniteDiff) = ForwardMode()

function Base.show(io::IO, backend::AutoFiniteDiff)
s = "AutoFiniteDiff("
if backend.fdtype != Val(:forward)
s *= "fdtype=$(repr(backend.fdtype, context=io)), "
end
if backend.fdjtype != backend.fdtype
s *= "fdjtype=$(repr(backend.fdjtype, context=io)), "
end
if backend.fdhtype != Val(:hcentral)
s *= "fdhtype=$(repr(backend.fdhtype, context=io)), "
end
if endswith(s, ", ")
s = s[1:(end - 2)]
end
s *= ")"
print(io, s)
end

"""
AutoFiniteDifferences{T}

Expand All @@ -119,6 +149,10 @@ end

mode(::AutoFiniteDifferences) = ForwardMode()

function Base.show(io::IO, backend::AutoFiniteDifferences)
print(io, "AutoFiniteDifferences(fdm=$(repr(backend.fdm, context=io)))")
end

"""
AutoForwardDiff{chunksize,T}

Expand Down Expand Up @@ -148,6 +182,21 @@ end

mode(::AutoForwardDiff) = ForwardMode()

function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize}
s = "AutoForwardDiff("
if chunksize !== nothing
s *= "chunksize=$chunksize, "
end
if backend.tag !== nothing
s *= "tag=$(repr(backend.tag, context=io)), "
end
if endswith(s, ", ")
s = s[1:(end - 2)]
end
s *= ")"
print(io, s)
end

"""
AutoPolyesterForwardDiff{chunksize,T}

Expand Down Expand Up @@ -177,6 +226,21 @@ end

mode(::AutoPolyesterForwardDiff) = ForwardMode()

function Base.show(io::IO, backend::AutoPolyesterForwardDiff{chunksize}) where {chunksize}
s = "AutoPolyesterForwardDiff("
if chunksize !== nothing
s *= "chunksize=$chunksize, "
end
if backend.tag !== nothing
s *= "tag=$(repr(backend.tag, context=io)), "
end
if endswith(s, ", ")
s = s[1:(end - 2)]
end
s *= ")"
print(io, s)
end

"""
AutoReverseDiff

Expand All @@ -193,7 +257,7 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
- `compile::Union{Val, Bool}`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation
"""
struct AutoReverseDiff{C} <: AbstractADType
compile::Bool # this field if left for legacy reasons
compile::Bool # this field is left for legacy reasons

function AutoReverseDiff(; compile::Union{Val, Bool} = Val(false))
_compile = _unwrap_val(compile)
Expand All @@ -212,6 +276,14 @@ end

mode(::AutoReverseDiff) = ReverseMode()

function Base.show(io::IO, ::AutoReverseDiff{compile}) where {compile}
if !compile
print(io, "AutoReverseDiff()")
else
print(io, "AutoReverseDiff(compile=true)")
end
end

"""
AutoSymbolics

Expand Down Expand Up @@ -248,6 +320,14 @@ end

mode(::AutoTapir) = ReverseMode()

function Base.show(io::IO, backend::AutoTapir)
if backend.safe_mode
print(io, "AutoTapir()")
else
print(io, "AutoTapir(safe_mode=false)")
end
end

"""
AutoTracker

Expand Down
12 changes: 12 additions & 0 deletions src/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ function AutoSparse(
}(dense_ad, sparsity_detector, coloring_algorithm)
end

function Base.show(io::IO, backend::AutoSparse)
s = "AutoSparse(dense_ad=$(repr(backend.dense_ad, context=io)), "
if backend.sparsity_detector != NoSparsityDetector()
s *= "sparsity_detector=$(repr(backend.sparsity_detector, context=io)), "
end
if backend.coloring_algorithm != NoColoringAlgorithm()
s *= "coloring_algorithm=$(repr(backend.coloring_algorithm, context=io))), "
end
s = s[1:(end - 2)] * ")"
print(io, s)
end

"""
dense_ad(ad::AutoSparse)::AbstractADType

Expand Down
22 changes: 20 additions & 2 deletions test/misc.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
for ad in every_ad()
@test identity.(ad) == ad
@testset "Broadcasting" begin
for ad in every_ad()
@test identity.(ad) == ad
end
end

@testset "Printing" begin
for ad in every_ad_with_options()
@test startswith(string(ad), "Auto")
@test endswith(string(ad), ")")
end

sparse_backend1 = AutoSparse(AutoForwardDiff())
sparse_backend2 = AutoSparse(
AutoForwardDiff();
sparsity_detector = FakeSparsityDetector(),
coloring_algorithm = FakeColoringAlgorithm()
)
@test contains(string(sparse_backend1), string(AutoForwardDiff()))
@test length(string(sparse_backend1)) < length(string(sparse_backend2))
end
27 changes: 27 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ struct ForwardRuleConfig <: RuleConfig{Union{HasForwardsMode, NoReverseMode}} en
struct ReverseRuleConfig <: RuleConfig{Union{NoForwardsMode, HasReverseMode}} end
struct ForwardOrReverseRuleConfig <: RuleConfig{Union{HasForwardsMode, HasReverseMode}} end

struct FakeSparsityDetector <: ADTypes.AbstractSparsityDetector end
struct FakeColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end

function every_ad()
return [
AutoChainRules(; ruleconfig = :rc),
Expand All @@ -49,6 +52,30 @@ function every_ad()
]
end

function every_ad_with_options()
return [
AutoChainRules(; ruleconfig = :rc),
AutoDiffractor(),
AutoEnzyme(),
AutoEnzyme(mode = :forward),
AutoFastDifferentiation(),
AutoFiniteDiff(),
AutoFiniteDiff(fdtype = :fd, fdjtype = :fdj, fdhtype = :fdh),
AutoFiniteDifferences(; fdm = :fdm),
AutoForwardDiff(),
AutoForwardDiff(chunksize = 3, tag = :tag),
AutoPolyesterForwardDiff(),
AutoPolyesterForwardDiff(chunksize = 3, tag = :tag),
AutoReverseDiff(),
AutoReverseDiff(compile = true),
AutoSymbolics(),
AutoTapir(),
AutoTapir(safe_mode = false),
AutoTracker(),
AutoZygote()
]
end

## Tests

@testset verbose=true "ADTypes.jl" begin
Expand Down
Loading