|
| 1 | +# Learning the motion of a double pendulum |
| 2 | + |
| 3 | +## Data Loading |
| 4 | + |
| 5 | +```julia |
| 6 | +using DataDeps, CSV, MLUtils, DataFrames |
| 7 | +using Printf |
| 8 | + |
| 9 | +register( |
| 10 | + DataDep( |
| 11 | + "DoublePendulumChaotic", |
| 12 | + """ |
| 13 | + Dataset was generated on the basis of 21 individual runs of a double pendulum. |
| 14 | + Each of the recorded sequences lasted around 40s and consisted of around 17500 frames. |
| 15 | +
|
| 16 | + * `x_red`: Horizontal pixel coordinate of the red point (the central pivot to the |
| 17 | + first pendulum) |
| 18 | + * `y_red`: Vertical pixel coordinate of the red point (the central pivot to the first |
| 19 | + pendulum) |
| 20 | + * `x_green`: Horizontal pixel coordinate of the green point (the first pendulum) |
| 21 | + * `y_green`: Vertical pixel coordinate of the green point (the first pendulum) |
| 22 | + * `x_blue`: Horizontal pixel coordinate of the blue point (the second pendulum) |
| 23 | + * `y_blue`: Vertical pixel coordinate of the blue point (the second pendulum) |
| 24 | +
|
| 25 | + Page: https://developer.ibm.com/exchanges/data/all/double-pendulum-chaotic/ |
| 26 | + """, |
| 27 | + "https://dax-cdn.cdn.appdomain.cloud/dax-double-pendulum-chaotic/2.0.1/double-pendulum-chaotic.tar.gz", |
| 28 | + "4ca743b4b783094693d313ebedc2e8e53cf29821ee8b20abd99f8fb4c0866f8d"; |
| 29 | + post_fetch_method=unpack, |
| 30 | + ), |
| 31 | +) |
| 32 | + |
| 33 | +function get_data(; i=0, n=-1) |
| 34 | + data_path = joinpath(datadep"DoublePendulumChaotic", "original", "dpc_dataset_csv") |
| 35 | + df = CSV.read( |
| 36 | + joinpath(data_path, "$i.csv"), |
| 37 | + DataFrame; |
| 38 | + header=[:x_red, :y_red, :x_green, :y_green, :x_blue, :y_blue], |
| 39 | + ) |
| 40 | + |
| 41 | + n < 0 && return collect(Float32, Matrix(df)') |
| 42 | + return collect(Float32, Matrix(df)')[:, 1:n] |
| 43 | +end |
| 44 | + |
| 45 | +function preprocess(x; Δt=1, nx=30, ny=30) |
| 46 | + # move red point to (0, 0) |
| 47 | + xs_red, ys_red = x[1, :], x[2, :] |
| 48 | + x[3, :] -= xs_red |
| 49 | + x[5, :] -= xs_red |
| 50 | + x[4, :] -= ys_red |
| 51 | + x[6, :] -= ys_red |
| 52 | + |
| 53 | + # needs only green and blue points |
| 54 | + x = reshape(x[3:6, 1:Δt:end], 1, 4, :) |
| 55 | + # velocity of green and blue points |
| 56 | + ∇x = x[:, :, 2:end] - x[:, :, 1:(end - 1)] |
| 57 | + # merge info of pos and velocity |
| 58 | + x = cat(x[:, :, 1:(end - 1)], ∇x; dims=1) |
| 59 | + |
| 60 | + # with info of first nx steps to inference next ny steps |
| 61 | + n = size(x, ndims(x)) - (nx + ny) + 1 |
| 62 | + xs = Array{Float32}(undef, size(x)[1:2]..., nx, n) |
| 63 | + ys = Array{Float32}(undef, size(x)[1:2]..., ny, n) |
| 64 | + for i in 1:n |
| 65 | + xs[:, :, :, i] .= x[:, :, i:(i + nx - 1)] |
| 66 | + ys[:, :, :, i] .= x[:, :, (i + nx):(i + nx + ny - 1)] |
| 67 | + end |
| 68 | + |
| 69 | + return permutedims(xs, (3, 2, 1, 4)), permutedims(ys, (3, 2, 1, 4)) |
| 70 | +end |
| 71 | + |
| 72 | +function get_dataloader(; n_file=20, Δt=1, nx=30, ny=30, ratio=0.9, batchsize=128) |
| 73 | + xs, ys = Array{Float32}(undef, nx, 4, 2, 0), Array{Float32}(undef, ny, 4, 2, 0) |
| 74 | + for i in 1:n_file |
| 75 | + xs_i, ys_i = preprocess(get_data(; i=i - 1); Δt, nx, ny) |
| 76 | + xs, ys = cat(xs, xs_i; dims=4), cat(ys, ys_i; dims=4) |
| 77 | + end |
| 78 | + |
| 79 | + data_train, data_test = splitobs(shuffleobs((xs, ys)); at=ratio) |
| 80 | + |
| 81 | + trainloader = DataLoader(data_train; batchsize, shuffle=true, partial=false) |
| 82 | + testloader = DataLoader(data_test; batchsize, shuffle=false, partial=false) |
| 83 | + |
| 84 | + return trainloader, testloader |
| 85 | +end |
| 86 | +``` |
| 87 | + |
| 88 | +## Model |
| 89 | + |
| 90 | +```julia |
| 91 | +using Lux, NeuralOperators, Optimisers, Random, Reactant |
| 92 | + |
| 93 | +const cdev = cpu_device() |
| 94 | +const xdev = reactant_device(; force=true) |
| 95 | + |
| 96 | +fno = FourierNeuralOperator( |
| 97 | + (16, 4), 2, 2, 64; num_layers=6, activation=gelu, positional_embedding=:none |
| 98 | +) |
| 99 | +ps, st = Lux.setup(Random.default_rng(), fno) |> xdev; |
| 100 | +``` |
| 101 | + |
| 102 | +## Training |
| 103 | + |
| 104 | +```julia |
| 105 | +trainloader, testloader = get_dataloader(; Δt=1, nx=30, ny=30) |> xdev; |
| 106 | + |
| 107 | +function prediction_loss(model, x, ps, st, y) |
| 108 | + return MSELoss()(first(model(x, ps, st)), y) |
| 109 | +end |
| 110 | + |
| 111 | +function train_model!(model, ps, st, trainloader, testloader; epochs=20) |
| 112 | + train_state = Training.TrainState(model, ps, st, AdamW(; eta=3.0f-4, lambda=1.0f-5)) |
| 113 | + |
| 114 | + (xtest, ytest) = first(testloader) |
| 115 | + prediction_loss_compiled = Reactant.with_config(; |
| 116 | + convolution_precision=PrecisionConfig.HIGH, |
| 117 | + dot_general_precision=PrecisionConfig.HIGH, |
| 118 | + ) do |
| 119 | + @compile prediction_loss( |
| 120 | + model, xtest, train_state.parameters, train_state.states, ytest |
| 121 | + ) |
| 122 | + end |
| 123 | + |
| 124 | + for epoch in 1:epochs |
| 125 | + for data in trainloader |
| 126 | + (_, _, _, train_state) = Training.single_train_step!( |
| 127 | + AutoEnzyme(), MSELoss(), data, train_state; return_gradients=Val(false) |
| 128 | + ) |
| 129 | + end |
| 130 | + |
| 131 | + test_loss, nbatches = 0.0f0, 0 |
| 132 | + for (xtest, ytest) in testloader |
| 133 | + nbatch = size(xtest, ndims(xtest)) |
| 134 | + nbatches += nbatch |
| 135 | + test_loss += |
| 136 | + Float32( |
| 137 | + prediction_loss_compiled( |
| 138 | + model, xtest, train_state.parameters, train_state.states, ytest |
| 139 | + ), |
| 140 | + ) * nbatch |
| 141 | + end |
| 142 | + test_loss /= nbatches |
| 143 | + |
| 144 | + @printf("Epoch [%3d/%3d]\tTest Loss: %12.6f\n", epoch, epochs, test_loss) |
| 145 | + end |
| 146 | + |
| 147 | + return train_state.parameters, train_state.states |
| 148 | +end |
| 149 | + |
| 150 | +ps_trained, st_trained = train_model!(fno, ps, st, trainloader, testloader; epochs=50); |
| 151 | +nothing #hide |
| 152 | +``` |
| 153 | + |
| 154 | +## Plotting |
| 155 | + |
| 156 | +```julia |
| 157 | +using CairoMakie, AlgebraOfGraphics |
| 158 | +const AoG = AlgebraOfGraphics |
| 159 | +AoG.set_aog_theme!() |
| 160 | + |
| 161 | +x_data, y_data = preprocess(get_data(; i=20)); |
| 162 | +gt_data = cat([x_data[:, :, :, i] for i in 1:size(x_data, 1):size(x_data, 4)]...; dims=1)[ |
| 163 | + :, :, 1 |
| 164 | +]'; |
| 165 | + |
| 166 | +n = 5 |
| 167 | +inferenced_data = x_data[:, :, :, 1:1] |
| 168 | +for i in 1:n |
| 169 | + input_data = inferenced_data[:, :, :, i:i] |> xdev |
| 170 | + prediction = first( |
| 171 | + Reactant.with_config(; |
| 172 | + convolution_precision=PrecisionConfig.HIGH, |
| 173 | + dot_general_precision=PrecisionConfig.HIGH, |
| 174 | + ) do |
| 175 | + @jit fno(input_data, ps_trained, st_trained) |
| 176 | + end, |
| 177 | + ) |
| 178 | + inferenced_data = cat(inferenced_data, cdev(prediction); dims=4) |
| 179 | +end |
| 180 | +inferenced_data = cat([inferenced_data[:, :, :, i] for i in 1:n]...; dims=1)[:, :, 1]' |
| 181 | + |
| 182 | +begin |
| 183 | + c = [ |
| 184 | + RGBf([239, 71, 111] / 255...), |
| 185 | + RGBf([6, 214, 160] / 255...), |
| 186 | + RGBf([17, 138, 178] / 255...), |
| 187 | + ] |
| 188 | + xi, yi = [2, 4, 6], [1, 3, 5] |
| 189 | + |
| 190 | + time = Observable(1) |
| 191 | + |
| 192 | + gx_data = @lift [0, 0, gt_data[:, $(time)]...][xi] |
| 193 | + gy_data = @lift [0, 0, gt_data[:, $(time)]...][yi] |
| 194 | + ix_data = @lift [0, 0, inferenced_data[:, $(time)]...][xi] |
| 195 | + iy_data = @lift [0, 0, inferenced_data[:, $(time)]...][yi] |
| 196 | + |
| 197 | + fig = Figure(; size=(512, 512)) |
| 198 | + ax = Axis( |
| 199 | + fig[1, 1]; |
| 200 | + title="Predicting the motion of the double pendulum", |
| 201 | + subtitle=@lift("t = $($(time))"), |
| 202 | + ) |
| 203 | + xlims!(ax, -1200, 1200) |
| 204 | + ylims!(ax, -1200, 1200) |
| 205 | + |
| 206 | + lines!(ax, gx_data, gy_data; color=:black, linewidth=2, linestyle=:solid) |
| 207 | + scatter!(ax, gx_data, gy_data; color=c, markersize=35, strokewidth=2) |
| 208 | + lines!(ax, ix_data, iy_data; color=:gray, linewidth=2, linestyle=:dash) |
| 209 | + scatter!(ax, ix_data, iy_data; color=c, markersize=15, strokewidth=2) |
| 210 | + |
| 211 | + record( |
| 212 | + fig, |
| 213 | + joinpath(@__DIR__, "double_pendulum.gif"), |
| 214 | + 1:size(inferenced_data, 2); |
| 215 | + framerate=30 |
| 216 | + ) do t |
| 217 | + time[] = t |
| 218 | + end |
| 219 | +end |
| 220 | +``` |
| 221 | + |
| 222 | + |
0 commit comments