diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 56350777..ca7f7e43 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,8 +18,6 @@ jobs: fail-fast: false matrix: version: - - '1.0' - - '1.6' - '1' - 'nightly' os: diff --git a/Project.toml b/Project.toml index f7c66517..6e6c2514 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SIMD = "fdea26ae-647d-5447-a871-4b548cad5224" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -25,8 +26,9 @@ LogExpFunctions = "0.3" NaNMath = "0.2.2, 0.3, 1" Preferences = "1" SpecialFunctions = "0.8, 0.9, 0.10, 1.0, 2" +SIMD = "3" StaticArrays = "0.8.3, 0.9, 0.10, 0.11, 0.12, 1.0" -julia = "1" +julia = "1.6" [extensions] ForwardDiffStaticArraysExt = "StaticArrays" @@ -43,4 +45,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" test = ["Calculus", "DiffTests", "SparseArrays", "Test", "InteractiveUtils", "StaticArrays"] [weakdeps] -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" \ No newline at end of file +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index a27d6dba..d1d7bbfa 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -7,6 +7,7 @@ if VERSION >= v"1.6" end using Random using LinearAlgebra +import SIMD: Vec import Printf import NaNMath @@ -14,6 +15,13 @@ import SpecialFunctions import LogExpFunctions import CommonSubexpressions +const SIMDFloat = Union{Float64, Float32} +const SIMDInt = Union{ + Int128, Int64, Int32, Int16, Int8, + UInt128, UInt64, UInt32, UInt16, UInt8, + } +const SIMDType = Union{SIMDFloat, SIMDInt} + include("prelude.jl") include("partials.jl") include("dual.jl") diff --git a/src/dual.jl b/src/dual.jl index 5afb2144..76fa70cb 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -196,13 +196,13 @@ macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_b end # Support complex-valued functions such as `hankelh1` -function dual_definition_retval(::Val{T}, val::Real, deriv::Real, partial::Partials) where {T} +@inline function dual_definition_retval(::Val{T}, val::Real, deriv::Real, partial::Partials) where {T} return Dual{T}(val, deriv * partial) end -function dual_definition_retval(::Val{T}, val::Real, deriv1::Real, partial1::Partials, deriv2::Real, partial2::Partials) where {T} +@inline function dual_definition_retval(::Val{T}, val::Real, deriv1::Real, partial1::Partials, deriv2::Real, partial2::Partials) where {T} return Dual{T}(val, _mul_partials(partial1, partial2, deriv1, deriv2)) end -function dual_definition_retval(::Val{T}, val::Complex, deriv::Union{Real,Complex}, partial::Partials) where {T} +@inline function dual_definition_retval(::Val{T}, val::Complex, deriv::Union{Real,Complex}, partial::Partials) where {T} reval, imval = reim(val) if deriv isa Real p = deriv * partial @@ -212,7 +212,7 @@ function dual_definition_retval(::Val{T}, val::Complex, deriv::Union{Real,Comple return Complex(Dual{T}(reval, rederiv * partial), Dual{T}(imval, imderiv * partial)) end end -function dual_definition_retval(::Val{T}, val::Complex, deriv1::Union{Real,Complex}, partial1::Partials, deriv2::Union{Real,Complex}, partial2::Partials) where {T} +@inline function dual_definition_retval(::Val{T}, val::Complex, deriv1::Union{Real,Complex}, partial1::Partials, deriv2::Union{Real,Complex}, partial2::Partials) where {T} reval, imval = reim(val) if deriv1 isa Real && deriv2 isa Real p = _mul_partials(partial1, partial2, deriv1, deriv2) @@ -592,6 +592,16 @@ end # fma # #-----# +@inline function calc_fma_xyz(x::Dual{T,V,N}, + y::Dual{T,V,N}, + z::Dual{T,V,N}) where {T, V<:SIMDFloat,N} + xv, yv, zv = value(x), value(y), value(z) + rv = fma(xv, yv, zv) + N == 0 && return Dual{T}(rv) + xp, yp, zp = Vec(partials(x).values), Vec(partials(y).values), Vec(partials(z).values) + parts = Tuple(fma(xv, yp, fma(yv, xp, zp))) + Dual{T}(rv, parts) +end @generated function calc_fma_xyz(x::Dual{T,<:Any,N}, y::Dual{T,<:Any,N}, z::Dual{T,<:Any,N}) where {T,N} @@ -634,6 +644,16 @@ end # muladd # #--------# +@inline function calc_muladd_xyz(x::Dual{T,V,N}, + y::Dual{T,V,N}, + z::Dual{T,V,N}) where {T, V<:SIMDType,N} + xv, yv, zv = value(x), value(y), value(z) + rv = muladd(xv, yv, zv) + N == 0 && return Dual{T}(rv) + xp, yp, zp = Vec(partials(x).values), Vec(partials(y).values), Vec(partials(z).values) + parts = Tuple(muladd(xv, yp, muladd(yv, xp, zp))) + Dual{T}(rv, parts) +end @generated function calc_muladd_xyz(x::Dual{T,<:Any,N}, y::Dual{T,<:Any,N}, z::Dual{T,<:Any,N}) where {T,N} diff --git a/src/partials.jl b/src/partials.jl index fce67b0a..7a94884e 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -141,6 +141,13 @@ end @inline _mul_partials(a::Partials{0,A}, b::Partials{N,B}, afactor, bfactor) where {N,A,B} = bfactor * b @inline _mul_partials(a::Partials{N,A}, b::Partials{0,B}, afactor, bfactor) where {N,A,B} = afactor * a +const SIMDFloat = Union{Float64, Float32} +const SIMDInt = Union{ + Int128, Int64, Int32, Int16, Int8, + UInt128, UInt64, UInt32, UInt16, UInt8, + } +const SIMDType = Union{SIMDFloat, SIMDInt} + ################################## # Generated Functions on NTuples # ################################## @@ -164,6 +171,7 @@ end @inline rand_tuple(::AbstractRNG, ::Type{Tuple{}}) = tuple() @inline rand_tuple(::Type{Tuple{}}) = tuple() +iszero_tuple(tup::NTuple{N,V}) where {N, V<:SIMDType} = sum(Vec(tup) != zero(V)) == 0 @generated function iszero_tuple(tup::NTuple{N,V}) where {N,V} ex = Expr(:&&, [:(z == tup[$i]) for i=1:N]...) return quote @@ -197,29 +205,24 @@ end return tupexpr(i -> :(rand(V)), N) end -@generated function scale_tuple(tup::NTuple{N}, x) where N - return tupexpr(i -> :(tup[$i] * x), N) -end +const NT{N,T} = NTuple{N,T} -@generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N - return tupexpr(i -> :(tup[$i] / x), N) -end +# SIMD implementation +@inline add_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) + Vec(b)) +@inline sub_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) - Vec(b)) +@inline scale_tuple(tup::NT{N,T}, x::T) where {N, T<:SIMDType} = Tuple(Vec(tup) * x) +@inline div_tuple_by_scalar(tup::NT{N,T}, x::T) where {N, T<:SIMDFloat} = Tuple(Vec(tup) / x) +@inline minus_tuple(tup::NT{N,T}) where {N, T<:SIMDType} = Tuple(-Vec(tup)) +@inline mul_tuples(a::NT{N,T}, b::NT{N,T}, af::T, bf::T) where {N, T<:SIMDType} = Tuple(muladd(af, Vec(a), bf * Vec(b))) -@generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N - return tupexpr(i -> :(a[$i] + b[$i]), N) -end -@generated function sub_tuples(a::NTuple{N}, b::NTuple{N}) where N - return tupexpr(i -> :(a[$i] - b[$i]), N) -end - -@generated function minus_tuple(tup::NTuple{N}) where N - return tupexpr(i -> :(-tup[$i]), N) -end - -@generated function mul_tuples(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N - return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N) -end +# Fallback implementations +@generated add_tuples(a::NT{N}, b::NT{N}) where N = tupexpr(i -> :(a[$i] + b[$i]), N) +@generated sub_tuples(a::NT{N}, b::NT{N}) where N = tupexpr(i -> :(a[$i] - b[$i]), N) +@generated scale_tuple(tup::NT{N}, x) where N = tupexpr(i -> :(tup[$i] * x), N) +@generated div_tuple_by_scalar(tup::NT{N}, x) where N = tupexpr(i -> :(tup[$i] / x), N) +@generated minus_tuple(tup::NT{N}) where N = tupexpr(i -> :(-tup[$i]), N) +@generated mul_tuples(a::NT{N}, b::NT{N}, af, bf) where N = tupexpr(i -> :(muladd(af, a[$i], bf * b[$i])), N) ################### # Pretty Printing # diff --git a/test/PartialsTest.jl b/test/PartialsTest.jl index 39fb05d7..08e76bda 100644 --- a/test/PartialsTest.jl +++ b/test/PartialsTest.jl @@ -114,7 +114,7 @@ for N in (0, 3), T in (Int, Float32, Float64) if N > 0 @test ForwardDiff._div_partials(PARTIALS, PARTIALS2, X, Y) == ForwardDiff._mul_partials(PARTIALS, PARTIALS2, inv(Y), -X/(Y^2)) - @test ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values == map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2) + @test all(isapprox.(ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values, map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2))) @test ForwardDiff._mul_partials(ZERO_PARTIALS, PARTIALS, X, Y) == Y * PARTIALS @test ForwardDiff._mul_partials(PARTIALS, ZERO_PARTIALS, X, Y) == X * PARTIALS