Skip to content

Commit d053bd7

Browse files
committed
use explicit simd for iszero check on partials (#559)
1 parent d551bbe commit d053bd7

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

src/ForwardDiff.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ if VERSION >= v"1.6"
88
end
99
using Random
1010
using LinearAlgebra
11+
import SIMD: Vec
1112

1213
import Printf
1314
import NaNMath

src/partials.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ end
141141
@inline _mul_partials(a::Partials{0,A}, b::Partials{N,B}, afactor, bfactor) where {N,A,B} = bfactor * b
142142
@inline _mul_partials(a::Partials{N,A}, b::Partials{0,B}, afactor, bfactor) where {N,A,B} = afactor * a
143143

144+
const SIMDFloat = Union{Float64, Float32}
145+
const SIMDInt = Union{
146+
Int128, Int64, Int32, Int16, Int8,
147+
UInt128, UInt64, UInt32, UInt16, UInt8,
148+
}
149+
const SIMDType = Union{SIMDFloat, SIMDInt}
150+
144151
##################################
145152
# Generated Functions on NTuples #
146153
##################################
@@ -164,6 +171,7 @@ end
164171
@inline rand_tuple(::AbstractRNG, ::Type{Tuple{}}) = tuple()
165172
@inline rand_tuple(::Type{Tuple{}}) = tuple()
166173

174+
iszero_tuple(tup::NTuple{N,V}) where {N, V<:SIMDType} = sum(Vec(tup) != zero(V)) == 0
167175
@generated function iszero_tuple(tup::NTuple{N,V}) where {N,V}
168176
ex = Expr(:&&, [:(z == tup[$i]) for i=1:N]...)
169177
return quote
@@ -205,15 +213,14 @@ const SIMDInt = Union{
205213
}
206214
const SIMDType = Union{SIMDFloat, SIMDInt}
207215
const NT{N,T} = NTuple{N,T}
208-
using SIMD
209216

210217
# SIMD implementation
211-
add_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) + Vec(b))
212-
sub_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) - Vec(b))
213-
scale_tuple(tup::NT{N,T}, x::T) where {N, T<:SIMDType} = Tuple(Vec(tup) * x)
214-
div_tuple_by_scalar(tup::NT{N,T}, x::T) where {N, T<:SIMDFloat} = Tuple(Vec(tup) / x)
215-
minus_tuple(tup::NT{N,T}) where {N, T<:SIMDType} = Tuple(-Vec(tup))
216-
mul_tuples(a::NT{N,T}, b::NT{N,T}, af::T, bf::T) where {N, T<:SIMDType} = Tuple(muladd(Vec{N,T}(af), Vec(a), Vec{N,T}(bf) * Vec(b)))
218+
@inline add_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) + Vec(b))
219+
@inline sub_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) - Vec(b))
220+
@inline scale_tuple(tup::NT{N,T}, x::T) where {N, T<:SIMDType} = Tuple(Vec(tup) * x)
221+
@inline div_tuple_by_scalar(tup::NT{N,T}, x::T) where {N, T<:SIMDFloat} = Tuple(Vec(tup) / x)
222+
@inline minus_tuple(tup::NT{N,T}) where {N, T<:SIMDType} = Tuple(-Vec(tup))
223+
@inline mul_tuples(a::NT{N,T}, b::NT{N,T}, af::T, bf::T) where {N, T<:SIMDType} = Tuple(muladd(af, Vec(a), bf * Vec(b)))
217224

218225

219226
# Fallback implementations
@@ -222,7 +229,7 @@ mul_tuples(a::NT{N,T}, b::NT{N,T}, af::T, bf::T) where {N, T<:SIMDType} = Tuple
222229
@generated scale_tuple(tup::NT{N}, x) where N = tupexpr(i -> :(tup[$i] * x), N)
223230
@generated div_tuple_by_scalar(tup::NT{N}, x) where N = tupexpr(i -> :(tup[$i] / x), N)
224231
@generated minus_tuple(tup::NT{N}) where N = tupexpr(i -> :(-tup[$i]), N)
225-
@generated mul_tuples(a::NT{N}, b::NT{N}, af, bf) where N = tupexpr(i -> :((af * a[$i]) + (bf * b[$i])), N)
232+
@generated mul_tuples(a::NT{N}, b::NT{N}, af, bf) where N = tupexpr(i -> :(muladd(af, a[$i], bf * b[$i])), N)
226233

227234
###################
228235
# Pretty Printing #

0 commit comments

Comments
 (0)