@@ -393,7 +393,7 @@ Base.float(d::Dual) = convert(float(typeof(d)), d)
393393# ##################################
394394
395395for (M, f, arity) in DiffRules. diffrules(filter_modules = nothing )
396- if (M, f) in ((:Base, :^ ), (:NaNMath, :pow), (:Base, :/ ), (:Base, :+ ), (:Base, :- ))
396+ if (M, f) in ((:Base, :^ ), (:NaNMath, :pow), (:Base, :/ ), (:Base, :+ ), (:Base, :- ), (:Base, :sin), (:Base, :cos) )
397397 continue # Skip methods which we define elsewhere.
398398 elseif ! (isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
399399 continue # Skip rules for methods not defined in the current scope
@@ -622,12 +622,19 @@ end
622622 Dual{Tz}(muladd(x, y, value(z)), partials(z)) # z_body
623623)
624624
625- # sincos #
625+ # sin/cos #
626626# --------#
627+ function Base. sin(d:: Dual{T} ) where T
628+ s, c = sincos(value(d))
629+ return Dual{T}(s, c * partials(d))
630+ end
627631
628- @inline sincos(x) = (sin(x), cos(x))
632+ function Base. cos(d:: Dual{T} ) where T
633+ s, c = sincos(value(d))
634+ return Dual{T}(c, - s * partials(d))
635+ end
629636
630- @inline function sincos(d:: Dual{T} ) where T
637+ @inline function Base . sincos(d:: Dual{T} ) where T
631638 sd, cd = sincos(value(d))
632639 return (Dual{T}(sd, cd * partials(d)), Dual{T}(cd, - sd * partials(d)))
633640end
0 commit comments