Skip to content

Commit 36b447b

Browse files
authored
docs: double pendulum example (#68)
* docs: double pendulum example * docs: update * docs: more fixes * docs: finish the example * chore: add typos
1 parent 680bc91 commit 36b447b

File tree

9 files changed

+272
-44
lines changed

9 files changed

+272
-44
lines changed

.buildkite/documentation.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ steps:
2424

2525
env:
2626
JULIA_DEBUG: "Documenter"
27+
XLA_REACTANT_GPU_MEM_FRACTION: "0.95"
2728
DATADEPS_ALWAYS_ACCEPT: true
2829
SECRET_CODECOV_TOKEN: "vn5M+4wSwUFje6fl6UB/Q/rTmLHu3OlCCMgoPOXPQHYpLZTLz2hOHsV44MadAnxw8MsNVxLKZlXBKqP3IydU9gUfV7QUBtnvbUmIvgUHbr+r0bVaIVVhw6cnd0s8/b+561nU483eRJd35bjYDOlO+V5eDxkbdh/0bzLefXNXy5+ALxsBYzsp75Sx/9nuREfRqWwU6S45mne2ZlwCDpZlFvBDXQ2ICKYXpA45MpxhW9RuqfpQdi6sSR6I/HdHkV2cuJO99dqqh8xfUy6vWPC/+HUVrn9ETsrXtayX1MX3McKj869htGICpR8vqd311HTONYVprH2AN1bJqr5MOIZ8Xg==;U2FsdGVkX1+W55pTI7zq+NwYrbK6Cgqe+Gp8wMCmXY+W10aXTB0bS6zshiDYSQ1Y3piT91xFyNhS+9AsajY0yQ=="
2930
SECRET_DOCUMENTER_KEY: "hnN1QYai+CvEgtp2M9S9ZxK4xBthWAhx8snqR7MIKk3cqG7qHMsTd49KMSYWLXEsrsMbI2N4OLjNE3SgVZwyfuHop6x3ZqwVWD0OtldK6QvL/c+lmpzYMkIRCqKfXOKm9cQTWrr/eKzEh0uEUEdzCH7vrqLP7ywf5F9BMJbRmHW2f4cE1RKiG5cWVtgyO2OyRg7CxYrWgg0eKba+Vh8dLj1xL1Gs1XpLr7TDYRfSo7VMKYdkvB7Os7hmDN04YwzRFvGQBlroBaiC7KgbI5kZcL33R/whCHUYneSsXozGVyJGGGUbdxCaSslJ3p3Vb3vpnn+eAn0GyriWF/TaKIUlpQ==;U2FsdGVkX1/7Mqz8JnY+1EpErS6yTG4GfBJTNI+3WLtOuhyr+6SWZgIBCoitfIVSqvAws+RXVY09rG5Vsnvperr/lAzuBINsfOdIkSnS4HjfRHrKXGWzwzdTawPw1tn0wJa9h4RNahoBQk7Qs3clEb150AlomRJ//UAOW6UgD/nvr8TASjc5aHPkZodBogtMn93ti7hVRlUCNZAk58zpyXEsJorX9RT4g473y1aJp+CMd+Momo7/eljjwq3uCFyWpBnIwmipal4myUszDnSHuEQHcgh74ghJ1LkBjKcLUrRj06wKM/PfF23/P7h/KZCeryvghPSczKCJHx3wjH10DUuFK4maX9Yli/vg0SBq7/crWxNPXS5GmHYg/NBD2pgbe9qFvSCX4xJvkDl2ixHkNu5wQCCCTixtLfv00/XKWT5nh+PLSapQkY7ouesLFk8adpLAESPlX4pjSbljy0BrcHeFXFclFerq+Ia4StCLehWI4k2o/gfUEx7cnNN5ZKQpcYtQVOKRQriq8ri6dz+znQTmWqIM0xf+ZfWHy1OA3UXi269b0jdIlRQNtAjWhah9WWBaCiPY63/bgW54nIh1TVS1YaHzx2elWtFzHNpeTxIoibmltUivCE9r9Mk+YpfYUHJVUP9DwaLXcCiMZoLPNU4XWAzTejkpKJX/TdVFudLm2WNsCWi7nzjTfWSwlUOjrccLy9JVAx186gMlFhf5rHWeNFydBgEHBUig0FWkASwSPyv5iwD63NoL8UsTCRTNKMDtzn1XuCNxbmxgsbNO5bl+mcCDB/ZJuXexw3Ek04Ur4eZNyhf5qcn2cUzm76Jqi2vHSnY8JS9U9sSCdOd1gHFhxTOmO/OWh4jP5GKtWSWH23T2UOTS95daw1z2+GwLLP3R8chnFUCWvnJ+SUAHKoHYvbiT34gOyhcbnmoYTBCuvy4eghq1pVALUooPHOHzq2WsHAI07y6dUfbq55xaQZYiT++e+UisIID4mamITC5IFAQhkXPSAbVo56VrRCNQ+WnxLhzB6HNceAUdJCvH0IKcC9NqvfhCNioQ0EktfNla4G1BydJwgqexu8Tm4f5dl9e9V65hTn05yRYY2QN4hPCsw9Fr2NOrGfudCp5PEvA9TNuaobXvR5lSHPLL19ImbkCR1FGndrpwzWHR/0xjyS4lfdo3cW6SGVK33WZzOouM2ShGZc5Ma4yJdcPr1g0RsnWMGVX0A2gzJD9EL9RtNLgsDERnRLkEJgvhgVGXI+nmXadWzmn+plWgGEjSCLeZH1a4msFmcco3NFj0c6i5g+6cG09yILwhUTX+iegwm71r71pbPsMBk/hz/0ntVdklqFfCUZOC1n+MGJJmFfHmHAYHXSX5RmJhmvZG59g+oBEKxCh4VaIsSjLbopR4/TxS/uHdq3JJMOHXpgTQDack0qz3iEZFt+tKvKmah15PijdCv/Y8SHlN+txgSpW8+bQpxaMFq2li+/vApT18hBLjgGoMRv+I4ZmEmLUN1C+BzkCBTtErn4REQn4aKQmqlKC0jhQ5jatR5/Olpa1msiNjVyX0l4RBZFmYlpRY0KxKuVyguQtBF2qaxP2Lgrf2lp7tI7bK1z2MMBOxRTKHYs/Jt77mL/oeOOd2263EuZ2ku6K/iaaEY7jOCN7Du60UYVh0ePPorCTICHJ/arKLQ/uhh/d6Lo71bkj72B7CG75KGjZOg3+FRsFI0BLVjqvEJasR3apRCeipt3iMqmWc67wDp6DMXGCb/LlZh+oSfXtTY1VuMK4L4qcPl15Lf9H2xIt7N1YR6riq5R2EAokqQtazfxMfIk8+R25674O2Jksy0HnkIQ/H1/r2vaQYFzB0KVl671RdzHVYpfSMQQUcQtbzJHHsw5OIAPRwGd2RxFwW3GieKPYNS1u01TW+7z7Tq80qCOqewf+L+4ckbgYiiDDux+lzEK+PowtyOnQR1sjw7PE+BCPqF4Uc7b65EItFacseBD3mNVRXFUSG29rehgU4OrBjS7uUjn43QoMfonVCRhiLxydrfgGTdXW0g3qg3NYRsNeX3kP7b8MrzJ1zhPMZCca0YPbLeZ6Ab8lSc/GeJgrJ1m6ollE5yC2YisJMvNyJU525A7pReSk7u77nd1HTcddvngn8glJjcPn5/2aDuPuP4ZSEt67QLUaxVXce83XMY3FqiDngPhbJYSwGb9dlDizu9ph6a/SgZQyVpMBVL76RK9xropTUFDfDf5SOjukv5VWC2klIh8Fc4v9ur6otP5kx0MYV4ndlCoU6CyfEaQIQKXyZIIOWEPxyBvy8GnOSyh+5E82rJ50+PHkU9+iEJLBaZENnHlDJA30ANHSLF33t3CAxlqOvhFC5hx+BWCbK6R92mITBJOmo58+Fcz1e7u6SRLhOrp5rO+t3z6eB7SuVr0j16Tn4CJNGiYVGmhDjA1NfXoI7uEUkLVGuhra0eFzUCH8c+ot1pamb35E0knzaXy7azX5nuRXgt/491L0f9uc20ZqsrP63JtRQO574sbHxM6SnBWcIZSEEhAGRf4busL7DqjfHSFVx8TJH4FO0fCteCveJcxDoc94Wj7J+u5rUQBsRLDQ24XPjSgflqfhyU0lNKXuDWUSJPzqVuUM3fnYCsCdSd4/Y23t5k5dL68wOj3tBEunpGKltzG4zXjopay0VSkpty0EkaQbejvgXWnZv2G/SZlH/dW0fqp4a/M824kuzh9liF34fcV1uReoc+SKXs7iLYHgDXXajWy40HqvQR1NxtD17eIBvfhsAeKM3WR6YqudVkl4DYXU2hGM7L7znex2JAVs1thOtzBD6DbTKtlEciTs9O2B++22XnSmrusweKE68/8P0FN7DKA2oNQCig9J7BMT3pWA2jtoeq5sZU3vgvzlCLELefO4+YUWHY436znIp773oE8+IDdbxwLqKQfRI7zQAUGXNeUNEzAIla51zSARIFocslZfmF8ugx3NDBNqd6cjU+gk4qVAmUZW/Sw7SM+aPbC6hJPrDCfsMYhG7CNO5DZNBSO6Hhh0GSfl/zuB4Lzzg6xJuCfupg/A9DaUMfaIacYcsP7TjqVf07SbiPIXhlrFvuv5d43tkGBQdAQ/dNtpm2R5//u5Fi/mOlELiJBV5aLI2wZn2pGJRuOa6tb8GHSW3eLmv7RlrvBUEXutrfnevzUFu73Ue6jArVLhNQBhq64GhDEMQQvXH8QM/C0aTL8/IhFx80+kESUdnNa1OW7bNrBhOxTJJ4ho2JvenqVP44UQjJKOQGTjD5LSAqy9mkzK73lzRBVTNhpI2MTbKEnfXxN0wCce7cdpzJAUOTYF4fa4A9kkL3dqq/nhvkfzFY7lJbGI0J6hwMa9lWPwhB+wBU5hZu7s9eLDsfMoPE3vti8p8e91SgnPs+QtzPoe72GfBbJe8QQNnGeGW1aI2+qF82yarez1hTFaUPX1dgwSAKJS3/RSy1pMm1EPcM50P5At2ervJRJM5VCKrVJ8rvcFKuzfQfcZ/sqCW4l3E8ojJ7+7MAu4Ilfg0C6ZVY2S+rH52jIJtpTqbtGDaC0Zp7QTEbD5Eu1GkIxUmjpg23VGrNi4SnV+3ZbzOue8qTTstQdwzD2gDgb+zXUSkeGgZYYyEJQpGSjEJOPVi6CByqUtdK+Lfd19WRgSFJMmDfupQFUk7D19/d4K7/dj0NcuH+DQBfPvZl4cJkGqHeBXlK8uPN1A+q1S8D2QA7m4BzHAHz0Sy99KDq4spuFlm5t9oYAielqduZHieY4BugcQ1KcwR6pmc2ayEZvVRwfTNnxyodbOQCg4yh/pV4kX40G+tGnwtimaIVzyeIaZhL+iQ0M1+4KJCbKT2NGnk82hUKyrnBYSTPAG6vjRkg3P6KZj6tj4ABpxnymkZnMrUq2gpx/OlF0nl2CV0iCvj0RqZ5kX3NDVDtW8XmXh1sTATEGltfEclYoKMA/c47adE+rj16CiPwInnI/+JReNSjI14UoNn+27n8yQfxb3LbYkDSYkZcV7aysKPLQunMO0f0n6WUc2nprUIwQSO1Ccdhnio+YJs9JdvICN6JDpXfXM3eM6YGoSFv+XUownwfCW+fc1LMEXb4qjlUv7TIZKgXN5cEokKC2ZPpFJxF1TB2B7M3ib3a0orWuL8RsvexU7fpbUWo0+QI0CmiTqkmtsq1QeC7OCUU6Ocmz0iKUcYzcuLc9gfRyxXE8p7GqamK+okZWA0CJcMGvg0UXZcIaaHvUa9FuYLnjEOwvz1V92GGL0epIR9Vm4N6Aa1HzppbDuyY+donj0KBQkDxhix4+sF79LzdAJq1cvuxWW3km8f9CXYRdJ3oOY2nqwQXsi08K+yE+Zv2YVPd6Z1tJfk7Y+KiyhSIsKxDqIY1SaWwgRRYpi8c52BjTtnaIbYjFI="

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ docs/build
99
.venv
1010
.python-version
1111
uv.lock
12+
13+
double_pendulum.gif

.typos.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[default.extend-words]
2+
iy = "iy"

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
[deps]
22
AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67"
3+
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
34
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
45
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
56
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
7+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
68
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
79
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
810
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"

docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pages = [
1111
"FNO" => "tutorials/burgers_fno.md",
1212
],
1313
"Solving Poisson Equation" => "tutorials/poisson_equation.md",
14+
"Double Pendulum" => "tutorials/double_pendulum.md",
1415
],
1516
"API Reference" => "api.md",
1617
]

