Skip to content

Commit aeb15d3

Browse files
authored
docs: burgers equation example using deeponets (#46)
* refactor: move tutorials to pre-built models * docs: working burgers example * fix: minor corrections * docs: build docs on buildkite * ci: dont fail CI if codecov fails * docs: fix codecov badge
1 parent 160732f commit aeb15d3

File tree

19 files changed

+255
-66
lines changed

19 files changed

+255
-66
lines changed

.buildkite/documentation.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
steps:
2+
- label: ":julia: Documentation"
3+
plugins:
4+
- JuliaCI/julia#v1:
5+
version: "1.10"
6+
- JuliaCI/julia-coverage#v1:
7+
codecov: true
8+
command: |
9+
julia --project -e '
10+
println("--- :julia: Instantiating project")
11+
using Pkg
12+
Pkg.instantiate()
13+
Pkg.activate("docs")
14+
Pkg.instantiate()
15+
push!(LOAD_PATH, @__DIR__)
16+
17+
println("+++ :julia: Building documentation")
18+
include("docs/make.jl")'
19+
agents:
20+
queue: "juliagpu"
21+
cuda: "*"
22+
env:
23+
DATADEPS_ALWAYS_ACCEPT: true
24+
if: build.message !~ /\[skip docs\]/ && !build.pull_request.draft
25+
timeout_in_minutes: 1000
26+
27+
env:
28+
SECRET_CODECOV_TOKEN: "vn5M+4wSwUFje6fl6UB/Q/rTmLHu3OlCCMgoPOXPQHYpLZTLz2hOHsV44MadAnxw8MsNVxLKZlXBKqP3IydU9gUfV7QUBtnvbUmIvgUHbr+r0bVaIVVhw6cnd0s8/b+561nU483eRJd35bjYDOlO+V5eDxkbdh/0bzLefXNXy5+ALxsBYzsp75Sx/9nuREfRqWwU6S45mne2ZlwCDpZlFvBDXQ2ICKYXpA45MpxhW9RuqfpQdi6sSR6I/HdHkV2cuJO99dqqh8xfUy6vWPC/+HUVrn9ETsrXtayX1MX3McKj869htGICpR8vqd311HTONYVprH2AN1bJqr5MOIZ8Xg==;U2FsdGVkX1+W55pTI7zq+NwYrbK6Cgqe+Gp8wMCmXY+W10aXTB0bS6zshiDYSQ1Y3piT91xFyNhS+9AsajY0yQ=="
29+
SECRET_DOCUMENTER_KEY: "hnN1QYai+CvEgtp2M9S9ZxK4xBthWAhx8snqR7MIKk3cqG7qHMsTd49KMSYWLXEsrsMbI2N4OLjNE3SgVZwyfuHop6x3ZqwVWD0OtldK6QvL/c+lmpzYMkIRCqKfXOKm9cQTWrr/eKzEh0uEUEdzCH7vrqLP7ywf5F9BMJbRmHW2f4cE1RKiG5cWVtgyO2OyRg7CxYrWgg0eKba+Vh8dLj1xL1Gs1XpLr7TDYRfSo7VMKYdkvB7Os7hmDN04YwzRFvGQBlroBaiC7KgbI5kZcL33R/whCHUYneSsXozGVyJGGGUbdxCaSslJ3p3Vb3vpnn+eAn0GyriWF/TaKIUlpQ==;U2FsdGVkX1/7Mqz8JnY+1EpErS6yTG4GfBJTNI+3WLtOuhyr+6SWZgIBCoitfIVSqvAws+RXVY09rG5Vsnvperr/lAzuBINsfOdIkSnS4HjfRHrKXGWzwzdTawPw1tn0wJa9h4RNahoBQk7Qs3clEb150AlomRJ//UAOW6UgD/nvr8TASjc5aHPkZodBogtMn93ti7hVRlUCNZAk58zpyXEsJorX9RT4g473y1aJp+CMd+Momo7/eljjwq3uCFyWpBnIwmipal4myUszDnSHuEQHcgh74ghJ1LkBjKcLUrRj06wKM/PfF23/P7h/KZCeryvghPSczKCJHx3wjH10DUuFK4maX9Yli/vg0SBq7/crWxNPXS5GmHYg/NBD2pgbe9qFvSCX4xJvkDl2ixHkNu5wQCCCTixtLfv00/XKWT5nh+PLSapQkY7ouesLFk8adpLAESPlX4pjSbljy0BrcHeFXFclFerq+Ia4StCLehWI4k2o/gfUEx7cnNN5ZKQpcYtQVOKRQriq8ri6dz+znQTmWqIM0xf+ZfWHy1OA3UXi269b0jdIlRQNtAjWhah9WWBaCiPY63/bgW54nIh1TVS1YaHzx2elWtFzHNpeTxIoibmltUivCE9r9Mk+YpfYUHJVUP9DwaLXcCiMZoLPNU4XWAzTejkpKJX/TdVFudLm2WNsCWi7nzjTfWSwlUOjrccLy9JVAx186gMlFhf5rHWeNFydBgEHBUig0FWkASwSPyv5iwD63NoL8UsTCRTNKMDtzn1XuCNxbmxgsbNO5bl+mcCDB/ZJuXexw3Ek04Ur4eZNyhf5qcn2cUzm76Jqi2vHSnY8JS9U9sSCdOd1gHFhxTOmO/OWh4jP5GKtWSWH23T2UOTS95daw1z2+GwLLP3R8chnFUCWvnJ+SUAHKoHYvbiT34gOyhcbnmoYTBCuvy4eghq1pVALUooPHOHzq2WsHAI07y6dUfbq55xaQZYiT++e+UisIID4mamITC5IFAQhkXPSAbVo56VrRCNQ+WnxLhzB6HNceAUdJCvH0IKcC9NqvfhCNioQ0EktfNla4G1BydJwgqexu8Tm4f5dl9e9V65hTn05yRYY2QN4hPCsw9Fr2NOrGfudCp5PEvA9TNuaobXvR5lSHPLL19ImbkCR1FGndrpwzWHR/0xjyS4lfdo3cW6SGVK33WZzOouM2ShGZc5Ma4yJdcPr1g0RsnWMGVX0A2gzJD9EL9RtNLgsDERnRLkEJgvhgVGXI+nmXadWzmn+plWgGEjSCLeZH1a4msFmcco3NFj0c6i5g+6cG09yILwhUTX+iegwm71r71pbPsMBk/hz/0ntVdklqFfCUZOC1n+MGJJmFfHmHAYHXSX5RmJhmvZG59g+oBEKxCh4VaIsSjLbopR4/TxS/uHdq3JJMOHXpgTQDack0qz3iEZFt+tKvKmah15PijdCv/Y8SHlN+txgSpW8+bQpxaMFq2li+/vApT18hBLjgGoMRv+I4ZmEmLUN1C+BzkCBTtErn4REQn4aKQmqlKC0jhQ5jatR5/Olpa1msiNjVyX0l4RBZFmYlpRY0KxKuVyguQtBF2qaxP2Lgrf2lp7tI7bK1z2MMBOxRTKHYs/Jt77mL/oeOOd2263EuZ2ku6K/iaaEY7jOCN7Du60UYVh0ePPorCTICHJ/arKLQ/uhh/d6Lo71bkj72B7CG75KGjZOg3+FRsFI0BLVjqvEJasR3apRCeipt3iMqmWc67wDp6DMXGCb/LlZh+oSfXtTY1VuMK4L4qcPl15Lf9H2xIt7N1YR6riq5R2EAokqQtazfxMfIk8+R25674O2Jksy0HnkIQ/H1/r2vaQYFzB0KVl671RdzHVYpfSMQQUcQtbzJHHsw5OIAPRwGd2RxFwW3GieKPYNS1u01TW+7z7Tq80qCOqewf+L+4ckbgYiiDDux+lzEK+PowtyOnQR1sjw7PE+BCPqF4Uc7b65EItFacseBD3mNVRXFUSG29rehgU4OrBjS7uUjn43QoMfonVCRhiLxydrfgGTdXW0g3qg3NYRsNeX3kP7b8MrzJ1zhPMZCca0YPbLeZ6Ab8lSc/GeJgrJ1m6ollE5yC2YisJMvNyJU525A7pReSk7u77nd1HTcddvngn8glJjcPn5/2aDuPuP4ZSEt67QLUaxVXce83XMY3FqiDngPhbJYSwGb9dlDizu9ph6a/SgZQyVpMBVL76RK9xropTUFDfDf5SOjukv5VWC2klIh8Fc4v9ur6otP5kx0MYV4ndlCoU6CyfEaQIQKXyZIIOWEPxyBvy8GnOSyh+5E82rJ50+PHkU9+iEJLBaZENnHlDJA30ANHSLF33t3CAxlqOvhFC5hx+BWCbK6R92mITBJOmo58+Fcz1e7u6SRLhOrp5rO+t3z6eB7SuVr0j16Tn4CJNGiYVGmhDjA1NfXoI7uEUkLVGuhra0eFzUCH8c+ot1pamb35E0knzaXy7azX5nuRXgt/491L0f9uc20ZqsrP63JtRQO574sbHxM6SnBWcIZSEEhAGRf4busL7DqjfHSFVx8TJH4FO0fCteCveJcxDoc94Wj7J+u5rUQBsRLDQ24XPjSgflqfhyU0lNKXuDWUSJPzqVuUM3fnYCsCdSd4/Y23t5k5dL68wOj3tBEunpGKltzG4zXjopay0VSkpty0EkaQbejvgXWnZv2G/SZlH/dW0fqp4a/M824kuzh9liF34fcV1uReoc+SKXs7iLYHgDXXajWy40HqvQR1NxtD17eIBvfhsAeKM3WR6YqudVkl4DYXU2hGM7L7znex2JAVs1thOtzBD6DbTKtlEciTs9O2B++22XnSmrusweKE68/8P0FN7DKA2oNQCig9J7BMT3pWA2jtoeq5sZU3vgvzlCLELefO4+YUWHY436znIp773oE8+IDdbxwLqKQfRI7zQAUGXNeUNEzAIla51zSARIFocslZfmF8ugx3NDBNqd6cjU+gk4qVAmUZW/Sw7SM+aPbC6hJPrDCfsMYhG7CNO5DZNBSO6Hhh0GSfl/zuB4Lzzg6xJuCfupg/A9DaUMfaIacYcsP7TjqVf07SbiPIXhlrFvuv5d43tkGBQdAQ/dNtpm2R5//u5Fi/mOlELiJBV5aLI2wZn2pGJRuOa6tb8GHSW3eLmv7RlrvBUEXutrfnevzUFu73Ue6jArVLhNQBhq64GhDEMQQvXH8QM/C0aTL8/IhFx80+kESUdnNa1OW7bNrBhOxTJJ4ho2JvenqVP44UQjJKOQGTjD5LSAqy9mkzK73lzRBVTNhpI2MTbKEnfXxN0wCce7cdpzJAUOTYF4fa4A9kkL3dqq/nhvkfzFY7lJbGI0J6hwMa9lWPwhB+wBU5hZu7s9eLDsfMoPE3vti8p8e91SgnPs+QtzPoe72GfBbJe8QQNnGeGW1aI2+qF82yarez1hTFaUPX1dgwSAKJS3/RSy1pMm1EPcM50P5At2ervJRJM5VCKrVJ8rvcFKuzfQfcZ/sqCW4l3E8ojJ7+7MAu4Ilfg0C6ZVY2S+rH52jIJtpTqbtGDaC0Zp7QTEbD5Eu1GkIxUmjpg23VGrNi4SnV+3ZbzOue8qTTstQdwzD2gDgb+zXUSkeGgZYYyEJQpGSjEJOPVi6CByqUtdK+Lfd19WRgSFJMmDfupQFUk7D19/d4K7/dj0NcuH+DQBfPvZl4cJkGqHeBXlK8uPN1A+q1S8D2QA7m4BzHAHz0Sy99KDq4spuFlm5t9oYAielqduZHieY4BugcQ1KcwR6pmc2ayEZvVRwfTNnxyodbOQCg4yh/pV4kX40G+tGnwtimaIVzyeIaZhL+iQ0M1+4KJCbKT2NGnk82hUKyrnBYSTPAG6vjRkg3P6KZj6tj4ABpxnymkZnMrUq2gpx/OlF0nl2CV0iCvj0RqZ5kX3NDVDtW8XmXh1sTATEGltfEclYoKMA/c47adE+rj16CiPwInnI/+JReNSjI14UoNn+27n8yQfxb3LbYkDSYkZcV7aysKPLQunMO0f0n6WUc2nprUIwQSO1Ccdhnio+YJs9JdvICN6JDpXfXM3eM6YGoSFv+XUownwfCW+fc1LMEXb4qjlUv7TIZKgXN5cEokKC2ZPpFJxF1TB2B7M3ib3a0orWuL8RsvexU7fpbUWo0+QI0CmiTqkmtsq1QeC7OCUU6Ocmz0iKUcYzcuLc9gfRyxXE8p7GqamK+okZWA0CJcMGvg0UXZcIaaHvUa9FuYLnjEOwvz1V92GGL0epIR9Vm4N6Aa1HzppbDuyY+donj0KBQkDxhix4+sF79LzdAJq1cvuxWW3km8f9CXYRdJ3oOY2nqwQXsi08K+yE+Zv2YVPd6Z1tJfk7Y+KiyhSIsKxDqIY1SaWwgRRYpi8c52BjTtnaIbYjFI="

.buildkite/pipeline.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,22 @@ steps:
1919
agents:
2020
queue: "juliagpu"
2121

22+
- path:
23+
- "src/"
24+
- "ext/"
25+
- "test/"
26+
- "docs/"
27+
- "Project.toml"
28+
- ".buildkite/"
29+
config:
30+
command: "buildkite-agent pipeline upload .buildkite/documentation.yml"
31+
agents:
32+
queue: "juliagpu"
33+
2234
- label: "Triggering Pipelines (Main Branch / Tag)"
2335
if: build.branch == "main" || build.tag != null
2436
agents:
2537
queue: "juliagpu"
26-
command: "buildkite-agent pipeline upload .buildkite/testing.yml"
38+
command: |
39+
buildkite-agent pipeline upload .buildkite/testing.yml
40+
buildkite-agent pipeline upload .buildkite/documentation.yml

.github/workflows/CI.yml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
fail-fast: false
2929
matrix:
3030
version:
31-
- "1"
31+
- "1.10"
3232
os:
3333
- ubuntu-latest
3434
- macos-latest
@@ -56,7 +56,7 @@ jobs:
5656
files: lcov.info
5757
token: ${{ secrets.CODECOV_TOKEN }}
5858
verbose: true
59-
fail_ci_if_error: true
59+
fail_ci_if_error: false
6060

6161
downgrade:
6262
if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }}
@@ -65,7 +65,7 @@ jobs:
6565
strategy:
6666
fail-fast: false
6767
matrix:
68-
version: ["1"]
68+
version: ["1.10"]
6969
steps:
7070
- uses: actions/checkout@v4
7171
- uses: julia-actions/setup-julia@v2
@@ -80,7 +80,7 @@ jobs:
8080
files: lcov.info
8181
token: ${{ secrets.CODECOV_TOKEN }}
8282
verbose: true
83-
fail_ci_if_error: true
83+
fail_ci_if_error: false
8484

