Open
Description
A week ago, I opened an issue on discourse. Since the problem still exists I decided to open an issue here. I have a conv net that works with complex numbers and it has a custom loss function. For the mwe :
using Flux
using CUDA
using Flux: glorot_uniform
using Statistics: mean
CUDA.allowscalar(false); # disallowing scalar operations on GPU
mutable struct Enc
rConv::Chain
iConv::Chain
function Enc(filter, stride, in, out, pad )
realConv = Chain(Conv(filter, in=>out, leakyrelu, init=glorot_uniform, stride=stride, pad=pad),
BatchNorm(out, relu))
imgConv = Chain(Conv(filter, in=>out, leakyrelu, init=glorot_uniform, stride=stride, pad=pad),
BatchNorm(out, relu))
new(realConv, imgConv)
end
function Enc(rConv::Chain, iConv::Chain)
new(rConv, iConv)
end
end
Flux.@functor Enc
function (enc::Enc)(x)
rC = enc.rConv(real(x))
iC = enc.iConv(imag(x))
rC = rC - iC
iC = rC + iC
complex.(rC, iC)
end
function multistft(spectrogram::CuArray{T, 4},
framelen::Int=1024,
hopsize::Int=div(framelen, 2)) where T <: Complex
freqbins, numframes, channels, samples = size(spectrogram)
expectedlen = framelen + (numframes - 1) * hopsize
spectrogram = isodd(numframes) ? hcat(spectrogram, CUDA.zeros(eltype(spectrogram), size(spectrogram, 1), 1, channels, samples)) : spectrogram
numframes = isodd(numframes) ? numframes + 1 : numframes # number of frames can be altered here, it should not effect the original framelen !
# window = hanningTensor(framelen, numframes, channels, samples)
window = CUDA.ones(Float32, (framelen, numframes, channels, samples)) .* CUDA.CuArray(Float32.(.5 .* (1 .- cos.(2 .* pi .* collect(0:framelen - 1)/(framelen - 1)))))
windows = CUDA.fill(Float32(1.0e-8), framelen, numframes, channels, samples) .+ (window.^2)
odds = Flux.flatten(windows[:, 1:2:end, :, :]);
evens = Flux.flatten(windows[:, 2:2:end, :, :]);
winsum = vcat(odds, CUDA.zeros(Float32, hopsize, samples)) .+ vcat(CUDA.zeros(Float32, hopsize, samples), evens);
wr_odd = window[:, 1:2:end, :, :] .* CUDA.CUFFT.irfft(spectrogram[:, 1:2:end, :, :], framelen, 1);
wr_even = window[:, 2:2:end, :, :] .* CUDA.CUFFT.irfft(spectrogram[:, 2:2:end, :, :], framelen, 1);
reconstructed = vcat(Flux.flatten(wr_odd), CUDA.zeros(Float32, hopsize, samples)) .+ vcat(CUDA.zeros(Float32, hopsize, samples), Flux.flatten(wr_even))
return (reconstructed ./ winsum)
end
# this loss is user-defined
function wsdrLoss(x, ŷ, y; ϵ=1e-8)
x = x |> multistft
ŷ = ŷ |> multistft
y = y |> multistft
z = x .- y
ẑ = x .- ŷ
nd = sum(y.^2; dims=1)[:]
dom = sum(z.^2; dims=1)[:]
ϵ_array = CUDA.fill(Float32(ϵ), size(nd))
aux = nd ./ (nd .+ dom .+ ϵ_array)
wSDR = aux .* sdr(ŷ, y) .+ (1 .- aux) .* sdr(ẑ, z)
CUDA.mean(wSDR)
end
multiNorm(A; dims) = CUDA.sqrt.(sum(real(A .* conj(A)), dims=dims))
function sdr(ypred, ygold; ϵ=1e-8)
num = sum(ygold .* ypred, dims=1)
den = multiNorm(ygold, dims=1) .* multiNorm(ypred, dims=1)
ϵ_array = CUDA.fill(Float32(ϵ), size(den))
-(num ./ (den .+ ϵ_array))
end
Zygote.@adjoint CUDA.ones(x...) = CUDA.ones(x...), _ -> map(_ -> nothing, x)
Zygote.@adjoint CUDA.zeros(x...) = CUDA.zeros(x...), _ -> map(_ -> nothing, x)
Zygote.@adjoint CUDA.fill(x::Real, dims...) = CUDA.fill(x, dims...), Δ->(sum(Δ), map(_->nothing, dims)...)
x = CUDA.rand(ComplexF32, 513, 321, 1, 1); # input
y = CUDA.rand(ComplexF32, 513, 321, 1, 1); # output
# creating a dummy model on gpu
encoder = Chain(Enc((1, 1), (1, 1), 1, 1, (0, 0))) |> gpu
# ŷ = encoder(x);
# the loss function accepts 3 arguments that are input, prediction, and ground truths.
# to train/update the model
θ = params(encoder)
opt = ADAM(0.01)
∇ = gradient(() -> wsdrLoss(x, encoder(x), y), θ)
During gradient calculation I get error saying
ERROR: MethodError: no method matching plan_brfft(::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::Int64, ::Int64)
Closest candidates are:
plan_brfft(::CuArray{T,N}, ::Integer, ::Any) where {T<:Union{Complex{Float32}, Complex{Float64}}, N} at /opt/.julia/packages/CUDA/YeS8q/lib/cufft/fft.jl:306
plan_brfft(::AbstractArray, ::Integer; kws...) at /opt/.julia/packages/AbstractFFTs/mhQvY/src/definitions.jl:285
Stacktrace:
[1] plan_irfft(::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::Int64, ::Int64; kws::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /opt/.julia/packages/AbstractFFTs/mhQvY/src/definitions.jl:334
[2] plan_irfft(::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::Int64, ::Int64) at /opt/.julia/packages/AbstractFFTs/mhQvY/src/definitions.jl:334
[3] irfft(::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::Int64, ::Int64) at /opt/.julia/packages/AbstractFFTs/mhQvY/src/definitions.jl:284
[4] adjoint at /opt/.julia/packages/Zygote/xBjHw/src/lib/array.jl:929 [inlined]
[5] _pullback(::Zygote.Context, ::typeof(AbstractFFTs.irfft), ::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::Int64, ::Int64) at /opt/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47
[6] multistft at ./REPL[13]:19 [inlined]
[7] _pullback(::Zygote.Context, ::typeof(multistft), ::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::Int64, ::Int64) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
[8] multistft at ./REPL[13]:5 [inlined] (repeats 2 times)
[9] |> at ./operators.jl:834 [inlined]
[10] #wsdrLoss#3 at ./REPL[14]:5 [inlined]
[11] _pullback(::Zygote.Context, ::var"##wsdrLoss#3", ::Float64, ::typeof(wsdrLoss), ::CuArray{Complex{Float32},4}, ::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::CuArray{Complex{Float32},4}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
[12] wsdrLoss at ./REPL[14]:4 [inlined]
[13] _pullback(::Zygote.Context, ::typeof(wsdrLoss), ::CuArray{Complex{Float32},4}, ::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::CuArray{Complex{Float32},4}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
[14] #24 at ./REPL[29]:1 [inlined]
[15] _pullback(::Zygote.Context, ::var"#24#25") at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
[16] pullback(::Function, ::Params) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface.jl:172
[17] gradient(::Function, ::Params) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface.jl:53
[18] top-level scope at REPL[29]:1
There is also another discussion on the same/similar topic on discourse which is out of my knowledge.
I cannot move on due to this problem. What should I do ?
B.R.
Metadata
Metadata
Assignees
Labels
No labels