docs/src/tutorials/burgers_deeponet.md

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,26 @@ const gdown = pyimport("gdown")
1111
1212
register(
1313
DataDep(
14-
"Burgers",
15-
"""
16-
Burgers' equation dataset from
17-
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
18-
19-
mapping between initial conditions to the solutions at the last point of time \
20-
evolution in some function space.
21-
22-
u(x,0) -> u(x, time_end):
23-
24-
* `a`: initial conditions u(x,0)
25-
* `u`: solutions u(x,t_end)
26-
""",
27-
"https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe",
28-
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd";
29-
fetch_method=(url,
30-
local_dir) -> begin
31-
pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip")))
32-
end,
33-
post_fetch_method=unpack
34-
)
14+
"Burgers",
15+
"""
16+
Burgers' equation dataset from
17+
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
18+
19+
mapping between initial conditions to the solutions at the last point of time \
20+
evolution in some function space.
21+
22+
u(x,0) -> u(x, time_end):
23+
24+
* `a`: initial conditions u(x,0)
25+
* `u`: solutions u(x,t_end)
26+
""",
27+
"https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe",
28+
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd";
29+
fetch_method=(url, local_dir) -> begin
30+
pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip")))
31+
end,
32+
post_fetch_method=unpack,
33+
),
3534
)
3635
3736
filepath = joinpath(datadep"Burgers", "burgers_data_R10.mat")
@@ -54,7 +53,7 @@ grid = reshape(collect(T, range(0, 1; length=grid_size)), 1, :)
5453
## Model
5554

