Skip to content

RFC: Rules for FFTs #495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed

RFC: Rules for FFTs #495

wants to merge 4 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jan 14, 2021

This adds minimal support for using FFTW with ForwardDiff. This is done via the very lightweight AbstractFFTs package.

It would be the first mention of Complex in the code here. And it uses Complex{Dual{...}}, although in fact Dual{Complex{...}} would be more convenient, if I am thinking correctly. The latter means the complex pairs are still neighbours, and so an array could be reinterpreted to just ComplexF64, passed to FFTW, and then reinterpreted back. But, instead, this makes new arrays for the values & then the partials, processes them, and then re-assembles.

Quick demonstration:

julia> x1 = Dual.([1,ℯ,pi], [1,0,0], [0,1,0], [0,0,1]);

julia> Zygote.gradient(x -> sum(abs2, fft(x)), value.(x1))[1]
3-element Vector{ComplexF64}:
                6.0 + 0.0im
 16.309690970754268 + 0.0im
  18.84955592153876 + 0.0im

julia> partials(sum(abs2, fft(x1)))
3-element ForwardDiff.Partials{3, Float64}:
  6.0
 16.309690970754268
 18.84955592153876

Possibly full of bugs, barely tested, etc. Doesn't try to treat fft! and friends.

cc @tholdem

Co-authored-by: Niklas Schmitz <[email protected]>
@niklasschmitz
Copy link

niklasschmitz commented Jun 10, 2021

I was looking for this combination, this is great!

Also, the example above using partials explicitly works fine for me, but gradient seems to fail:

ForwardDiff.gradient(x -> sum(abs2, fft(x)), ForwardDiff.value.(x1))
# ERROR: LoadError: Cannot determine ordering of Dual tags Nothing and ForwardDiff.Tag{var"#17#18", Float64}
# Stacktrace:
# [1] ≺(a::Type, b::Type)
#   @ ForwardDiff ~/.julia/packages/ForwardDiff/QOqCN/src/dual.jl:49
# [2] partials
#   @ ~/.julia/packages/ForwardDiff/QOqCN/src/dual.jl:103 [inlined]
# [3] extract_gradient!(#unused#::Type{ForwardDiff.Tag{var"#17#18", Float64}}, result::Vector{Float64}, dual::ForwardDiff.Dual{Nothing, Float64, 3})
#   @ ForwardDiff ~/.julia/packages/ForwardDiff/QOqCN/src/gradient.jl:81
# [4] vector_mode_gradient(f::var"#17#18", x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 3}}})
#   @ ForwardDiff ~/.julia/packages/ForwardDiff/QOqCN/src/gradient.jl:109
# [5] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 3}}}, ::Val{true})
#   @ ForwardDiff ~/.julia/packages/ForwardDiff/QOqCN/src/gradient.jl:19
# [6] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 3}}}) (repeats 2 times)
#   @ ForwardDiff ~/.julia/packages/ForwardDiff/QOqCN/src/gradient.jl:17
# [7] top-level scope

@mcabbott
Copy link
Member Author

I guess _apply_plan here constructs new Duals without propagating the tag from the input. That's where the ones with "tags Nothing" must be coming from, inside your ForwardDiff.gradient call.

@niklasschmitz
Copy link

Thanks for the pointer, I've now modified _apply_plan to propagate the input tagtype

ForwardDiff.tagtype(x::Complex{<:ForwardDiff.Dual{T,V,N}}) where {T,V,N} = T
ForwardDiff.tagtype(::Type{<:Complex{<:ForwardDiff.Dual{T,V,N}}}) where {T,V,N} = T

function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray)
    xtil = p * ForwardDiff.value.(x)
    dxtils = ntuple(ForwardDiff.npartials(eltype(x))) do n
        p * ForwardDiff.partials.(x, n)
    end
    T = ForwardDiff.tagtype(eltype(x))
    map(xtil, dxtils...) do val, parts...
        Complex(
            ForwardDiff.Dual{T}(real(val), map(real, parts)),
            ForwardDiff.Dual{T}(imag(val), map(imag, parts)),
        )
    end
end

which indeed seems to have fixed the problem:

julia> ForwardDiff.gradient(x -> sum(abs2, fft(x)), [1,ℯ,pi])
3-element Vector{Float64}:
  6.0
 16.309690970754268
 18.84955592153876

end
end

function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is assuming the Plan is complex-valued, but not all of them are.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. This is at the moment a very rough sketch of how FFT might be supported, but is far from handling everything, and probably full of bugs.

If anyone wished to tidy it up they would be most welcome. Whether the maintainers think that this is something this package should do at all, and do in this way, I don't know.

@test complexfloat(x1)[1] === complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im
@test realfloat(x1)[1] === realfloat(x1[1]) === Dual(1.0, 2.0, 3.0)

@test fft(x1, 1)[1] isa Complex{<:Dual}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add tests for rfft and FFTW.r2r

@dlfivefifty dlfivefifty mentioned this pull request Aug 3, 2021
@codecov-commenter

This comment has been minimized.

@mcabbott mcabbott closed this Nov 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants