Skip to content

Commit 5187357

Browse files
committed
fix: rename to centered_truncation
1 parent 60e0428 commit 5187357

File tree

7 files changed

+60
-43
lines changed

7 files changed

+60
-43
lines changed

docs/src/tutorials/burgers_deeponet.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,10 @@ draw(
131131
:data_sequence => L"u(x)";
132132
color=:label => "",
133133
layout=:sequence => nonnumeric,
134+
linestyle=:label => "",
134135
) *
135-
visual(Lines),
136-
scales(; Color=(; palette=:tab10));
136+
visual(Lines; linewidth=4),
137+
scales(; Color=(; palette=:tab10), LineStyle = (; palette = [:solid, :dash]));
137138
figure=(;
138139
size=(1024, 1024),
139140
title="Using DeepONet to solve the Burgers equation",

docs/src/tutorials/burgers_fno.md

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,7 @@ x_data = reshape(T.(collect(read(file, "a")[1:N, 1:Δsamples:end])), N, :)
4646
y_data = reshape(T.(collect(read(file, "u")[1:N, 1:Δsamples:end])), N, :)
4747
close(file)
4848
49-
x_data = hcat(
50-
repeat(reshape(collect(T, range(0, 1; length=grid_size)), :, 1, 1), 1, 1, N),
51-
reshape(permutedims(x_data, (2, 1)), grid_size, 1, N)
52-
);
49+
x_data = reshape(permutedims(x_data, (2, 1)), grid_size, 1, N);
5350
y_data = reshape(permutedims(y_data, (2, 1)), grid_size, 1, N);
5451
```
5552

@@ -62,9 +59,7 @@ const cdev = cpu_device()
6259
const xdev = reactant_device(; force=true)
6360
6461
fno = FourierNeuralOperator(
65-
gelu;
66-
chs = (2, 32, 32, 32, 1),
67-
modes = (16,)
62+
(16,), 2, 1, 32; activation=gelu, stabilizer=tanh, centered_truncation=true
6863
)
6964
ps, st = Lux.setup(Random.default_rng(), fno) |> xdev;
7065
```
@@ -74,13 +69,16 @@ ps, st = Lux.setup(Random.default_rng(), fno) |> xdev;
7469
```@example burgers_fno
7570
dataloader = DataLoader((x_data, y_data); batchsize=128, shuffle=true) |> xdev;
7671
77-
function train_model!(model, ps, st, dataloader; epochs=5000)
72+
function train_model!(model, ps, st, dataloader; epochs=1000)
7873
train_state = Training.TrainState(model, ps, st, Adam(0.0001f0))
7974
80-
for epoch in 1:epochs, data in dataloader
81-
(_, loss, _, train_state) = Training.single_train_step!(
82-
AutoEnzyme(), MAELoss(), data, train_state; return_gradients=Val(false)
83-
)
75+
for epoch in 1:epochs
76+
loss = -Inf
77+
for data in dataloader
78+
(_, loss, _, train_state) = Training.single_train_step!(
79+
AutoEnzyme(), MAELoss(), data, train_state; return_gradients=Val(false)
80+
)
81+
end
8482
8583
if epoch % 100 == 1 || epoch == epochs
8684
@printf("Epoch %d: loss = %.6e\n", epoch, loss)
@@ -90,7 +88,7 @@ function train_model!(model, ps, st, dataloader; epochs=5000)
9088
return train_state.parameters, train_state.states
9189
end
9290
93-
(ps_trained, st_trained) = train_model!(fno, ps, st, dataloader)
91+
ps_trained, st_trained = train_model!(fno, ps, st, dataloader)
9492
nothing #hide
9593
```
9694

@@ -104,7 +102,7 @@ AoG.set_aog_theme!()
104102
x_data_dev = x_data |> xdev;
105103
y_data_dev = y_data |> xdev;
106104
107-
grid = x_data[:, 1, :]
105+
grid = range(0, 1; length=grid_size)
108106
pred = first(
109107
Reactant.with_config(;
110108
convolution_precision=PrecisionConfig.HIGH,
@@ -116,7 +114,7 @@ pred = first(
116114
117115
data_sequence, sequence, repeated_grid, label = Float32[], Int[], Float32[], String[]
118116
for i in 1:16
119-
append!(repeated_grid, vcat(grid[:, i], grid[:, i]))
117+
append!(repeated_grid, repeat(grid, 2))
120118
append!(sequence, repeat([i], grid_size * 2))
121119
append!(label, repeat(["Ground Truth"], grid_size))
122120
append!(label, repeat(["Predictions"], grid_size))
@@ -135,7 +133,7 @@ draw(
135133
linestyle=:label => "",
136134
) *
137135
visual(Lines; linewidth=4),
138-
scales(; Color=(; palette=:tab10), LineStyle = (; palette = [:solid, :dash, :dot]));
136+
scales(; Color=(; palette=:tab10), LineStyle = (; palette = [:solid, :dash]));
139137
figure=(;
140138
size=(1024, 1024),
141139
title="Using FNO to solve the Burgers equation",

src/layers.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,11 @@ julia> SpectralConv(2 => 5, (16,));
8989
```
9090
"""
9191
function SpectralConv(
92-
ch::Pair{<:Integer,<:Integer}, modes::Dims; shift::Bool=false, kwargs...
92+
ch::Pair{<:Integer,<:Integer}, modes::Dims; centered_truncation::Bool=false, kwargs...
9393
)
94-
return OperatorConv(ch, modes, FourierTransform{ComplexF32}(modes, shift); kwargs...)
94+
return OperatorConv(
95+
ch, modes, FourierTransform{ComplexF32}(modes, centered_truncation); kwargs...
96+
)
9597
end
9698

9799
"""
@@ -129,8 +131,8 @@ function OperatorKernel(
129131
stabilizer=identity,
130132
complex_data::Bool=false,
131133
fno_skip::Symbol=:linear,
132-
channel_mlp_skip::Symbol=:linear,
133-
use_channel_mlp::Bool=true,
134+
channel_mlp_skip::Symbol=:soft_gating,
135+
use_channel_mlp::Bool=false,
134136
channel_mlp_expansion::Real=0.5,
135137
kwargs...,
136138
) where {N}
@@ -205,10 +207,14 @@ julia> SpectralKernel(2 => 5, (16,));
205207
```
206208
"""
207209
function SpectralKernel(
208-
ch::Pair{<:Integer,<:Integer}, modes::Dims, act=identity; shift::Bool=false, kwargs...
210+
ch::Pair{<:Integer,<:Integer},
211+
modes::Dims,
212+
act=identity;
213+
centered_truncation::Bool=false,
214+
kwargs...,
209215
)
210216
return OperatorKernel(
211-
ch, modes, FourierTransform{ComplexF32}(modes, shift), act; kwargs...
217+
ch, modes, FourierTransform{ComplexF32}(modes, centered_truncation), act; kwargs...
212218
)
213219
end
214220

src/models/fno.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ function FourierNeuralOperator(
8181
fno_skip::Symbol=:linear,
8282
complex_data::Bool=false,
8383
stabilizer=tanh,
84-
shift::Bool=false,
84+
centered_truncation::Bool=false,
8585
) where {N}
8686
lifting_channels = hidden_channels * lifting_channel_ratio
8787
projection_channels = out_channels * projection_channel_ratio
@@ -114,7 +114,7 @@ function FourierNeuralOperator(
114114
modes,
115115
activation;
116116
stabilizer,
117-
shift,
117+
centered_truncation,
118118
use_channel_mlp,
119119
channel_mlp_expansion,
120120
channel_mlp_skip,

src/transform.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,30 @@ function truncate_modes end
1919
function inverse end
2020

2121
"""
22-
FourierTransform{T}(modes, shift::Bool=false)
22+
FourierTransform{T}(modes, centered_truncation::Bool=false)
2323
2424
A concrete implementation of `AbstractTransform` for Fourier transforms.
2525
"""
2626
struct FourierTransform{T,M} <: AbstractTransform{T}
2727
modes::M
28-
shift::Bool
28+
centered_truncation::Bool
2929
end
3030

31-
function FourierTransform{T}(modes::Dims, shift::Bool=false) where {T}
32-
return FourierTransform{T,typeof(modes)}(modes, shift)
31+
function FourierTransform{T}(modes::Dims, centered_truncation::Bool=false) where {T}
32+
return FourierTransform{T,typeof(modes)}(modes, centered_truncation)
3333
end
3434

3535
function Base.show(io::IO, ft::FourierTransform)
3636
print(io, "FourierTransform{", eltype(ft), "}(")
37-
print(io, ft.modes, ", shift=", ft.shift, ")")
37+
print(io, ft.modes, ", centered_truncation=", ft.centered_truncation, ")")
3838
return nothing
3939
end
4040

4141
Base.ndims(T::FourierTransform) = length(T.modes)
4242

4343
function transform(ft::FourierTransform, x::AbstractArray)
4444
res = Lux.Utils.eltype(x) <: Complex ? fft(x, 1:ndims(ft)) : rfft(x, 1:ndims(ft))
45-
if ft.shift && ndims(ft) > 1
45+
if ft.centered_truncation && ndims(ft) > 1
4646
res = fftshift(res, 1:ndims(ft))
4747
end
4848
return res
@@ -57,7 +57,7 @@ truncate_modes(ft::FourierTransform, x_fft::AbstractArray) = low_pass(ft, x_fft)
5757
function inverse(
5858
ft::FourierTransform, x_fft::AbstractArray{T,N}, x::AbstractArray{T2,N}
5959
) where {T,T2,N}
60-
if ft.shift && ndims(ft) > 1
60+
if ft.centered_truncation && ndims(ft) > 1
6161
x_fft = fftshift(x_fft, 1:ndims(ft))
6262
end
6363

test/fno_tests.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,27 @@
77
chs=(2, 64, 64, 64, 64, 64, 128, 1),
88
x_size=(1024, 2, 5),
99
y_size=(1024, 1, 5),
10-
shift=false,
10+
centered_truncation=false,
1111
),
1212
(
1313
modes=(16, 16),
1414
chs=(2, 64, 64, 64, 64, 64, 128, 4),
1515
x_size=(32, 32, 2, 5),
1616
y_size=(32, 32, 4, 5),
17-
shift=false,
17+
centered_truncation=false,
1818
),
1919
(
2020
modes=(16, 16),
2121
chs=(2, 64, 64, 64, 64, 64, 128, 4),
2222
x_size=(32, 32, 2, 5),
2323
y_size=(32, 32, 4, 5),
24-
shift=true,
24+
centered_truncation=true,
2525
),
2626
]
2727

28-
@testset "$(length(setup.modes))D | shift=$(setup.shift)" for setup in setups
29-
fno = FourierNeuralOperator(; setup.chs, setup.modes, setup.shift)
28+
@testset "$(length(setup.modes))D | centered_truncation=$(setup.centered_truncation)" for setup in
29+
setups
30+
fno = FourierNeuralOperator(; setup.chs, setup.modes, setup.centered_truncation)
3031
display(fno)
3132
ps, st = Lux.setup(rng, fno)
3233

test/layers_tests.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,33 @@
33

44
opconv = [SpectralConv, SpectralKernel]
55
setups = [
6-
(; m=(16,), x_size=(1024, 2, 5), y_size=(1024, 16, 5), shift=false),
7-
(; m=(10, 10), x_size=(22, 22, 1, 5), y_size=(22, 22, 16, 5), shift=false),
8-
(; m=(10, 10), x_size=(22, 22, 1, 5), y_size=(22, 22, 16, 5), shift=true),
6+
(; m=(16,), x_size=(1024, 2, 5), y_size=(1024, 16, 5), centered_truncation=false),
7+
(;
8+
m=(10, 10),
9+
x_size=(22, 22, 1, 5),
10+
y_size=(22, 22, 16, 5),
11+
centered_truncation=false,
12+
),
13+
(;
14+
m=(10, 10),
15+
x_size=(22, 22, 1, 5),
16+
y_size=(22, 22, 16, 5),
17+
centered_truncation=true,
18+
),
919
]
1020

1121
rdev = reactant_device()
1222

13-
@testset "$(op) $(length(setup.m))D | shift=$(setup.shift)" for op in opconv,
23+
@testset "$(op) $(length(setup.m))D | centered_truncation=$(setup.centered_truncation)" for op in
24+
opconv,
1425
setup in setups
1526

1627
in_chs = setup.x_size[end - 1]
1728
out_chs = setup.y_size[end - 1]
1829
ch = 4 => out_chs
1930

2031
l1 = Conv(ntuple(_ -> 1, length(setup.m)), in_chs => first(ch))
21-
m = Chain(l1, op(ch, setup.m; setup.shift))
32+
m = Chain(l1, op(ch, setup.m; setup.centered_truncation))
2233
display(m)
2334
ps, st = Lux.setup(rng, m)
2435

0 commit comments

Comments
 (0)