5655
```@example burgers
57-
using Lux, NeuralOperators, Optimisers, Random, Reactant
56+
using Lux, NeuralOperators, Optimisers, Random, Reactant
5857
5958
const cdev = cpu_device()
6059
const xdev = reactant_device(; force=true)

docs/src/tutorials/burgers_fno.md

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,26 @@ const gdown = pyimport("gdown")
1111
1212
register(
1313
DataDep(
14-
"Burgers",
15-
"""
16-
Burgers' equation dataset from
17-
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
18-
19-
mapping between initial conditions to the solutions at the last point of time \
20-
evolution in some function space.
21-
22-
u(x,0) -> u(x, time_end):
23-
24-
* `a`: initial conditions u(x,0)
25-
* `u`: solutions u(x,t_end)
26-
""",
27-
"https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe",
28-
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd";
29-
fetch_method=(url,
30-
local_dir) -> begin
31-
pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip")))
32-
end,
33-
post_fetch_method=unpack
34-
)
14+
"Burgers",
15+
"""
16+
Burgers' equation dataset from
17+
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
18+
19+
mapping between initial conditions to the solutions at the last point of time \
20+
evolution in some function space.
21+
22+
u(x,0) -> u(x, time_end):
23+
24+
* `a`: initial conditions u(x,0)
25+
* `u`: solutions u(x,t_end)
26+
""",
27+
"https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe",
28+
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd";
29+
fetch_method=(url, local_dir) -> begin
30+
pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip")))
31+
end,
32+
post_fetch_method=unpack,
33+
),
3534
)
3635
3736
filepath = joinpath(datadep"Burgers", "burgers_data_R10.mat")
@@ -59,7 +58,7 @@ const cdev = cpu_device()
5958
const xdev = reactant_device(; force=true)
6059
6160
fno = FourierNeuralOperator(
62-
(16,), 2, 1, 32; activation=gelu, stabilizer=tanh, shift=true
61+
(16,), 2, 1, 32; activation=gelu, stabilizer=tanh
6362
)
6463
ps, st = Lux.setup(Random.default_rng(), fno) |> xdev;
6564
```
839 KB
Loading

