Skip to content

Commit 47a7ec6

Browse files
authored
Restructure sparse backends and replace subtyping by traits (#40)
1 parent c9b2692 commit 47a7ec6

File tree

16 files changed

+978
-381
lines changed

16 files changed

+978
-381
lines changed

.github/workflows/CI.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
version:
21-
- '1.6'
21+
- '1.10'
2222
- '1'
23-
- 'nightly'
2423
os:
2524
- ubuntu-latest
2625
arch:
@@ -34,3 +33,9 @@ jobs:
3433
- uses: julia-actions/cache@v1
3534
- uses: julia-actions/julia-buildpkg@v1
3635
- uses: julia-actions/julia-runtest@v1
36+
- uses: julia-actions/julia-processcoverage@v1
37+
- uses: codecov/codecov-action@v4
38+
with:
39+
files: lcov.info
40+
token: ${{ secrets.CODECOV_TOKEN }}
41+
fail_ci_if_error: true

Project.toml

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,29 @@
11
name = "ADTypes"
22
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
3-
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
4-
version = "0.2.7"
3+
authors = [
4+
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
5+
]
6+
version = "1.0.0"
7+
8+
[weakdeps]
9+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
10+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
11+
12+
[extensions]
13+
ADTypesChainRulesCoreExt = "ChainRulesCore"
14+
ADTypesEnzymeCoreExt = "EnzymeCore"
515

616
[compat]
7-
julia = "1.6"
17+
ChainRulesCore = "1.23.0"
18+
EnzymeCore = "0.7.2"
19+
julia = "1.10"
820

921
[extras]
22+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
23+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
24+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
25+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1026
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1127

1228
[targets]
13-
test = ["Test"]
29+
test = ["Aqua", "ChainRulesCore", "EnzymeCore", "JET", "Test"]

README.md

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,24 @@
44
[![Docs dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://SciML.github.io/ADTypes.jl/dev/)
55
[![Build Status](https://github.com/SciML/ADTypes.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/SciML/ADTypes.jl/actions/workflows/CI.yml?query=branch%3Amain)
66

7-
ADTypes.jl is a multi-valued logic system specifying the choice of an automatic differentiation (AD) library and its parameters.
7+
ADTypes.jl is a multi-valued logic system to choose an automatic differentiation (AD) package and specify its parameters.
88

99
## Which AD libraries are supported?
1010

1111
See the API reference in the documentation.
12+
If a given package is missing, feel free to open an issue or pull request.
1213

13-
## Why should packages adopt this standard?
14+
## Why should AD users adopt this standard?
1415

15-
A common practice is the use of a boolean keyword argument like `autodiff = true/false`.
16-
However, boolean logic is not precise enough for all the choices required.
17-
For instance, forward mode AD is implemented by both ForwardDiff and Enzyme, which makes `autodiff = true` ambiguous.
18-
Something like `ChooseForwardDiff()` is thus required, possibly with additional parameters depending on the library.
16+
A natural approach is to use a keyword argument with e.g. `Bool` or `Symbol` values.
17+
Let's see a few examples to understand why this is not enough:
1918

20-
The risk is that every package developer might develop their own version of `ChooseForwardDiff()`, which would ruin interoperability.
21-
This is why ADTypes.jl provides a single set of shared types for this task, as an extremely lightweight dependency.
22-
Wonder no more: `ADTypes.AutoForwardDiff()` is the way to go.
23-
24-
## Why define types instead of enums?
25-
26-
If we used enums, they would not contain type-level information useful for dispatch.
27-
This is needed by many AD libraries to ensure type stability.
28-
Notably, the choice of config or cache type is different with each AD, so we must know statically which AD library is chosen.
19+
- `autodiff = true`: ambiguous, we don't know which AD package should be used
20+
- `autodiff = :forward`: ambiguous, there are several AD packages implementing both forward and reverse mode (and there are other modes beyond that)
21+
- `autodiff = :Enzyme`: ambiguous, some AD packages can work both in forward and reverse mode
22+
- `autodiff = (:Enzyme, :forward)`: not too bad, but many AD packages require additional configuration (number of chunks, tape compilation, etc.)
2923

30-
## Why is this AD package missing?
31-
32-
Feel free to open a pull request adding it.
24+
A more involved struct is thus required, with package-specific parameters.
25+
If every AD user develops their own version of said struct, it will ruin interoperability.
26+
This is why ADTypes.jl provides a single set of shared types for this task, as an extremely lightweight dependency.
27+
They are types and not enums because we need AD choice information statically to use it for dispatch.

docs/src/index.md

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,87 @@ CollapsedDocStrings = true
77

88
Documentation for [ADTypes.jl](https://github.com/SciML/ADTypes.jl/).
99

10-
## Public
10+
```@docs
11+
ADTypes
12+
AbstractADType
13+
```
14+
15+
## Dense AD
16+
17+
### Forward mode
18+
19+
Algorithmic differentiation:
20+
21+
```@docs
22+
AutoForwardDiff
23+
AutoPolyesterForwardDiff
24+
```
25+
26+
Finite differences:
27+
28+
```@docs
29+
AutoFiniteDiff
30+
AutoFiniteDifferences
31+
```
32+
33+
### Reverse mode
34+
35+
```@docs
36+
AutoReverseDiff
37+
AutoTapir
38+
AutoTracker
39+
AutoZygote
40+
```
41+
42+
### Forward or reverse mode
43+
44+
```@docs
45+
AutoEnzyme
46+
AutoChainRules
47+
AutoDiffractor
48+
```
49+
50+
### Symbolic mode
51+
52+
```@docs
53+
AutoFastDifferentiation
54+
AutoSymbolics
55+
```
56+
57+
## Sparse AD
58+
59+
```@docs
60+
AutoSparse
61+
ADTypes.dense_ad
62+
```
63+
64+
### Sparsity detector
65+
66+
```@docs
67+
ADTypes.sparsity_detector
68+
ADTypes.AbstractSparsityDetector
69+
ADTypes.jacobian_sparsity
70+
ADTypes.hessian_sparsity
71+
ADTypes.NoSparsityDetector
72+
```
73+
74+
### Coloring algorithm
1175

12-
```@autodocs
13-
Modules = [ADTypes]
14-
Private = false
76+
```@docs
77+
ADTypes.coloring_algorithm
78+
ADTypes.AbstractColoringAlgorithm
79+
ADTypes.column_coloring
80+
ADTypes.row_coloring
81+
ADTypes.NoColoringAlgorithm
1582
```
1683

17-
## Internal
84+
## Modes
1885

19-
```@autodocs
20-
Modules = [ADTypes]
21-
Public = false
86+
```@docs
87+
ADTypes.mode
88+
ADTypes.AbstractMode
89+
ADTypes.ForwardMode
90+
ADTypes.ForwardOrReverseMode
91+
ADTypes.ReverseMode
92+
ADTypes.SymbolicMode
2293
```

ext/ADTypesChainRulesCoreExt.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
module ADTypesChainRulesCoreExt
2+
3+
using ADTypes: ADTypes, AutoChainRules
4+
using ChainRulesCore: HasForwardsMode, HasReverseMode,
5+
NoForwardsMode, NoReverseMode,
6+
RuleConfig
7+
8+
# see https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html
9+
10+
function ADTypes.mode(::AutoChainRules{RC}) where {
11+
RC <: RuleConfig{>:HasForwardsMode}
12+
}
13+
return ADTypes.ForwardMode()
14+
end
15+
16+
function ADTypes.mode(::AutoChainRules{RC}) where {
17+
RC <: RuleConfig{>:HasReverseMode}
18+
}
19+
return ADTypes.ReverseMode()
20+
end
21+
22+
function ADTypes.mode(::AutoChainRules{RC}) where {
23+
RC <: RuleConfig{>:Union{HasForwardsMode, HasReverseMode}}
24+
}
25+
# more specific than the previous two
26+
return ADTypes.ForwardOrReverseMode()
27+
end
28+
29+
end

ext/ADTypesEnzymeCoreExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module ADTypesEnzymeCoreExt
2+
3+
using ADTypes: ADTypes, AutoEnzyme
4+
using EnzymeCore: EnzymeCore
5+
6+
ADTypes.mode(::AutoEnzyme{M}) where {M <: EnzymeCore.ForwardMode} = ADTypes.ForwardMode()
7+
ADTypes.mode(::AutoEnzyme{M}) where {M <: EnzymeCore.ReverseMode} = ADTypes.ReverseMode()
8+
9+
end

0 commit comments

Comments
 (0)