Skip to content

Commit 965750d

Browse files
Merge pull request #988 from SciML/gd/update_ad
Update AD comparison
2 parents a514075 + a6c7b0c commit 965750d

File tree

3 files changed

+452
-380
lines changed

3 files changed

+452
-380
lines changed

benchmarks/AutomaticDifferentiation/JuliaAD.jmd

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,23 @@ function paritytrig(x::AbstractVector{T}) where {T}
2121
end
2222

2323
backends = [
24-
AutoEnzyme(Enzyme.Reverse),
25-
AutoTapir(),
24+
AutoEnzyme(mode=Enzyme.Reverse),
25+
AutoTapir(safe_mode=false),
2626
AutoZygote(),
2727
];
2828

2929
scenarios = [
30-
GradientScenario(paritytrig, x=rand(100); operator=:inplace),
31-
GradientScenario(paritytrig, x=rand(10_000); operator=:inplace)
30+
GradientScenario(paritytrig; x=rand(100), y=0.0, nb_args=1, place=:inplace),
31+
GradientScenario(paritytrig; x=rand(10_000), y=0.0, nb_args=1, place=:inplace)
3232
];
3333

34-
result = benchmark_differentiation(backends, scenarios, logging=false);
35-
36-
data = DataFrame(result);
37-
38-
filtered_data = @chain data begin
39-
@select(:backend, :operator, :func, :input_type, :input_size, :time, :bytes, :allocs, :compile_fraction, :gc_fraction)
40-
@rsubset(string(:operator) in ["gradient!"])
41-
end
34+
data = benchmark_differentiation(backends, scenarios, logging=true);
4235

4336
table = PrettyTables.pretty_table(
4437
String,
45-
filtered_data;
38+
data;
4639
backend=Val(:markdown),
47-
header=names(filtered_data),
40+
header=names(data),
4841
formatters=PrettyTables.ft_printf("%.1e"),
4942
)
5043

0 commit comments

Comments
 (0)