docs/src/tutorials/double_pendulum.md

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Learning the motion of a double pendulum
2+
3+
## Data Loading
4+
5+
```julia
6+
using DataDeps, CSV, MLUtils, DataFrames
7+
using Printf
8+
9+
register(
10+
DataDep(
11+
"DoublePendulumChaotic",
12+
"""
13+
Dataset was generated on the basis of 21 individual runs of a double pendulum.
14+
Each of the recorded sequences lasted around 40s and consisted of around 17500 frames.
15+
16+
* `x_red`: Horizontal pixel coordinate of the red point (the central pivot to the
17+
first pendulum)
18+
* `y_red`: Vertical pixel coordinate of the red point (the central pivot to the first
19+
pendulum)
20+
* `x_green`: Horizontal pixel coordinate of the green point (the first pendulum)
21+
* `y_green`: Vertical pixel coordinate of the green point (the first pendulum)
22+
* `x_blue`: Horizontal pixel coordinate of the blue point (the second pendulum)
23+
* `y_blue`: Vertical pixel coordinate of the blue point (the second pendulum)
24+
25+
Page: https://developer.ibm.com/exchanges/data/all/double-pendulum-chaotic/
26+
""",
27+
"https://dax-cdn.cdn.appdomain.cloud/dax-double-pendulum-chaotic/2.0.1/double-pendulum-chaotic.tar.gz",
28+
"4ca743b4b783094693d313ebedc2e8e53cf29821ee8b20abd99f8fb4c0866f8d";
29+
post_fetch_method=unpack,
30+
),
31+
)
32+
33+
function get_data(; i=0, n=-1)
34+
data_path = joinpath(datadep"DoublePendulumChaotic", "original", "dpc_dataset_csv")
35+
df = CSV.read(
36+
joinpath(data_path, "$i.csv"),
37+
DataFrame;
38+
header=[:x_red, :y_red, :x_green, :y_green, :x_blue, :y_blue],
39+
)
40+
41+
n < 0 && return collect(Float32, Matrix(df)')
42+
return collect(Float32, Matrix(df)')[:, 1:n]
43+
end
44+
45+
function preprocess(x; Δt=1, nx=30, ny=30)
46+
# move red point to (0, 0)
47+
xs_red, ys_red = x[1, :], x[2, :]
48+
x[3, :] -= xs_red
49+
x[5, :] -= xs_red
50+
x[4, :] -= ys_red
51+
x[6, :] -= ys_red
52+
53+
# needs only green and blue points
54+
x = reshape(x[3:6, 1:Δt:end], 1, 4, :)
55+
# velocity of green and blue points
56+
∇x = x[:, :, 2:end] - x[:, :, 1:(end - 1)]
57+
# merge info of pos and velocity
58+
x = cat(x[:, :, 1:(end - 1)], ∇x; dims=1)
59+
60+
# with info of first nx steps to inference next ny steps
61+
n = size(x, ndims(x)) - (nx + ny) + 1
62+
xs = Array{Float32}(undef, size(x)[1:2]..., nx, n)
63+
ys = Array{Float32}(undef, size(x)[1:2]..., ny, n)
64+
for i in 1:n
65+
xs[:, :, :, i] .= x[:, :, i:(i + nx - 1)]
66+
ys[:, :, :, i] .= x[:, :, (i + nx):(i + nx + ny - 1)]
67+
end
68+
69+
return permutedims(xs, (3, 2, 1, 4)), permutedims(ys, (3, 2, 1, 4))
70+
end
71+
72+
function get_dataloader(; n_file=20, Δt=1, nx=30, ny=30, ratio=0.9, batchsize=128)
73+
xs, ys = Array{Float32}(undef, nx, 4, 2, 0), Array{Float32}(undef, ny, 4, 2, 0)
74+
for i in 1:n_file
75+
xs_i, ys_i = preprocess(get_data(; i=i - 1); Δt, nx, ny)
76+
xs, ys = cat(xs, xs_i; dims=4), cat(ys, ys_i; dims=4)
77+
end
78+
79+
data_train, data_test = splitobs(shuffleobs((xs, ys)); at=ratio)
80+
81+
trainloader = DataLoader(data_train; batchsize, shuffle=true, partial=false)
82+
testloader = DataLoader(data_test; batchsize, shuffle=false, partial=false)
83+
84+
return trainloader, testloader
85+
end
86+
```
87+
88+
## Model
89+
90+
```julia
91+
using Lux, NeuralOperators, Optimisers, Random, Reactant
92+
93+
const cdev = cpu_device()
94+
const xdev = reactant_device(; force=true)
95+
96+
fno = FourierNeuralOperator(
97+
(16, 4), 2, 2, 64; num_layers=6, activation=gelu, positional_embedding=:none
98+
)
99+
ps, st = Lux.setup(Random.default_rng(), fno) |> xdev;
100+
```
101+
102+
## Training
103+
104+
```julia
105+
trainloader, testloader = get_dataloader(; Δt=1, nx=30, ny=30) |> xdev;
106+
107+
function prediction_loss(model, x, ps, st, y)
108+
return MSELoss()(first(model(x, ps, st)), y)
109+
end
110+
111+
function train_model!(model, ps, st, trainloader, testloader; epochs=20)
112+
train_state = Training.TrainState(model, ps, st, AdamW(; eta=3.0f-4, lambda=1.0f-5))
113+
114+
(xtest, ytest) = first(testloader)
115+
prediction_loss_compiled = Reactant.with_config(;
116+
convolution_precision=PrecisionConfig.HIGH,
117+
dot_general_precision=PrecisionConfig.HIGH,
118+
) do
119+
@compile prediction_loss(
120+
model, xtest, train_state.parameters, train_state.states, ytest
121+
)
122+
end
123+
124+
for epoch in 1:epochs
125+
for data in trainloader
126+
(_, _, _, train_state) = Training.single_train_step!(
127+
AutoEnzyme(), MSELoss(), data, train_state; return_gradients=Val(false)
128+
)
129+
end
130+
131+
test_loss, nbatches = 0.0f0, 0
132+
for (xtest, ytest) in testloader
133+
nbatch = size(xtest, ndims(xtest))
134+
nbatches += nbatch
135+
test_loss +=
136+
Float32(
137+
prediction_loss_compiled(
138+
model, xtest, train_state.parameters, train_state.states, ytest
139+
),
140+
) * nbatch
141+
end
142+
test_loss /= nbatches
143+
144+
@printf("Epoch [%3d/%3d]\tTest Loss: %12.6f\n", epoch, epochs, test_loss)
145+
end
146+
147+
return train_state.parameters, train_state.states
148+
end
149+
150+
ps_trained, st_trained = train_model!(fno, ps, st, trainloader, testloader; epochs=50);
151+
nothing #hide
152+
```
153+
154+
## Plotting
155+
156+
```julia
157+
using CairoMakie, AlgebraOfGraphics
158+
const AoG = AlgebraOfGraphics
159+
AoG.set_aog_theme!()
160+
161+
x_data, y_data = preprocess(get_data(; i=20));
162+
gt_data = cat([x_data[:, :, :, i] for i in 1:size(x_data, 1):size(x_data, 4)]...; dims=1)[
163+
:, :, 1
164+
]';
165+
166+
n = 5
167+
inferenced_data = x_data[:, :, :, 1:1]
168+
for i in 1:n
169+
input_data = inferenced_data[:, :, :, i:i] |> xdev
170+
prediction = first(
171+
Reactant.with_config(;
172+
convolution_precision=PrecisionConfig.HIGH,
173+
dot_general_precision=PrecisionConfig.HIGH,
174+
) do
175+
@jit fno(input_data, ps_trained, st_trained)
176+
end,
177+
)
178+
inferenced_data = cat(inferenced_data, cdev(prediction); dims=4)
179+
end
180+
inferenced_data = cat([inferenced_data[:, :, :, i] for i in 1:n]...; dims=1)[:, :, 1]'
181+
182+
begin
183+
c = [
184+
RGBf([239, 71, 111] / 255...),
185+
RGBf([6, 214, 160] / 255...),
186+
RGBf([17, 138, 178] / 255...),
187+
]
188+
xi, yi = [2, 4, 6], [1, 3, 5]
189+
190+
time = Observable(1)
191+
192+
gx_data = @lift [0, 0, gt_data[:, $(time)]...][xi]
193+
gy_data = @lift [0, 0, gt_data[:, $(time)]...][yi]
194+
ix_data = @lift [0, 0, inferenced_data[:, $(time)]...][xi]
195+
iy_data = @lift [0, 0, inferenced_data[:, $(time)]...][yi]
196+
197+
fig = Figure(; size=(512, 512))
198+
ax = Axis(
199+
fig[1, 1];
200+
title="Predicting the motion of the double pendulum",
201+
subtitle=@lift("t = $($(time))"),
202+
)
203+
xlims!(ax, -1200, 1200)
204+
ylims!(ax, -1200, 1200)
205+
206+
lines!(ax, gx_data, gy_data; color=:black, linewidth=2, linestyle=:solid)
207+
scatter!(ax, gx_data, gy_data; color=c, markersize=35, strokewidth=2)
208+
lines!(ax, ix_data, iy_data; color=:gray, linewidth=2, linestyle=:dash)
209+
scatter!(ax, ix_data, iy_data; color=c, markersize=15, strokewidth=2)
210+
211+
record(
212+
fig,
213+
joinpath(@__DIR__, "double_pendulum.gif"),
214+
1:size(inferenced_data, 2);
215+
framerate=30
216+
) do t
217+
time[] = t
218+
end
219+
end
220+
```
221+
222+
![double_pendulum](double_pendulum.gif)

0 commit comments

Comments
 (0)