Skip to content

Commit f39a36b

Browse files
Make seeds and extract_jacobian gpu-friendly
Use broadcast/macro consistently Fix jac
1 parent 4c7495d commit f39a36b

File tree

2 files changed

+18
-23
lines changed

2 files changed

+18
-23
lines changed

src/apiutils.jl

+8-14
Original file line numberDiff line numberDiff line change
@@ -55,36 +55,30 @@ end
5555

5656
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
5757
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
58-
for i in eachindex(duals)
59-
duals[i] = Dual{T,V,N}(x[i], seed)
60-
end
58+
duals .= Dual{T,V,N}.(x, Ref(seed))
6159
return duals
6260
end
6361

6462
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
6563
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
66-
for i in 1:N
67-
duals[i] = Dual{T,V,N}(x[i], seeds[i])
68-
end
64+
i = 1:N
65+
duals[i] .= Dual{T,V,N}.(view(x,i), seeds)
6966
return duals
7067
end
7168

7269
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
7370
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
7471
offset = index - 1
75-
for i in 1:N
76-
j = i + offset
77-
duals[j] = Dual{T,V,N}(x[j], seed)
78-
end
72+
j = (1:N) .+ offset
73+
duals[j] .= Dual{T,V,N}.(view(x, j), Ref(seed))
7974
return duals
8075
end
8176

8277
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
8378
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
8479
offset = index - 1
85-
for i in 1:chunksize
86-
j = i + offset
87-
duals[j] = Dual{T,V,N}(x[j], seeds[i])
88-
end
80+
i = 1:chunksize
81+
j = i .+ offset
82+
duals[j] .= Dual{T,V,N}.(view(x, j), seeds[i])
8983
return duals
9084
end

src/jacobian.jl

+10-9
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,10 @@ end
111111

112112
function extract_jacobian!(::Type{T}, result::AbstractArray, ydual::AbstractArray, n) where {T}
113113
out_reshaped = reshape(result, length(ydual), n)
114-
for col in 1:size(out_reshaped, 2), row in 1:size(out_reshaped, 1)
115-
out_reshaped[row, col] = partials(T, ydual[row], col)
116-
end
114+
ydual_reshaped = reshape(ydual, length(ydual))
115+
# Use closure to avoid GPU broadcasting with Type
116+
partials_wrap(ydual, nrange) = partials(T, ydual, nrange)
117+
out_reshaped .= partials_wrap.(ydual_reshaped, transpose(1:n))
117118
return result
118119
end
119120

@@ -123,13 +124,13 @@ function extract_jacobian!(::Type{T}, result::MutableDiffResult, ydual::Abstract
123124
end
124125

125126
function extract_jacobian_chunk!(::Type{T}, result, ydual, index, chunksize) where {T}
127+
ydual_reshaped = reshape(ydual, length(ydual))
126128
offset = index - 1
127-
for i in 1:chunksize
128-
col = i + offset
129-
for row in eachindex(ydual)
130-
result[row, col] = partials(T, ydual[row], i)
131-
end
132-
end
129+
irange = 1:chunksize
130+
col = irange .+ offset
131+
# Use closure to avoid GPU broadcasting with Type
132+
partials_wrap(ydual, nrange) = partials(T, ydual, nrange)
133+
result[:, col] .= partials_wrap.(ydual_reshaped, transpose(irange))
133134
return result
134135
end
135136

0 commit comments

Comments
 (0)