Skip to content

Commit 3b6f62b

Browse files
Make seeds and extract_jacobian gpu-friendly
Use broadcast/macro consistently Fix jac Add AllocationsTest.jl
1 parent 5792261 commit 3b6f62b

File tree

4 files changed

+57
-23
lines changed

4 files changed

+57
-23
lines changed

src/apiutils.jl

+8-14
Original file line numberDiff line numberDiff line change
@@ -55,36 +55,30 @@ end
5555

5656
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
5757
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
58-
for i in eachindex(duals)
59-
duals[i] = Dual{T,V,N}(x[i], seed)
60-
end
58+
duals .= Dual{T,V,N}.(x, Ref(seed))
6159
return duals
6260
end
6361

6462
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
6563
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
66-
for i in 1:N
67-
duals[i] = Dual{T,V,N}(x[i], seeds[i])
68-
end
64+
dual_inds = 1:N
65+
duals[dual_inds] .= Dual{T,V,N}.(view(x,dual_inds), seeds)
6966
return duals
7067
end
7168

7269
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
7370
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
7471
offset = index - 1
75-
for i in 1:N
76-
j = i + offset
77-
duals[j] = Dual{T,V,N}(x[j], seed)
78-
end
72+
dual_inds = (1:N) .+ offset
73+
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), Ref(seed))
7974
return duals
8075
end
8176

8277
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
8378
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
8479
offset = index - 1
85-
for i in 1:chunksize
86-
j = i + offset
87-
duals[j] = Dual{T,V,N}(x[j], seeds[i])
88-
end
80+
seed_inds = 1:chunksize
81+
dual_inds = seed_inds .+ offset
82+
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), getindex.(Ref(seeds), seed_inds))
8983
return duals
9084
end

src/jacobian.jl

+10-9
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,10 @@ end
111111

112112
function extract_jacobian!(::Type{T}, result::AbstractArray, ydual::AbstractArray, n) where {T}
113113
out_reshaped = reshape(result, length(ydual), n)
114-
for col in 1:size(out_reshaped, 2), row in 1:size(out_reshaped, 1)
115-
out_reshaped[row, col] = partials(T, ydual[row], col)
116-
end
114+
ydual_reshaped = vec(ydual)
115+
# Use closure to avoid GPU broadcasting with Type
116+
partials_wrap(ydual, nrange) = partials(T, ydual, nrange)
117+
out_reshaped .= partials_wrap.(ydual_reshaped, transpose(1:n))
117118
return result
118119
end
119120

@@ -123,13 +124,13 @@ function extract_jacobian!(::Type{T}, result::MutableDiffResult, ydual::Abstract
123124
end
124125

125126
function extract_jacobian_chunk!(::Type{T}, result, ydual, index, chunksize) where {T}
127+
ydual_reshaped = vec(ydual)
126128
offset = index - 1
127-
for i in 1:chunksize
128-
col = i + offset
129-
for row in eachindex(ydual)
130-
result[row, col] = partials(T, ydual[row], i)
131-
end
132-
end
129+
irange = 1:chunksize
130+
col = irange .+ offset
131+
# Use closure to avoid GPU broadcasting with Type
132+
partials_wrap(ydual, nrange) = partials(T, ydual, nrange)
133+
result[:, col] .= partials_wrap.(ydual_reshaped, transpose(irange))
133134
return result
134135
end
135136

test/AllocationsTest.jl

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
module AllocationsTest
2+
3+
using ForwardDiff
4+
5+
include(joinpath(dirname(@__FILE__), "utils.jl"))
6+
7+
@testset "Test seed! allocations" begin
8+
x = rand(1000)
9+
cfg = ForwardDiff.GradientConfig(nothing, x)
10+
duals = cfg.duals
11+
seeds = cfg.seeds
12+
seed = cfg.seeds[1]
13+
14+
alloc = @allocated ForwardDiff.seed!(duals, x, seeds)
15+
alloc = @allocated ForwardDiff.seed!(duals, x, seeds)
16+
@test alloc == 0
17+
18+
alloc = @allocated ForwardDiff.seed!(duals, x, seed)
19+
alloc = @allocated ForwardDiff.seed!(duals, x, seed)
20+
@test alloc == 0
21+
22+
index = 1
23+
alloc = @allocated ForwardDiff.seed!(duals, x, index, seeds)
24+
alloc = @allocated ForwardDiff.seed!(duals, x, index, seeds)
25+
@test alloc == 0
26+
27+
index = 1
28+
alloc = @allocated ForwardDiff.seed!(duals, x, index, seed)
29+
alloc = @allocated ForwardDiff.seed!(duals, x, index, seed)
30+
@test alloc == 0
31+
end
32+
33+
end

test/runtests.jl

+6
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,9 @@ println("done (took $t seconds).")
3131
println("Testing miscellaneous functionality...")
3232
t = @elapsed include("MiscTest.jl")
3333
println("done (took $t seconds).")
34+
35+
if VERSION >= v"1.5-"
36+
println("Testing allocations...")
37+
t = @elapsed include("AllocationsTest.jl")
38+
println("done (took $t seconds).")
39+
end

0 commit comments

Comments
 (0)