Closed as not planned
Description
I came across that FNO
does not work with ComponentArrays
(which is need for OptimizationProblem
). Any ideas what the problem is in?
fno = FourierNeuralOperator(gelu; chs = (2, 64, 64, 128, 1), modes = (16,))
θ, st = Lux.setup(Random.default_rng(), fno)
v = rand(rng, Float32, 2, 40, 50)
c = fno(v, θ, st)[1] .- 1.0f0
ff = (θ) -> fno(v, θ, st)[1] .- 1.0f0
init_params = ComponentArrays.ComponentArray(θ)
function total_loss(θ)
sum(abs2, ff(θ))
end
total_loss(θ)
total_loss(init_params)
julia> total_loss(init_params)
ERROR: MethodError: no method matching realfloat(::Array{ComplexF32, 3})
Closest candidates are:
realfloat(::StridedArray{<:Union{Float32, Float64}})
@ AbstractFFTs ~/.julia/packages/AbstractFFTs/4iQz5/src/definitions.jl:42
realfloat(::AbstractArray{T}) where T<:Real
@ AbstractFFTs ~/.julia/packages/AbstractFFTs/4iQz5/src/definitions.jl:49
Stacktrace:
[1] plan_rfft(x::Array{ComplexF32, 3}, region::UnitRange{Int64}; kws::@Kwargs{})
@ AbstractFFTs ~/.julia/packages/AbstractFFTs/4iQz5/src/definitions.jl:221
[2] rfft(x::Array{ComplexF32, 3}, region::UnitRange{Int64})
@ AbstractFFTs ~/.julia/packages/AbstractFFTs/4iQz5/src/definitions.jl:67
[3] transform
@ ~/.julia/packages/NeuralOperators/rTBsc/src/transform.jl:24 [inlined]
[4] operator_conv
@ ~/.julia/packages/NeuralOperators/rTBsc/src/functional.jl:3 [inlined]
[5] (::OperatorConv{…})(x::Array{…}, ps::ComponentArrays.ComponentVector{…}, st::@NamedTuple{})
@ NeuralOperators ~/.julia/packages/NeuralOperators/rTBsc/src/layers.jl:66
[6] (::OperatorKernel{…})(x::Array{…}, ps::ComponentArrays.ComponentVector{…}, st::@NamedTuple{…})
@ NeuralOperators ~/.julia/packages/NeuralOperators/rTBsc/src/layers.jl:138
[7] (::FourierNeuralOperator{…})(x::Array{…}, ps::ComponentArrays.ComponentVector{…}, st::@NamedTuple{…})
@ NeuralOperators ~/.julia/packages/NeuralOperators/rTBsc/src/fno.jl:70
[8] (::var"#245#246")(θ::ComponentArrays.ComponentVector{ComplexF32, Vector{ComplexF32}, Tuple{ComponentArrays.Axis{…}}})
@ Main ./REPL[706]:1
[9] total_loss(θ::ComponentArrays.ComponentVector{ComplexF32, Vector{ComplexF32}, Tuple{ComponentArrays.Axis{…}}})
@ Main ./REPL[708]:2
[10] top-level scope
@ REPL[710]:1
Some type information was truncated. Use `show(err)` to see complete types.
Metadata
Metadata
Assignees
Labels
No labels