8585
invalidations:
8686
# Only run on PRs to the default branch.
@@ -113,5 +113,3 @@ jobs:
113113

114114
env:
115115
BACKEND_GROUP: "CPU"
116-
RETESTITEMS_NWORKERS: 4
117-
RETESTITEMS_NWORKER_THREADS: 2

.github/workflows/Documentation.yml

Lines changed: 0 additions & 37 deletions
This file was deleted.

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@ Manifest-v*.toml
33
.vscode
44
wip
55
examples
6-
docs/build
6+
docs/build
7+
8+
.CondaPkg

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NeuralOperators"
22
uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
33
authors = ["Avik Pal <[email protected]>"]
4-
version = "0.5.0"
4+
version = "0.5.1"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
@@ -11,6 +11,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1111
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1212
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1313
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
14+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1415
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1516
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1617
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
@@ -24,6 +25,7 @@ FFTW = "1.8"
2425
Lux = "1"
2526
LuxCore = "1"
2627
LuxLib = "1.2"
28+
MLDataDevices = "1.2.0"
2729
NNlib = "0.9.21"
2830
Random = "1.10"
2931
Static = "1.1.1"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
[![Join the chat at https://julialang.zulipchat.com #sciml-bridged](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/279055-sciml-bridged)
44
[![Global Docs](https://img.shields.io/badge/docs-SciML-blue.svg)](https://docs.sciml.ai/NeuralOperators/stable/)
55

6-
[![codecov](https://codecov.io/gh/SciML/NeuralOperators.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/SciML/NeuralOperators.jl)
6+
[![codecov](https://codecov.io/gh/SciML/NeuralOperators.jl/branch/main/graph/badge.svg?token=wTS4cxrvB1)](https://codecov.io/gh/SciML/NeuralOperators.jl)
77
[![Build Status](https://github.com/SciML/NeuralOperators.jl/workflows/CI/badge.svg)](https://github.com/SciML/NeuralOperators.jl/actions?query=workflow%3ACI)
88
[![Build status](https://badge.buildkite.com/dd09599b08f61de1b5c7960aacd5390554c53e3b54f1407ba1.svg?branch=main)](https://buildkite.com/julialang/neuraloperators-dot-jl)
99

docs/CondaPkg.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[deps]
2+
gdown = ""

docs/Project.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,29 @@
11
[deps]
22
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
3+
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
4+
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
35
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
46
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
7+
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
8+
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
9+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
510
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
611
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
12+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
13+
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
714
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
815

916
[compat]
1017
CairoMakie = "0.12.11"
18+
CondaPkg = "0.2.23"
19+
DataDeps = "0.7.13"
1120
Documenter = "1.7.0"
1221
Lux = "1"
22+
LuxCUDA = "0.3.3"
23+
MAT = "0.10.7"
24+
MLUtils = "0.4.4"
1325
NeuralOperators = "0.5"
1426
Optimisers = "0.3.3"
27+
Printf = "1.10"
28+
PythonCall = "0.9.23"
1529
Zygote = "0.6.71"

docs/pages.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
pages = [
22
"NeuralOperators.jl" => "index.md",
3+
"Pre-built Models" => [
4+
"FNO" => "models/fno.md",
5+
"DeepONet" => "models/deeponet.md",
6+
"NOMAD" => "models/nomad.md"
7+
],
38
"Tutorials" => [
4-
"FNO" => "tutorials/fno.md",
5-
"DeepONet" => "tutorials/deeponet.md",
6-
"NOMAD" => "tutorials/nomad.md"
9+
"Burgers Equation" => "tutorials/burgers.md"
710
],
811
"API Reference" => "api.md"
912
]

docs/src/assets/Project.toml

Lines changed: 0 additions & 6 deletions
This file was deleted.

docs/src/index.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ Pkg.add("NeuralOperators")
1414

1515
Currently provided operator architectures are :
1616

17-
- [Fourier Neural Operators (FNOs)](tutorials/fno.md)
18-
- [DeepONets](tutorials/deeponet.md)
19-
- [Nonlinear Manifold Decoders for Operator Learning (NOMADs)](tutorials/nomad.md)
17+
- [Fourier Neural Operators (FNOs)](models/fno.md)
18+
- [DeepONets](models/deeponet.md)
19+
- [Nonlinear Manifold Decoders for Operator Learning (NOMADs)](models/nomad.md)
2020

2121
## Reproducibility
2222

File renamed without changes.
File renamed without changes.
File renamed without changes.

docs/src/tutorials/burgers.md

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Burgers Equation using DeepONet
2+
3+
## Data Loading
4+
5+
```@example burgers
6+
using DataDeps, MAT, MLUtils
7+
using PythonCall, CondaPkg # For `gdown`
8+
using Printf
9+
10+
const gdown = pyimport("gdown")
11+
12+
register(
13+
DataDep(
14+
"Burgers",
15+
"""
16+
Burgers' equation dataset from
17+
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
18+
19+
mapping between initial conditions to the solutions at the last point of time \
20+
evolution in some function space.
21+
22+
u(x,0) -> u(x, time_end):
23+
24+
* `a`: initial conditions u(x,0)
25+
* `u`: solutions u(x,t_end)
26+
""",
27+
"https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe",
28+
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd";
29+
fetch_method=(url, local_dir) -> begin
30+
pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip")))
31+
end,
32+
post_fetch_method=unpack
33+
)
34+
)
35+
36+
filepath = joinpath(datadep"Burgers", "burgers_data_R10.mat")
37+
38+
const N = 2048
39+
const Δsamples = 2^3
40+
const grid_size = div(2^13, Δsamples)
41+
const T = Float32
42+
43+
file = matopen(filepath)
44+
x_data = reshape(T.(collect(read(file, "a")[1:N, 1:Δsamples:end])), N, :, 1)
45+
y_data = reshape(T.(collect(read(file, "u")[1:N, 1:Δsamples:end])), N, :, 1)
46+
close(file)
47+
48+
x_data = permutedims(x_data, (2, 1, 3))
49+
grid = reshape(T.(collect(range(0, 1; length=grid_size)')), :, grid_size, 1)
50+
```
51+
52+
## Model
53+
54+
```@example burgers
55+
using Lux, NeuralOperators, Optimisers, Zygote, Random
56+
using LuxCUDA
57+
58+
const cdev = cpu_device()
59+
const gdev = gpu_device()
60+
61+
deeponet = DeepONet(;
62+
branch=(size(x_data, 1), ntuple(Returns(32), 5)...),
63+
trunk=(size(grid, 1), ntuple(Returns(32), 5)...),
64+
branch_activation=tanh,
65+
trunk_activation=tanh
66+
)
67+
ps, st = Lux.setup(Random.default_rng(), deeponet) |> gdev;
68+
```
69+
70+
## Training
71+
72+
```@example burgers
73+
x_data_dev = x_data |> gdev
74+
y_data_dev = y_data |> gdev
75+
grid_dev = grid |> gdev
76+
77+
function loss_function(model, ps, st, ((v, y), u))
78+
û, stₙ = model((v, y), ps, st)
79+
return MAELoss()(û, u), stₙ, (;)
80+
end
81+
82+
function train_model!(model, ps, st, data; epochs=5000)
83+
train_state = Training.TrainState(model, ps, st, Adam(0.0001f0))
84+
85+
for epoch in 1:epochs
86+
_, loss, _, train_state = Training.single_train_step!(
87+
AutoZygote(), loss_function, data, train_state)
88+
89+
if epoch % 25 == 1 || epoch == epochs
90+
@printf("Epoch %d: loss = %.6e\n", epoch, loss)
91+
end
92+
end
93+
94+
return train_state.parameters, train_state.states
95+
end
96+
97+
ps_trained, st_trained = train_model!(
98+
deeponet, ps, st, ((x_data_dev, grid_dev), y_data_dev))
99+
```
100+
101+
## Plotting
102+
103+
```@example burgers
104+
using CairoMakie
105+
106+
pred = first(deeponet((x_data_dev, grid_dev), ps_trained, st_trained)) |> cdev
107+
108+
begin
109+
fig = Figure(; size=(1024, 1024))
110+
111+
axs = [Axis(fig[i, j]) for i in 1:4, j in 1:4]
112+
for i in 1:4, j in 1:4
113+
idx = i + (j - 1) * 4
114+
ax = axs[i, j]
115+
l1 = lines!(ax, vec(grid), pred[idx, :, 1])
116+
l2 = lines!(ax, vec(grid), y_data[idx, :, 1])
117+
118+
i == 4 && (ax.xlabel = "x")
119+
j == 1 && (ax.ylabel = "u(x)")
120+
121+
if i == 1 && j == 1
122+
axislegend(ax, [l1, l2], ["Predictions", "Ground Truth"])
123+
end
124+
end
125+
linkaxes!(axs...)
126+
127+
fig[0, :] = Label(fig, "Burgers Equation using DeepONet"; tellwidth=false, font=:bold)
128+
129+
fig
130+
end
131+
```

src/NeuralOperators.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module NeuralOperators
22

33
using ArgCheck: @argcheck
4-
using ChainRulesCore: @non_differentiable
4+
using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable
55
using ConcreteStructs: @concrete
66
using FFTW: FFTW, irfft, rfft
77
using Random: Random, AbstractRNG
@@ -10,9 +10,11 @@ using Static: StaticBool, False, True, known, static
1010
using Lux
1111
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer
1212
using LuxLib: batched_matmul
13-
using NNlib: NNlib, batched_adjoint
13+
using MLDataDevices: AbstractDevice, AbstractGPUDevice
14+
using NNlib: NNlib
1415

1516
const BoolLike = Union{Bool, StaticBool, Val{true}, Val{false}}
17+
const CRC = ChainRulesCore
1618

1719
include("utils.jl")
1820

0 commit comments

Comments
 (0)