Skip to content

Commit 9c84c27

Browse files
committed
feat: update to newest versions
1 parent f7c46ec commit 9c84c27

File tree

7 files changed

+18
-11
lines changed

7 files changed

+18
-11
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ ChainRulesCore = "1.24.0"
2222
ConcreteStructs = "0.2.3"
2323
FFTW = "1.8.0"
2424
Lux = "0.5.62"
25-
LuxCore = "0.1.15"
25+
LuxCore = "0.1.21"
2626
LuxLib = "0.3.40"
27-
NNlib = "0.9.17"
27+
NNlib = "0.9.21"
2828
Random = "1.10"
2929
Reexport = "1.2.2"
3030
WeightInitializers = "1"

test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
55
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
66
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
7+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
78
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
89
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
910
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1011
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
12+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1113
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
1214
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1315
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -20,10 +22,12 @@ Documenter = "1.5.0"
2022
ExplicitImports = "1.9.0"
2123
Hwloc = "3.2.0"
2224
InteractiveUtils = "<0.0.1, 1"
25+
Lux = "0.5.62"
2326
LuxTestUtils = "1.1.2"
2427
MLDataDevices = "1.0.0"
2528
Optimisers = "0.3.3"
2629
Pkg = "1.10"
30+
Random = "1.10"
2731
Reexport = "1.2.2"
2832
ReTestItems = "1.24.0"
2933
StableRNGs = "1.0.2"

test/deeponet_tests.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@
5050
@test setup.out_size == size(pred)
5151

5252
__f = (u, y, ps) -> sum(abs2, first(deeponet((u, y), ps, st)))
53-
test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3)
53+
test_gradients(
54+
__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])
5455
end
5556

5657
@testset "Embedding layer mismatch" begin
@@ -62,9 +63,6 @@
6263

6364
ps, st = Lux.setup(rng, deeponet) |> dev
6465
@test_throws ArgumentError deeponet((u, y), ps, st)
65-
66-
__f = (u, y, ps) -> sum(abs2, first(deeponet((u, y), ps, st)))
67-
test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3)
6866
end
6967
end
7068
end

test/fno_tests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
end broken=broken
3030

3131
__f = (x, ps) -> sum(abs2, first(fno(x, ps, st)))
32-
test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3)
32+
test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3,
33+
skip_backends=[AutoEnzyme(), AutoTracker(), AutoReverseDiff()])
3334
end
3435
end
3536
end

test/layers_tests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
end broken=broken
3939

4040
__f = (x, ps) -> sum(abs2, first(m(x, ps, st)))
41-
test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3)
41+
test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3,
42+
skip_backends=[AutoEnzyme(), AutoTracker(), AutoReverseDiff()])
4243
end
4344
end
4445
end

test/qa_tests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ end
99
@testitem "Aqua: Quality Assurance" tags=[:qa] begin
1010
using Aqua
1111

12-
Aqua.test_all(NeuralOperators)
12+
Aqua.test_all(NeuralOperators; ambiguities=false)
13+
Aqua.test_ambiguities(NeuralOperators; recursive=false)
1314
end
1415

1516
@testitem "Explicit Imports: Quality Assurance" tags=[:qa] begin

test/shared_testsetup.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import Reexport: @reexport
44
@reexport using Lux, Zygote, Optimisers, Random, StableRNGs, LuxTestUtils
55
using MLDataDevices
66

7+
LuxTestUtils.jet_target_modules!(["NeuralOperators", "Lux", "LuxLib"])
8+
79
const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All"))
810

911
if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda"
@@ -37,9 +39,9 @@ train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...)
3739
function train!(loss, backend, model, ps, st, data; epochs=10)
3840
l1 = loss(model, ps, st, first(data))
3941

40-
tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.01f0))
42+
tstate = Training.TrainState(model, ps, st, Adam(0.01f0))
4143
for _ in 1:epochs, (x, y) in data
42-
_, _, _, tstate = Lux.Experimental.single_train_step!(backend, loss, (x, y), tstate)
44+
_, _, _, tstate = Training.single_train_step!(backend, loss, (x, y), tstate)
4345
end
4446

4547
l2 = loss(model, ps, st, first(data))

0 commit comments

Comments
 (0)