Skip to content

Commit db6fe0d

Browse files
committed
feat: more support for complex
1 parent 8f4c96b commit db6fe0d

File tree

4 files changed

+49
-0
lines changed

4 files changed

+49
-0
lines changed

Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
1515

16+
[extensions]
17+
NeuralOperatorsReactantExt = "Reactant"
18+
19+
[weakdeps]
20+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
21+
1622
[compat]
1723
AbstractFFTs = "1.5.0"
1824
ConcreteStructs = "0.2.3"
@@ -21,5 +27,6 @@ LuxCore = "1.2"
2127
LuxLib = "1.8"
2228
NNlib = "0.9.30"
2329
Random = "1.10"
30+
Reactant = "0.2.129"
2431
WeightInitializers = "1"
2532
julia = "1.10"

ext/NeuralOperatorsReactantExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module NeuralOperatorsReactantExt
2+
3+
using NeuralOperators: NeuralOperators
4+
using Reactant: Reactant, TracedRNumber
5+
6+
NeuralOperators.unwrapped_eltype(x::TracedRNumber) = Reactant.unwrapped_eltype(x)
7+
8+
end

src/models/fno.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,22 @@ function FourierNeuralOperator(
6464
),
6565
)
6666
end
67+
68+
function FourierNeuralOperator(
69+
modes::Dims,
70+
in_channels::Integer,
71+
out_channels::Integer,
72+
hidden_channels::Integer;
73+
num_layers::Integer=4,
74+
lifting_channel_ratio::Integer=2,
75+
projection_channel_ratio::Integer=2,
76+
positional_embedding::Union{Symbol,AbstractLuxLayer}=:grid, # :grid | :none
77+
activation=gelu,
78+
use_channel_mlp::Bool=true,
79+
channel_mlp_dropout_rate::Real=0.0,
80+
channel_mlp_expansion::Real=0.5,
81+
channel_mlp_skip::Symbol=:soft_gating,
82+
fno_skip::Symbol=:linear,
83+
)
84+
return nothing
85+
end

src/utils.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,18 @@ function meshgrid(args::AbstractVector...)
3939
end
4040
end
4141
end
42+
43+
unwrapped_eltype(x) = eltype(x)
44+
45+
function decomposed_activation(f::F, x::Number) where {F}
46+
unwrapped_eltype(x) <: Complex && return Complex(f(real(x)), f(imag(x)))
47+
return f(x)
48+
end
49+
50+
apply_complex((rfn, ifn), x::Number) = apply_complex(rfn, ifn, x)
51+
function apply_complex(rfn, ifn, x::Number)
52+
@assert unwrapped_eltype(x) <: Complex "Expected a complex number, got \
53+
$(unwrapped_eltype(x))"
54+
rl, img = real(x), imag(x)
55+
return Complex(rfn(rl) - ifn(img), rfn(img) + ifn(rl))
56+
end

0 commit comments

Comments
 (0)