Skip to content

Commit 58d29f9

Browse files
authored
feat: closer feature parity with neuraloperators (#63)
* feat: grid embedding layer * feat: more support for complex * feat: ComplexDecomposedLayer * feat: soft gating * feat: use fftshift * chore: bump min Reactant * fix: cleanup printing * fix: make shift into an option * feat: constant grid in IR * feat: generalize OperatorKernel * fix: channel mlp * fix: rename * feat: more options * chore: run fmt * test: skip some tests * ci: env vars
1 parent da5454a commit 58d29f9

File tree

15 files changed

+421
-67
lines changed

15 files changed

+421
-67
lines changed

.buildkite/documentation.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ steps:
1919
agents:
2020
queue: "juliagpu"
2121
cuda: "*"
22-
env:
23-
DATADEPS_ALWAYS_ACCEPT: true
2422
if: build.message !~ /\[skip docs\]/ && !build.pull_request.draft
2523
timeout_in_minutes: 1000
2624

2725
env:
26+
JULIA_DEBUG: "Documenter"
27+
DATADEPS_ALWAYS_ACCEPT: true
2828
SECRET_CODECOV_TOKEN: "vn5M+4wSwUFje6fl6UB/Q/rTmLHu3OlCCMgoPOXPQHYpLZTLz2hOHsV44MadAnxw8MsNVxLKZlXBKqP3IydU9gUfV7QUBtnvbUmIvgUHbr+r0bVaIVVhw6cnd0s8/b+561nU483eRJd35bjYDOlO+V5eDxkbdh/0bzLefXNXy5+ALxsBYzsp75Sx/9nuREfRqWwU6S45mne2ZlwCDpZlFvBDXQ2ICKYXpA45MpxhW9RuqfpQdi6sSR6I/HdHkV2cuJO99dqqh8xfUy6vWPC/+HUVrn9ETsrXtayX1MX3McKj869htGICpR8vqd311HTONYVprH2AN1bJqr5MOIZ8Xg==;U2FsdGVkX1+W55pTI7zq+NwYrbK6Cgqe+Gp8wMCmXY+W10aXTB0bS6zshiDYSQ1Y3piT91xFyNhS+9AsajY0yQ=="
2929
SECRET_DOCUMENTER_KEY: "hnN1QYai+CvEgtp2M9S9ZxK4xBthWAhx8snqR7MIKk3cqG7qHMsTd49KMSYWLXEsrsMbI2N4OLjNE3SgVZwyfuHop6x3ZqwVWD0OtldK6QvL/c+lmpzYMkIRCqKfXOKm9cQTWrr/eKzEh0uEUEdzCH7vrqLP7ywf5F9BMJbRmHW2f4cE1RKiG5cWVtgyO2OyRg7CxYrWgg0eKba+Vh8dLj1xL1Gs1XpLr7TDYRfSo7VMKYdkvB7Os7hmDN04YwzRFvGQBlroBaiC7KgbI5kZcL33R/whCHUYneSsXozGVyJGGGUbdxCaSslJ3p3Vb3vpnn+eAn0GyriWF/TaKIUlpQ==;U2FsdGVkX1/7Mqz8JnY+1EpErS6yTG4GfBJTNI+3WLtOuhyr+6SWZgIBCoitfIVSqvAws+RXVY09rG5Vsnvperr/lAzuBINsfOdIkSnS4HjfRHrKXGWzwzdTawPw1tn0wJa9h4RNahoBQk7Qs3clEb150AlomRJ//UAOW6UgD/nvr8TASjc5aHPkZodBogtMn93ti7hVRlUCNZAk58zpyXEsJorX9RT4g473y1aJp+CMd+Momo7/eljjwq3uCFyWpBnIwmipal4myUszDnSHuEQHcgh74ghJ1LkBjKcLUrRj06wKM/PfF23/P7h/KZCeryvghPSczKCJHx3wjH10DUuFK4maX9Yli/vg0SBq7/crWxNPXS5GmHYg/NBD2pgbe9qFvSCX4xJvkDl2ixHkNu5wQCCCTixtLfv00/XKWT5nh+PLSapQkY7ouesLFk8adpLAESPlX4pjSbljy0BrcHeFXFclFerq+Ia4StCLehWI4k2o/gfUEx7cnNN5ZKQpcYtQVOKRQriq8ri6dz+znQTmWqIM0xf+ZfWHy1OA3UXi269b0jdIlRQNtAjWhah9WWBaCiPY63/bgW54nIh1TVS1YaHzx2elWtFzHNpeTxIoibmltUivCE9r9Mk+YpfYUHJVUP9DwaLXcCiMZoLPNU4XWAzTejkpKJX/TdVFudLm2WNsCWi7nzjTfWSwlUOjrccLy9JVAx186gMlFhf5rHWeNFydBgEHBUig0FWkASwSPyv5iwD63NoL8UsTCRTNKMDtzn1XuCNxbmxgsbNO5bl+mcCDB/ZJuXexw3Ek04Ur4eZNyhf5qcn2cUzm76Jqi2vHSnY8JS9U9sSCdOd1gHFhxTOmO/OWh4jP5GKtWSWH23T2UOTS95daw1z2+GwLLP3R8chnFUCWvnJ+SUAHKoHYvbiT34gOyhcbnmoYTBCuvy4eghq1pVALUooPHOHzq2WsHAI07y6dUfbq55xaQZYiT++e+UisIID4mamITC5IFAQhkXPSAbVo56VrRCNQ+WnxLhzB6HNceAUdJCvH0IKcC9NqvfhCNioQ0EktfNla4G1BydJwgqexu8Tm4f5dl9e9V65hTn05yRYY2QN4hPCsw9Fr2NOrGfudCp5PEvA9TNuaobXvR5lSHPLL19ImbkCR1FGndrpwzWHR/0xjyS4lfdo3cW6SGVK33WZzOouM2ShGZc5Ma4yJdcPr1g0RsnWMGVX0A2gzJD9EL9RtNLgsDERnRLkEJgvhgVGXI+nmXadWzmn+plWgGEjSCLeZH1a4msFmcco3NFj0c6i5g+6cG09yILwhUTX+iegwm71r71pbPsMBk/hz/0ntVdklqFfCUZOC1n+MGJJmFfHmHAYHXSX5RmJhmvZG59g+oBEKxCh4VaIsSjLbopR4/TxS/uHdq3JJMOHXpgTQDack0qz3iEZFt+tKvKmah15PijdCv/Y8SHlN+txgSpW8+bQpxaMFq2li+/vApT18hBLjgGoMRv+I4ZmEmLUN1C+BzkCBTtErn4REQn4aKQmqlKC0jhQ5jatR5/Olpa1msiNjVyX0l4RBZFmYlpRY0KxKuVyguQtBF2qaxP2Lgrf2lp7tI7bK1z2MMBOxRTKHYs/Jt77mL/oeOOd2263EuZ2ku6K/iaaEY7jOCN7Du60UYVh0ePPorCTICHJ/arKLQ/uhh/d6Lo71bkj72B7CG75KGjZOg3+FRsFI0BLVjqvEJasR3apRCeipt3iMqmWc67wDp6DMXGCb/LlZh+oSfXtTY1VuMK4L4qcPl15Lf9H2xIt7N1YR6riq5R2EAokqQtazfxMfIk8+R25674O2Jksy0HnkIQ/H1/r2vaQYFzB0KVl671RdzHVYpfSMQQUcQtbzJHHsw5OIAPRwGd2RxFwW3GieKPYNS1u01TW+7z7Tq80qCOqewf+L+4ckbgYiiDDux+lzEK+PowtyOnQR1sjw7PE+BCPqF4Uc7b65EItFacseBD3mNVRXFUSG29rehgU4OrBjS7uUjn43QoMfonVCRhiLxydrfgGTdXW0g3qg3NYRsNeX3kP7b8MrzJ1zhPMZCca0YPbLeZ6Ab8lSc/GeJgrJ1m6ollE5yC2YisJMvNyJU525A7pReSk7u77nd1HTcddvngn8glJjcPn5/2aDuPuP4ZSEt67QLUaxVXce83XMY3FqiDngPhbJYSwGb9dlDizu9ph6a/SgZQyVpMBVL76RK9xropTUFDfDf5SOjukv5VWC2klIh8Fc4v9ur6otP5kx0MYV4ndlCoU6CyfEaQIQKXyZIIOWEPxyBvy8GnOSyh+5E82rJ50+PHkU9+iEJLBaZENnHlDJA30ANHSLF33t3CAxlqOvhFC5hx+BWCbK6R92mITBJOmo58+Fcz1e7u6SRLhOrp5rO+t3z6eB7SuVr0j16Tn4CJNGiYVGmhDjA1NfXoI7uEUkLVGuhra0eFzUCH8c+ot1pamb35E0knzaXy7azX5nuRXgt/491L0f9uc20ZqsrP63JtRQO574sbHxM6SnBWcIZSEEhAGRf4busL7DqjfHSFVx8TJH4FO0fCteCveJcxDoc94Wj7J+u5rUQBsRLDQ24XPjSgflqfhyU0lNKXuDWUSJPzqVuUM3fnYCsCdSd4/Y23t5k5dL68wOj3tBEunpGKltzG4zXjopay0VSkpty0EkaQbejvgXWnZv2G/SZlH/dW0fqp4a/M824kuzh9liF34fcV1uReoc+SKXs7iLYHgDXXajWy40HqvQR1NxtD17eIBvfhsAeKM3WR6YqudVkl4DYXU2hGM7L7znex2JAVs1thOtzBD6DbTKtlEciTs9O2B++22XnSmrusweKE68/8P0FN7DKA2oNQCig9J7BMT3pWA2jtoeq5sZU3vgvzlCLELefO4+YUWHY436znIp773oE8+IDdbxwLqKQfRI7zQAUGXNeUNEzAIla51zSARIFocslZfmF8ugx3NDBNqd6cjU+gk4qVAmUZW/Sw7SM+aPbC6hJPrDCfsMYhG7CNO5DZNBSO6Hhh0GSfl/zuB4Lzzg6xJuCfupg/A9DaUMfaIacYcsP7TjqVf07SbiPIXhlrFvuv5d43tkGBQdAQ/dNtpm2R5//u5Fi/mOlELiJBV5aLI2wZn2pGJRuOa6tb8GHSW3eLmv7RlrvBUEXutrfnevzUFu73Ue6jArVLhNQBhq64GhDEMQQvXH8QM/C0aTL8/IhFx80+kESUdnNa1OW7bNrBhOxTJJ4ho2JvenqVP44UQjJKOQGTjD5LSAqy9mkzK73lzRBVTNhpI2MTbKEnfXxN0wCce7cdpzJAUOTYF4fa4A9kkL3dqq/nhvkfzFY7lJbGI0J6hwMa9lWPwhB+wBU5hZu7s9eLDsfMoPE3vti8p8e91SgnPs+QtzPoe72GfBbJe8QQNnGeGW1aI2+qF82yarez1hTFaUPX1dgwSAKJS3/RSy1pMm1EPcM50P5At2ervJRJM5VCKrVJ8rvcFKuzfQfcZ/sqCW4l3E8ojJ7+7MAu4Ilfg0C6ZVY2S+rH52jIJtpTqbtGDaC0Zp7QTEbD5Eu1GkIxUmjpg23VGrNi4SnV+3ZbzOue8qTTstQdwzD2gDgb+zXUSkeGgZYYyEJQpGSjEJOPVi6CByqUtdK+Lfd19WRgSFJMmDfupQFUk7D19/d4K7/dj0NcuH+DQBfPvZl4cJkGqHeBXlK8uPN1A+q1S8D2QA7m4BzHAHz0Sy99KDq4spuFlm5t9oYAielqduZHieY4BugcQ1KcwR6pmc2ayEZvVRwfTNnxyodbOQCg4yh/pV4kX40G+tGnwtimaIVzyeIaZhL+iQ0M1+4KJCbKT2NGnk82hUKyrnBYSTPAG6vjRkg3P6KZj6tj4ABpxnymkZnMrUq2gpx/OlF0nl2CV0iCvj0RqZ5kX3NDVDtW8XmXh1sTATEGltfEclYoKMA/c47adE+rj16CiPwInnI/+JReNSjI14UoNn+27n8yQfxb3LbYkDSYkZcV7aysKPLQunMO0f0n6WUc2nprUIwQSO1Ccdhnio+YJs9JdvICN6JDpXfXM3eM6YGoSFv+XUownwfCW+fc1LMEXb4qjlUv7TIZKgXN5cEokKC2ZPpFJxF1TB2B7M3ib3a0orWuL8RsvexU7fpbUWo0+QI0CmiTqkmtsq1QeC7OCUU6Ocmz0iKUcYzcuLc9gfRyxXE8p7GqamK+okZWA0CJcMGvg0UXZcIaaHvUa9FuYLnjEOwvz1V92GGL0epIR9Vm4N6Aa1HzppbDuyY+donj0KBQkDxhix4+sF79LzdAJq1cvuxWW3km8f9CXYRdJ3oOY2nqwQXsi08K+yE+Zv2YVPd6Z1tJfk7Y+KiyhSIsKxDqIY1SaWwgRRYpi8c52BjTtnaIbYjFI="

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ examples
66
docs/build
77

88
.CondaPkg
9+
.venv
10+
.python-version
11+
uv.lock

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NeuralOperators"
22
uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
33
authors = ["Avik Pal <[email protected]>"]
4-
version = "0.6.0"
4+
version = "0.6.1"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ NeuralOperators = "0.6"
2929
Optimisers = "0.4"
3030
Printf = "1.10"
3131
PythonCall = "0.9.23"
32-
Reactant = "0.2.127"
32+
Reactant = "0.2.130"

docs/src/api.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ OperatorConv
1515
SpectralConv
1616
OperatorKernel
1717
SpectralKernel
18+
GridEmbedding
19+
ComplexDecomposedLayer
20+
SoftGating
1821
```
1922

2023
## Transform API

docs/src/tutorials/burgers_deeponet.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,10 @@ draw(
131131
:data_sequence => L"u(x)";
132132
color=:label => "",
133133
layout=:sequence => nonnumeric,
134+
linestyle=:label => "",
134135
) *
135-
visual(Lines),
136-
scales(; Color=(; palette=:tab10));
136+
visual(Lines; linewidth=4),
137+
scales(; Color=(; palette=:tab10), LineStyle = (; palette = [:solid, :dash]));
137138
figure=(;
138139
size=(1024, 1024),
139140
title="Using DeepONet to solve the Burgers equation",

docs/src/tutorials/burgers_fno.md

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,7 @@ x_data = reshape(T.(collect(read(file, "a")[1:N, 1:Δsamples:end])), N, :)
4646
y_data = reshape(T.(collect(read(file, "u")[1:N, 1:Δsamples:end])), N, :)
4747
close(file)
4848
49-
x_data = hcat(
50-
repeat(reshape(collect(T, range(0, 1; length=grid_size)), :, 1, 1), 1, 1, N),
51-
reshape(permutedims(x_data, (2, 1)), grid_size, 1, N)
52-
);
49+
x_data = reshape(permutedims(x_data, (2, 1)), grid_size, 1, N);
5350
y_data = reshape(permutedims(y_data, (2, 1)), grid_size, 1, N);
5451
```
5552

@@ -62,9 +59,7 @@ const cdev = cpu_device()
6259
const xdev = reactant_device(; force=true)
6360
6461
fno = FourierNeuralOperator(
65-
gelu;
66-
chs = (2, 32, 32, 32, 1),
67-
modes = (16,)
62+
(16,), 2, 1, 32; activation=gelu, stabilizer=tanh, shift=true
6863
)
6964
ps, st = Lux.setup(Random.default_rng(), fno) |> xdev;
7065
```
@@ -74,13 +69,16 @@ ps, st = Lux.setup(Random.default_rng(), fno) |> xdev;
7469
```@example burgers_fno
7570
dataloader = DataLoader((x_data, y_data); batchsize=128, shuffle=true) |> xdev;
7671
77-
function train_model!(model, ps, st, dataloader; epochs=5000)
72+
function train_model!(model, ps, st, dataloader; epochs=1000)
7873
train_state = Training.TrainState(model, ps, st, Adam(0.0001f0))
7974
80-
for epoch in 1:epochs, data in dataloader
81-
(_, loss, _, train_state) = Training.single_train_step!(
82-
AutoEnzyme(), MAELoss(), data, train_state; return_gradients=Val(false)
83-
)
75+
for epoch in 1:epochs
76+
loss = -Inf
77+
for data in dataloader
78+
(_, loss, _, train_state) = Training.single_train_step!(
79+
AutoEnzyme(), MAELoss(), data, train_state; return_gradients=Val(false)
80+
)
81+
end
8482
8583
if epoch % 100 == 1 || epoch == epochs
8684
@printf("Epoch %d: loss = %.6e\n", epoch, loss)
@@ -90,7 +88,7 @@ function train_model!(model, ps, st, dataloader; epochs=5000)
9088
return train_state.parameters, train_state.states
9189
end
9290
93-
(ps_trained, st_trained) = train_model!(fno, ps, st, dataloader)
91+
ps_trained, st_trained = train_model!(fno, ps, st, dataloader)
9492
nothing #hide
9593
```
9694

@@ -104,7 +102,7 @@ AoG.set_aog_theme!()
104102
x_data_dev = x_data |> xdev;
105103
y_data_dev = y_data |> xdev;
106104
107-
grid = x_data[:, 1, :]
105+
grid = range(0, 1; length=grid_size)
108106
pred = first(
109107
Reactant.with_config(;
110108
convolution_precision=PrecisionConfig.HIGH,
@@ -116,7 +114,7 @@ pred = first(
116114
117115
data_sequence, sequence, repeated_grid, label = Float32[], Int[], Float32[], String[]
118116
for i in 1:16
119-
append!(repeated_grid, vcat(grid[:, i], grid[:, i]))
117+
append!(repeated_grid, repeat(grid, 2))
120118
append!(sequence, repeat([i], grid_size * 2))
121119
append!(label, repeat(["Ground Truth"], grid_size))
122120
append!(label, repeat(["Predictions"], grid_size))
@@ -135,7 +133,7 @@ draw(
135133
linestyle=:label => "",
136134
) *
137135
visual(Lines; linewidth=4),
138-
scales(; Color=(; palette=:tab10), LineStyle = (; palette = [:solid, :dash, :dot]));
136+
scales(; Color=(; palette=:tab10), LineStyle = (; palette = [:solid, :dash]));
139137
figure=(;
140138
size=(1024, 1024),
141139
title="Using FNO to solve the Burgers equation",

src/NeuralOperators.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
module NeuralOperators
22

3-
using AbstractFFTs: rfft, irfft
3+
using AbstractFFTs: fft, rfft, ifft, irfft, fftshift
44
using ConcreteStructs: @concrete
55
using Random: Random, AbstractRNG
66

7-
using Lux: Lux, Chain, Dense, Conv, Parallel, NoOpLayer, WrappedFunction
7+
using Lux: Lux, Chain, Dense, Conv, Parallel, NoOpLayer, WrappedFunction, Scale
88
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer
99
using LuxLib: fast_activation!!
1010
using NNlib: NNlib, batched_mul, pad_constant, gelu
@@ -21,6 +21,7 @@ include("models/nomad.jl")
2121

2222
export FourierTransform
2323
export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel
24+
export GridEmbedding, ComplexDecomposedLayer, SoftGating
2425

2526
export FourierNeuralOperator
2627
export DeepONet

src/layers.jl

Lines changed: 162 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32}((16,)));
3030
init_weight
3131
end
3232

33+
function Base.show(io::IO, layer::OperatorConv)
34+
print(io, "OperatorConv(")
35+
print(io, layer.in_chs, " => ", layer.out_chs, ", ")
36+
print(io, layer.tform, ")")
37+
return nothing
38+
end
39+
3340
function LuxCore.initialparameters(rng::AbstractRNG, layer::OperatorConv)
3441
in_chs, out_chs = layer.in_chs, layer.out_chs
3542
scale = real(one(eltype(layer.tform))) / (in_chs * out_chs)
@@ -54,20 +61,17 @@ function OperatorConv(
5461
end
5562

5663
function (conv::OperatorConv)(x::AbstractArray{T,N}, ps, st) where {T,N}
57-
return operator_conv(x, conv.tform, ps.weight), st
58-
end
59-
60-
function operator_conv(x, tform::AbstractTransform, weights)
61-
x_t = transform(tform, x)
62-
x_tr = truncate_modes(tform, x_t)
63-
x_p = apply_pattern(x_tr, weights)
64+
x_t = transform(conv.tform, x)
65+
x_tr = truncate_modes(conv.tform, x_t)
66+
x_p = apply_pattern(x_tr, ps.weight)
6467

6568
pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)]
6669
x_padded = pad_constant(
6770
x_p, expand_pad_dims(pad_dims), false; dims=ntuple(identity, ndims(x_p) - 2)
6871
)
72+
out = inverse(conv.tform, x_padded, x)
6973

70-
return inverse(tform, x_padded, size(x))
74+
return out, st
7175
end
7276

7377
"""
@@ -83,8 +87,10 @@ julia> SpectralConv(2 => 5, (16,));
8387
8488
```
8589
"""
86-
function SpectralConv(ch::Pair{<:Integer,<:Integer}, modes::Dims; kwargs...)
87-
return OperatorConv(ch, modes, FourierTransform{ComplexF32}(modes); kwargs...)
90+
function SpectralConv(
91+
ch::Pair{<:Integer,<:Integer}, modes::Dims; shift::Bool=false, kwargs...
92+
)
93+
return OperatorConv(ch, modes, FourierTransform{ComplexF32}(modes, shift); kwargs...)
8894
end
8995

9096
"""
@@ -119,17 +125,72 @@ function OperatorKernel(
119125
modes::Dims{N},
120126
transform::AbstractTransform,
121127
act=identity;
128+
stabilizer=identity,
129+
complex_data::Bool=false,
130+
fno_skip::Symbol=:linear,
131+
channel_mlp_skip::Symbol=:soft_gating,
132+
use_channel_mlp::Bool=false,
133+
channel_mlp_expansion::Real=0.5,
122134
kwargs...,
123135
) where {N}
136+
in_chs, out_chs = ch
137+
138+
complex_data && (stabilizer = Base.Fix1(decomposed_activation, stabilizer))
139+
stabilizer = WrappedFunction(Base.BroadcastFunction(stabilizer))
140+
141+
activation = complex_data ? Base.Fix1(decomposed_activation, act) : act
142+
143+
conv_layer = OperatorConv(ch, modes, transform; kwargs...)
144+
145+
fno_skip_layer = __fno_skip_connection(in_chs, out_chs, N, false, fno_skip)
146+
complex_data && (fno_skip_layer = ComplexDecomposedLayer(fno_skip_layer))
147+
148+
if use_channel_mlp
149+
channel_mlp_hidden_channels = round(Int, out_chs * channel_mlp_expansion)
150+
channel_mlp = Chain(
151+
Conv(ntuple(Returns(1), N), out_chs => channel_mlp_hidden_channels),
152+
Conv(ntuple(Returns(1), N), channel_mlp_hidden_channels => out_chs),
153+
)
154+
complex_data && (channel_mlp = ComplexDecomposedLayer(channel_mlp))
155+
156+
channel_mlp_skip_layer = __fno_skip_connection(
157+
in_chs, out_chs, N, false, channel_mlp_skip
158+
)
159+
complex_data &&
160+
(channel_mlp_skip_layer = ComplexDecomposedLayer(channel_mlp_skip_layer))
161+
162+
return OperatorKernel(
163+
Parallel(
164+
Fix1(add_act, activation),
165+
Chain(
166+
Parallel(
167+
Fix1(add_act, act), fno_skip_layer, Chain(; stabilizer, conv_layer)
168+
),
169+
channel_mlp,
170+
),
171+
channel_mlp_skip_layer,
172+
),
173+
)
174+
end
175+
124176
return OperatorKernel(
125-
Parallel(
126-
Fix1(add_act, act),
127-
Conv(ntuple(one, N), ch),
128-
OperatorConv(ch, modes, transform; kwargs...),
129-
),
177+
Parallel(Fix1(add_act, act), fno_skip_layer, Chain(; stabilizer, conv_layer))
130178
)
131179
end
132180

181+
function __fno_skip_connection(in_chs, out_chs, n_dims, use_bias, skip_type)
182+
if skip_type == :linear
183+
return Conv(ntuple(Returns(1), n_dims), in_chs => out_chs; use_bias)
184+
elseif skip_type == :soft_gating
185+
@assert in_chs == out_chs "For soft gating, in_chs must equal out_chs"
186+
return SoftGating(out_chs, n_dims; use_bias)
187+
elseif skip_type == :none
188+
return NoOpLayer()
189+
else
190+
error("Invalid skip_type: $(skip_type)")
191+
end
192+
end
193+
133194
"""
134195
SpectralKernel(args...; kwargs...)
135196
@@ -143,6 +204,90 @@ julia> SpectralKernel(2 => 5, (16,));
143204
144205
```
145206
"""
146-
function SpectralKernel(ch::Pair{<:Integer,<:Integer}, modes::Dims, act=identity; kwargs...)
147-
return OperatorKernel(ch, modes, FourierTransform{ComplexF32}(modes), act; kwargs...)
207+
function SpectralKernel(
208+
ch::Pair{<:Integer,<:Integer}, modes::Dims, act=identity; shift::Bool=false, kwargs...
209+
)
210+
return OperatorKernel(
211+
ch, modes, FourierTransform{ComplexF32}(modes, shift), act; kwargs...
212+
)
213+
end
214+
215+
"""
216+
GridEmbedding(grid_boundaries::Vector{<:Tuple{<:Real,<:Real}})
217+
218+
Appends a uniform grid embedding to the input data along the penultimate dimension.
219+
"""
220+
@concrete struct GridEmbedding <: AbstractLuxLayer
221+
grid_boundaries <: Vector{<:Tuple{<:Real,<:Real}}
222+
end
223+
224+
function Base.show(io::IO, layer::GridEmbedding)
225+
return print(io, "GridEmbedding(", join(layer.grid_boundaries, ", "), ")")
226+
end
227+
228+
function (layer::GridEmbedding)(x::AbstractArray{T,N}, ps, st) where {T,N}
229+
@assert length(layer.grid_boundaries) == N - 2
230+
231+
grid = meshgrid(map(enumerate(layer.grid_boundaries)) do (i, (min, max))
232+
range(T(min), T(max); length=size(x, i))
233+
end...)
234+
235+
grid = repeat(
236+
Lux.Utils.contiguous(reshape(grid, size(grid)..., 1)),
237+
ntuple(Returns(1), N - 1)...,
238+
size(x, N),
239+
)
240+
return cat(grid, x; dims=N - 1), st
241+
end
242+
243+
"""
244+
ComplexDecomposedLayer(layer::AbstractLuxLayer)
245+
246+
Decomposes complex activations into real and imaginary parts and applies the given layer to
247+
each component separately, and then recombines the real and imaginary parts.
248+
"""
249+
@concrete struct ComplexDecomposedLayer <: AbstractLuxWrapperLayer{:layer}
250+
layer <: AbstractLuxLayer
251+
end
252+
253+
function LuxCore.initialparameters(rng::AbstractRNG, layer::ComplexDecomposedLayer)
254+
return (;
255+
real=LuxCore.initialparameters(rng, layer.layer),
256+
imag=LuxCore.initialparameters(rng, layer.layer),
257+
)
258+
end
259+
260+
function LuxCore.initialstates(rng::AbstractRNG, layer::ComplexDecomposedLayer)
261+
return (;
262+
real=LuxCore.initialstates(rng, layer.layer),
263+
imag=LuxCore.initialstates(rng, layer.layer),
264+
)
265+
end
266+
267+
function (layer::ComplexDecomposedLayer)(x::AbstractArray{T,N}, ps, st) where {T,N}
268+
rx = real.(x)
269+
ix = imag.(x)
270+
271+
rfn_rx, st_real = layer.layer(rx, ps.real, st.real)
272+
rfn_ix, st_real = layer.layer(ix, ps.real, st_real)
273+
274+
ifn_rx, st_imag = layer.layer(rx, ps.imag, st.imag)
275+
ifn_ix, st_imag = layer.layer(ix, ps.imag, st_imag)
276+
277+
out = Complex.(rfn_rx .- ifn_ix, rfn_ix .+ ifn_rx)
278+
return out, (; real=st_real, imag=st_imag)
279+
end
280+
281+
"""
282+
SoftGating(chs::Integer, ndims::Integer; kwargs...)
283+
284+
Constructs a wrapper over `Scale` with `dims = (ntuple(Returns(1), ndims)..., chs)`. All
285+
keyword arguments are passed to the `Scale` constructor.
286+
"""
287+
@concrete struct SoftGating <: AbstractLuxWrapperLayer{:layer}
288+
layer <: Scale
289+
end
290+
291+
function SoftGating(chs::Integer, ndims::Integer; kwargs...)
292+
return SoftGating(Scale(ntuple(Returns(1), ndims)..., chs; kwargs...))
148293
end

0 commit comments

Comments
 (0)