-
Notifications
You must be signed in to change notification settings - Fork 150
RFC: Rules for FFTs #495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Rules for FFTs #495
Conversation
Co-authored-by: Niklas Schmitz <[email protected]>
I was looking for this combination, this is great! Also, the example above using ForwardDiff.gradient(x -> sum(abs2, fft(x)), ForwardDiff.value.(x1))
# ERROR: LoadError: Cannot determine ordering of Dual tags Nothing and ForwardDiff.Tag{var"#17#18", Float64}
# Stacktrace:
# [1] ≺(a::Type, b::Type)
# @ ForwardDiff ~/.julia/packages/ForwardDiff/QOqCN/src/dual.jl:49
# [2] partials
# @ ~/.julia/packages/ForwardDiff/QOqCN/src/dual.jl:103 [inlined]
# [3] extract_gradient!(#unused#::Type{ForwardDiff.Tag{var"#17#18", Float64}}, result::Vector{Float64}, dual::ForwardDiff.Dual{Nothing, Float64, 3})
# @ ForwardDiff ~/.julia/packages/ForwardDiff/QOqCN/src/gradient.jl:81
# [4] vector_mode_gradient(f::var"#17#18", x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 3}}})
# @ ForwardDiff ~/.julia/packages/ForwardDiff/QOqCN/src/gradient.jl:109
# [5] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 3}}}, ::Val{true})
# @ ForwardDiff ~/.julia/packages/ForwardDiff/QOqCN/src/gradient.jl:19
# [6] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 3}}}) (repeats 2 times)
# @ ForwardDiff ~/.julia/packages/ForwardDiff/QOqCN/src/gradient.jl:17
# [7] top-level scope |
I guess |
Thanks for the pointer, I've now modified ForwardDiff.tagtype(x::Complex{<:ForwardDiff.Dual{T,V,N}}) where {T,V,N} = T
ForwardDiff.tagtype(::Type{<:Complex{<:ForwardDiff.Dual{T,V,N}}}) where {T,V,N} = T
function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray)
xtil = p * ForwardDiff.value.(x)
dxtils = ntuple(ForwardDiff.npartials(eltype(x))) do n
p * ForwardDiff.partials.(x, n)
end
T = ForwardDiff.tagtype(eltype(x))
map(xtil, dxtils...) do val, parts...
Complex(
ForwardDiff.Dual{T}(real(val), map(real, parts)),
ForwardDiff.Dual{T}(imag(val), map(imag, parts)),
)
end
end which indeed seems to have fixed the problem: julia> ForwardDiff.gradient(x -> sum(abs2, fft(x)), [1,ℯ,pi])
3-element Vector{Float64}:
6.0
16.309690970754268
18.84955592153876 |
end | ||
end | ||
|
||
function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is assuming the Plan
is complex-valued, but not all of them are.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. This is at the moment a very rough sketch of how FFT might be supported, but is far from handling everything, and probably full of bugs.
If anyone wished to tidy it up they would be most welcome. Whether the maintainers think that this is something this package should do at all, and do in this way, I don't know.
@test complexfloat(x1)[1] === complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im | ||
@test realfloat(x1)[1] === realfloat(x1[1]) === Dual(1.0, 2.0, 3.0) | ||
|
||
@test fft(x1, 1)[1] isa Complex{<:Dual} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add tests for rfft
and FFTW.r2r
This adds minimal support for using FFTW with ForwardDiff. This is done via the very lightweight AbstractFFTs package.
It would be the first mention of
Complex
in the code here. And it usesComplex{Dual{...}}
, although in factDual{Complex{...}}
would be more convenient, if I am thinking correctly. The latter means the complex pairs are still neighbours, and so an array could be reinterpreted to justComplexF64
, passed to FFTW, and then reinterpreted back. But, instead, this makes new arrays for the values & then the partials, processes them, and then re-assembles.Quick demonstration:
Possibly full of bugs, barely tested, etc. Doesn't try to treat
fft!
and friends.cc @tholdem