Skip to content

Commit d033d2a

Browse files
use SIMD.jl for explicit vectorization of partial operations (#557)
Co-authored-by: Yingbo Ma <[email protected]>
1 parent 0af523a commit d033d2a

File tree

4 files changed

+29
-25
lines changed

4 files changed

+29
-25
lines changed

.github/workflows/ci.yml

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ jobs:
1616
fail-fast: false
1717
matrix:
1818
version:
19-
- '1.0'
2019
- '1'
2120
- 'nightly'
2221
os:

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1212
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1313
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
1516
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1617
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1718

@@ -24,9 +25,10 @@ DiffTests = "0.0.1, 0.1"
2425
LogExpFunctions = "0.3"
2526
NaNMath = "0.2.2, 0.3"
2627
Preferences = "1"
28+
SIMD = "3"
2729
SpecialFunctions = "0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1.0"
2830
StaticArrays = "0.8.3, 0.9, 0.10, 0.11, 0.12, 1.0"
29-
julia = "1"
31+
julia = "1.6"
3032

3133
[extras]
3234
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"

src/partials.jl

+25-22
Original file line numberDiff line numberDiff line change
@@ -197,29 +197,32 @@ end
197197
return tupexpr(i -> :(rand(V)), N)
198198
end
199199

200-
@generated function scale_tuple(tup::NTuple{N}, x) where N
201-
return tupexpr(i -> :(tup[$i] * x), N)
202-
end
203-
204-
@generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N
205-
return tupexpr(i -> :(tup[$i] / x), N)
206-
end
207-
208-
@generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N
209-
return tupexpr(i -> :(a[$i] + b[$i]), N)
210-
end
211200

212-
@generated function sub_tuples(a::NTuple{N}, b::NTuple{N}) where N
213-
return tupexpr(i -> :(a[$i] - b[$i]), N)
214-
end
215-
216-
@generated function minus_tuple(tup::NTuple{N}) where N
217-
return tupexpr(i -> :(-tup[$i]), N)
218-
end
219-
220-
@generated function mul_tuples(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N
221-
return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N)
222-
end
201+
const SIMDFloat = Union{Float64, Float32}
202+
const SIMDInt = Union{
203+
Int128, Int64, Int32, Int16, Int8,
204+
UInt128, UInt64, UInt32, UInt16, UInt8,
205+
}
206+
const SIMDType = Union{SIMDFloat, SIMDInt}
207+
const NT{N,T} = NTuple{N,T}
208+
using SIMD
209+
210+
# SIMD implementation
211+
add_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) + Vec(b))
212+
sub_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) - Vec(b))
213+
scale_tuple(tup::NT{N,T}, x::T) where {N, T<:SIMDType} = Tuple(Vec(tup) * x)
214+
div_tuple_by_scalar(tup::NT{N,T}, x::T) where {N, T<:SIMDFloat} = Tuple(Vec(tup) / x)
215+
minus_tuple(tup::NT{N,T}) where {N, T<:SIMDType} = Tuple(-Vec(tup))
216+
mul_tuples(a::NT{N,T}, b::NT{N,T}, af::T, bf::T) where {N, T<:SIMDType} = Tuple(muladd(Vec{N,T}(af), Vec(a), Vec{N,T}(bf) * Vec(b)))
217+
218+
219+
# Fallback implementations
220+
@generated add_tuples(a::NT{N}, b::NT{N}) where N = tupexpr(i -> :(a[$i] + b[$i]), N)
221+
@generated sub_tuples(a::NT{N}, b::NT{N}) where N = tupexpr(i -> :(a[$i] - b[$i]), N)
222+
@generated scale_tuple(tup::NT{N}, x) where N = tupexpr(i -> :(tup[$i] * x), N)
223+
@generated div_tuple_by_scalar(tup::NT{N}, x) where N = tupexpr(i -> :(tup[$i] / x), N)
224+
@generated minus_tuple(tup::NT{N}) where N = tupexpr(i -> :(-tup[$i]), N)
225+
@generated mul_tuples(a::NT{N}, b::NT{N}, af, bf) where N = tupexpr(i -> :((af * a[$i]) + (bf * b[$i])), N)
223226

224227
###################
225228
# Pretty Printing #

test/PartialsTest.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ for N in (0, 3), T in (Int, Float32, Float64)
114114

115115
if N > 0
116116
@test ForwardDiff._div_partials(PARTIALS, PARTIALS2, X, Y) == ForwardDiff._mul_partials(PARTIALS, PARTIALS2, inv(Y), -X/(Y^2))
117-
@test ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values == map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2)
117+
@test all(isapprox.(ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values, map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2)))
118118
@test ForwardDiff._mul_partials(ZERO_PARTIALS, PARTIALS, X, Y) == Y * PARTIALS
119119
@test ForwardDiff._mul_partials(PARTIALS, ZERO_PARTIALS, X, Y) == X * PARTIALS
120120

0 commit comments

Comments
 (0)