@@ -30,6 +30,13 @@ julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32}((16,)));
30
30
init_weight
31
31
end
32
32
33
+ function Base. show (io:: IO , layer:: OperatorConv )
34
+ print (io, " OperatorConv(" )
35
+ print (io, layer. in_chs, " => " , layer. out_chs, " , " )
36
+ print (io, layer. tform, " )" )
37
+ return nothing
38
+ end
39
+
33
40
function LuxCore. initialparameters (rng:: AbstractRNG , layer:: OperatorConv )
34
41
in_chs, out_chs = layer. in_chs, layer. out_chs
35
42
scale = real (one (eltype (layer. tform))) / (in_chs * out_chs)
@@ -54,20 +61,17 @@ function OperatorConv(
54
61
end
55
62
56
63
function (conv:: OperatorConv )(x:: AbstractArray{T,N} , ps, st) where {T,N}
57
- return operator_conv (x, conv. tform, ps. weight), st
58
- end
59
-
60
- function operator_conv (x, tform:: AbstractTransform , weights)
61
- x_t = transform (tform, x)
62
- x_tr = truncate_modes (tform, x_t)
63
- x_p = apply_pattern (x_tr, weights)
64
+ x_t = transform (conv. tform, x)
65
+ x_tr = truncate_modes (conv. tform, x_t)
66
+ x_p = apply_pattern (x_tr, ps. weight)
64
67
65
68
pad_dims = size (x_t)[1 : (end - 2 )] .- size (x_p)[1 : (end - 2 )]
66
69
x_padded = pad_constant (
67
70
x_p, expand_pad_dims (pad_dims), false ; dims= ntuple (identity, ndims (x_p) - 2 )
68
71
)
72
+ out = inverse (conv. tform, x_padded, x)
69
73
70
- return inverse (tform, x_padded, size (x))
74
+ return out, st
71
75
end
72
76
73
77
"""
@@ -83,8 +87,10 @@ julia> SpectralConv(2 => 5, (16,));
83
87
84
88
```
85
89
"""
86
- function SpectralConv (ch:: Pair{<:Integer,<:Integer} , modes:: Dims ; kwargs... )
87
- return OperatorConv (ch, modes, FourierTransform {ComplexF32} (modes); kwargs... )
90
+ function SpectralConv (
91
+ ch:: Pair{<:Integer,<:Integer} , modes:: Dims ; shift:: Bool = false , kwargs...
92
+ )
93
+ return OperatorConv (ch, modes, FourierTransform {ComplexF32} (modes, shift); kwargs... )
88
94
end
89
95
90
96
"""
@@ -119,17 +125,72 @@ function OperatorKernel(
119
125
modes:: Dims{N} ,
120
126
transform:: AbstractTransform ,
121
127
act= identity;
128
+ stabilizer= identity,
129
+ complex_data:: Bool = false ,
130
+ fno_skip:: Symbol = :linear ,
131
+ channel_mlp_skip:: Symbol = :soft_gating ,
132
+ use_channel_mlp:: Bool = false ,
133
+ channel_mlp_expansion:: Real = 0.5 ,
122
134
kwargs... ,
123
135
) where {N}
136
+ in_chs, out_chs = ch
137
+
138
+ complex_data && (stabilizer = Base. Fix1 (decomposed_activation, stabilizer))
139
+ stabilizer = WrappedFunction (Base. BroadcastFunction (stabilizer))
140
+
141
+ activation = complex_data ? Base. Fix1 (decomposed_activation, act) : act
142
+
143
+ conv_layer = OperatorConv (ch, modes, transform; kwargs... )
144
+
145
+ fno_skip_layer = __fno_skip_connection (in_chs, out_chs, N, false , fno_skip)
146
+ complex_data && (fno_skip_layer = ComplexDecomposedLayer (fno_skip_layer))
147
+
148
+ if use_channel_mlp
149
+ channel_mlp_hidden_channels = round (Int, out_chs * channel_mlp_expansion)
150
+ channel_mlp = Chain (
151
+ Conv (ntuple (Returns (1 ), N), out_chs => channel_mlp_hidden_channels),
152
+ Conv (ntuple (Returns (1 ), N), channel_mlp_hidden_channels => out_chs),
153
+ )
154
+ complex_data && (channel_mlp = ComplexDecomposedLayer (channel_mlp))
155
+
156
+ channel_mlp_skip_layer = __fno_skip_connection (
157
+ in_chs, out_chs, N, false , channel_mlp_skip
158
+ )
159
+ complex_data &&
160
+ (channel_mlp_skip_layer = ComplexDecomposedLayer (channel_mlp_skip_layer))
161
+
162
+ return OperatorKernel (
163
+ Parallel (
164
+ Fix1 (add_act, activation),
165
+ Chain (
166
+ Parallel (
167
+ Fix1 (add_act, act), fno_skip_layer, Chain (; stabilizer, conv_layer)
168
+ ),
169
+ channel_mlp,
170
+ ),
171
+ channel_mlp_skip_layer,
172
+ ),
173
+ )
174
+ end
175
+
124
176
return OperatorKernel (
125
- Parallel (
126
- Fix1 (add_act, act),
127
- Conv (ntuple (one, N), ch),
128
- OperatorConv (ch, modes, transform; kwargs... ),
129
- ),
177
+ Parallel (Fix1 (add_act, act), fno_skip_layer, Chain (; stabilizer, conv_layer))
130
178
)
131
179
end
132
180
181
+ function __fno_skip_connection (in_chs, out_chs, n_dims, use_bias, skip_type)
182
+ if skip_type == :linear
183
+ return Conv (ntuple (Returns (1 ), n_dims), in_chs => out_chs; use_bias)
184
+ elseif skip_type == :soft_gating
185
+ @assert in_chs == out_chs " For soft gating, in_chs must equal out_chs"
186
+ return SoftGating (out_chs, n_dims; use_bias)
187
+ elseif skip_type == :none
188
+ return NoOpLayer ()
189
+ else
190
+ error (" Invalid skip_type: $(skip_type) " )
191
+ end
192
+ end
193
+
133
194
"""
134
195
SpectralKernel(args...; kwargs...)
135
196
@@ -143,6 +204,90 @@ julia> SpectralKernel(2 => 5, (16,));
143
204
144
205
```
145
206
"""
146
- function SpectralKernel (ch:: Pair{<:Integer,<:Integer} , modes:: Dims , act= identity; kwargs... )
147
- return OperatorKernel (ch, modes, FourierTransform {ComplexF32} (modes), act; kwargs... )
207
+ function SpectralKernel (
208
+ ch:: Pair{<:Integer,<:Integer} , modes:: Dims , act= identity; shift:: Bool = false , kwargs...
209
+ )
210
+ return OperatorKernel (
211
+ ch, modes, FourierTransform {ComplexF32} (modes, shift), act; kwargs...
212
+ )
213
+ end
214
+
215
+ """
216
+ GridEmbedding(grid_boundaries::Vector{<:Tuple{<:Real,<:Real}})
217
+
218
+ Appends a uniform grid embedding to the input data along the penultimate dimension.
219
+ """
220
+ @concrete struct GridEmbedding <: AbstractLuxLayer
221
+ grid_boundaries <: Vector{<:Tuple{<:Real,<:Real}}
222
+ end
223
+
224
+ function Base. show (io:: IO , layer:: GridEmbedding )
225
+ return print (io, " GridEmbedding(" , join (layer. grid_boundaries, " , " ), " )" )
226
+ end
227
+
228
+ function (layer:: GridEmbedding )(x:: AbstractArray{T,N} , ps, st) where {T,N}
229
+ @assert length (layer. grid_boundaries) == N - 2
230
+
231
+ grid = meshgrid (map (enumerate (layer. grid_boundaries)) do (i, (min, max))
232
+ range (T (min), T (max); length= size (x, i))
233
+ end ... )
234
+
235
+ grid = repeat (
236
+ Lux. Utils. contiguous (reshape (grid, size (grid)... , 1 )),
237
+ ntuple (Returns (1 ), N - 1 )... ,
238
+ size (x, N),
239
+ )
240
+ return cat (grid, x; dims= N - 1 ), st
241
+ end
242
+
243
+ """
244
+ ComplexDecomposedLayer(layer::AbstractLuxLayer)
245
+
246
+ Decomposes complex activations into real and imaginary parts and applies the given layer to
247
+ each component separately, and then recombines the real and imaginary parts.
248
+ """
249
+ @concrete struct ComplexDecomposedLayer <: AbstractLuxWrapperLayer{:layer}
250
+ layer <: AbstractLuxLayer
251
+ end
252
+
253
+ function LuxCore. initialparameters (rng:: AbstractRNG , layer:: ComplexDecomposedLayer )
254
+ return (;
255
+ real= LuxCore. initialparameters (rng, layer. layer),
256
+ imag= LuxCore. initialparameters (rng, layer. layer),
257
+ )
258
+ end
259
+
260
+ function LuxCore. initialstates (rng:: AbstractRNG , layer:: ComplexDecomposedLayer )
261
+ return (;
262
+ real= LuxCore. initialstates (rng, layer. layer),
263
+ imag= LuxCore. initialstates (rng, layer. layer),
264
+ )
265
+ end
266
+
267
+ function (layer:: ComplexDecomposedLayer )(x:: AbstractArray{T,N} , ps, st) where {T,N}
268
+ rx = real .(x)
269
+ ix = imag .(x)
270
+
271
+ rfn_rx, st_real = layer. layer (rx, ps. real, st. real)
272
+ rfn_ix, st_real = layer. layer (ix, ps. real, st_real)
273
+
274
+ ifn_rx, st_imag = layer. layer (rx, ps. imag, st. imag)
275
+ ifn_ix, st_imag = layer. layer (ix, ps. imag, st_imag)
276
+
277
+ out = Complex .(rfn_rx .- ifn_ix, rfn_ix .+ ifn_rx)
278
+ return out, (; real= st_real, imag= st_imag)
279
+ end
280
+
281
+ """
282
+ SoftGating(chs::Integer, ndims::Integer; kwargs...)
283
+
284
+ Constructs a wrapper over `Scale` with `dims = (ntuple(Returns(1), ndims)..., chs)`. All
285
+ keyword arguments are passed to the `Scale` constructor.
286
+ """
287
+ @concrete struct SoftGating <: AbstractLuxWrapperLayer{:layer}
288
+ layer <: Scale
289
+ end
290
+
291
+ function SoftGating (chs:: Integer , ndims:: Integer ; kwargs... )
292
+ return SoftGating (Scale (ntuple (Returns (1 ), ndims)... , chs; kwargs... ))
148
293
end
0 commit comments