Skip to content

Commit f3c651e

Browse files
CompatHelper: bump compat for "CUDAapi" to "4.0" (#23)
* CompatHelper: bump compat for "CUDAapi" to "4.0" * Make gather support Fill * Fix gitlab-ci Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Yueh-Hua Tu <[email protected]>
1 parent af62987 commit f3c651e

File tree

6 files changed

+74
-56
lines changed

6 files changed

+74
-56
lines changed

.gitlab-ci.yml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,27 @@ variables:
55
include:
66
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v6.yml'
77

8-
image: juliagpu/cuda:10.1-cudnn7-cutensor1-devel-ubuntu18.04
8+
image: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04
99

10-
test:v1.3:
10+
test:v1.4:
1111
extends:
12-
- .julia:1.3
12+
- .julia:1.4
1313
- .test
14-
variables:
15-
CI_VERSION_TAG: 'v1.3'
16-
17-
test:dev:
18-
extends:
19-
- .julia:nightly
20-
- .test
21-
allow_failure: true
2214
variables:
2315
CI_VERSION_TAG: 'v1.4'
2416

17+
# test:dev:
18+
# extends:
19+
# - .julia:nightly
20+
# - .test
21+
# allow_failure: true
22+
# variables:
23+
# CI_VERSION_TAG: 'v1.5'
24+
2525
coverage:
2626
stage: post
2727
extends:
28-
- .julia:1.2
28+
- .julia:1.4
2929
script:
3030
- julia -e 'using Pkg;
3131
Pkg.instantiate();

Manifest.toml

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ git-tree-sha1 = "c88cfc7f9c1f9f8633cddf0b56e86302b70f64c5"
1818
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1919
version = "1.0.1"
2020

21+
[[ArrayLayouts]]
22+
deps = ["FillArrays", "LinearAlgebra"]
23+
git-tree-sha1 = "bc779df8d73be70e4e05a63727d3a4dfb4c52b1f"
24+
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
25+
version = "0.1.5"
26+
2127
[[Base64]]
2228
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
2329

@@ -58,27 +64,33 @@ version = "0.6.0"
5864

5965
[[ColorTypes]]
6066
deps = ["FixedPointNumbers", "Random"]
61-
git-tree-sha1 = "7b62b728a5f3dd6ee3b23910303ccf27e82fad5e"
67+
git-tree-sha1 = "b9de8dc6106e09c79f3f776c27c62360d30e5eb8"
6268
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
63-
version = "0.8.1"
69+
version = "0.9.1"
6470

6571
[[Colors]]
6672
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport"]
67-
git-tree-sha1 = "c9c1845d6bf22e34738bee65c357a69f416ed5d1"
73+
git-tree-sha1 = "177d8b959d3c103a6d57574c38ee79c81059c31b"
6874
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
69-
version = "0.9.6"
75+
version = "0.11.2"
7076

7177
[[CommonSubexpressions]]
7278
deps = ["Test"]
7379
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
7480
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
7581
version = "0.2.0"
7682

83+
[[CompilerSupportLibraries_jll]]
84+
deps = ["Libdl", "Pkg"]
85+
git-tree-sha1 = "b57c5d019367c90f234a7bc7e24ff0a84971da5d"
86+
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
87+
version = "0.2.0+1"
88+
7789
[[CuArrays]]
7890
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
79-
git-tree-sha1 = "7c20c5a45bb245cf248f454d26966ea70255b271"
91+
git-tree-sha1 = "7fa1331a0e0cd10e43b94b280027bda45990cb63"
8092
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
81-
version = "1.7.2"
93+
version = "1.7.3"
8294

8395
[[DataAPI]]
8496
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
@@ -87,9 +99,9 @@ version = "1.1.0"
8799

88100
[[DataStructures]]
89101
deps = ["InteractiveUtils", "OrderedCollections"]
90-
git-tree-sha1 = "b7720de347734f4716d1815b00ce5664ed6bbfd4"
102+
git-tree-sha1 = "5a431d46abf2ef2a4d5d00bd0ae61f651cf854c8"
91103
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
92-
version = "0.17.9"
104+
version = "0.17.10"
93105

94106
[[Dates]]
95107
deps = ["Printf"]
@@ -107,9 +119,9 @@ version = "1.0.2"
107119

108120
[[DiffRules]]
109121
deps = ["NaNMath", "Random", "SpecialFunctions"]
110-
git-tree-sha1 = "10dca52cf6d4a62d82528262921daf63b99704a2"
122+
git-tree-sha1 = "eb0c34204c8410888844ada5359ac8b96292cfd1"
111123
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
112-
version = "1.0.0"
124+
version = "1.0.1"
113125

114126
[[Distributed]]
115127
deps = ["Random", "Serialization", "Sockets"]
@@ -123,26 +135,26 @@ version = "1.2.0"
123135

124136
[[FFTW_jll]]
125137
deps = ["Libdl", "Pkg"]
126-
git-tree-sha1 = "05674f209a6e3387dd103a945b0113eeb64b1a58"
138+
git-tree-sha1 = "ddb57f4cf125243b4aa4908c94d73a805f3cbf2c"
127139
uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a"
128-
version = "3.3.9+3"
140+
version = "3.3.9+4"
129141

130142
[[FillArrays]]
131143
deps = ["LinearAlgebra", "Random", "SparseArrays"]
132-
git-tree-sha1 = "fec413d4fc547992eb62a5c544cedb6d7853c1f5"
144+
git-tree-sha1 = "85c6b57e2680fa28d5c8adc798967377646fbf66"
133145
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
134-
version = "0.8.4"
146+
version = "0.8.5"
135147

136148
[[FixedPointNumbers]]
137-
git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
149+
git-tree-sha1 = "4aaea64dd0c30ad79037084f8ca2b94348e65eaa"
138150
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
139-
version = "0.6.1"
151+
version = "0.7.1"
140152

141153
[[Flux]]
142154
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "CuArrays", "DelimitedFiles", "Juno", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "SHA", "Statistics", "StatsBase", "Test", "ZipFile", "Zygote"]
143-
git-tree-sha1 = "8134adbb0f10b0d22b22f8b4299d0d20509edc5f"
155+
git-tree-sha1 = "b5647a92b4d547f835b0eac904331a97c45d773d"
144156
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
145-
version = "0.10.1"
157+
version = "0.10.3"
146158

147159
[[ForwardDiff]]
148160
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
@@ -173,18 +185,19 @@ deps = ["Markdown"]
173185
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
174186

175187
[[Juno]]
176-
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
177-
git-tree-sha1 = "30d94657a422d09cb97b6f86f04f750fa9c50df8"
188+
deps = ["Base64", "Logging", "Media", "Profile"]
189+
git-tree-sha1 = "e1ba2a612645b3e07c773c3a208f215745081fe6"
178190
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
179-
version = "0.7.2"
191+
version = "0.8.1"
180192

181193
[[LLVM]]
182194
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
183-
git-tree-sha1 = "1d08d7e4250f452f6cb20e4574daaebfdbee0ff7"
195+
git-tree-sha1 = "b6b86801ae2f2682e0a4889315dc76b68db2de71"
184196
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
185-
version = "1.3.3"
197+
version = "1.3.4"
186198

187199
[[LibGit2]]
200+
deps = ["Printf"]
188201
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
189202

190203
[[Libdl]]
@@ -230,20 +243,20 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"
230243

231244
[[NNlib]]
232245
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
233-
git-tree-sha1 = "755c0bab3912ff782167e1b4b774b833f8a0e550"
246+
git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824"
234247
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
235-
version = "0.6.4"
248+
version = "0.6.6"
236249

237250
[[NaNMath]]
238251
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
239252
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
240253
version = "0.3.3"
241254

242255
[[OpenSpecFun_jll]]
243-
deps = ["Libdl", "Pkg"]
244-
git-tree-sha1 = "65f672edebf3f4e613ddf37db9dcbd7a407e5e90"
256+
deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"]
257+
git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87"
245258
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
246-
version = "0.5.3+1"
259+
version = "0.5.3+3"
247260

248261
[[OrderedCollections]]
249262
deps = ["Random", "Serialization", "Test"]
@@ -252,7 +265,7 @@ uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
252265
version = "1.1.0"
253266

254267
[[Pkg]]
255-
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"]
268+
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
256269
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
257270

258271
[[Printf]]
@@ -304,9 +317,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
304317

305318
[[SpecialFunctions]]
306319
deps = ["OpenSpecFun_jll"]
307-
git-tree-sha1 = "268052ee908b2c086cc0011f528694f02f3e2408"
320+
git-tree-sha1 = "e19b98acb182567bcb7b75bb5d9eedf3a3b5ec6c"
308321
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
309-
version = "0.9.0"
322+
version = "0.10.0"
310323

311324
[[StaticArrays]]
312325
deps = ["LinearAlgebra", "Random", "Statistics"]
@@ -320,9 +333,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
320333

321334
[[StatsBase]]
322335
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
323-
git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950"
336+
git-tree-sha1 = "19bfcb46245f69ff4013b3df3b977a289852c3a1"
324337
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
325-
version = "0.32.0"
338+
version = "0.32.2"
326339

327340
[[Test]]
328341
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
@@ -360,10 +373,10 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
360373
version = "1.2.11+8"
361374

362375
[[Zygote]]
363-
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
364-
git-tree-sha1 = "54872ae5411c8795ed52852759796a04fb771f68"
376+
deps = ["ArrayLayouts", "DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
377+
git-tree-sha1 = "7dc5fdb4917ac5a84e199ae654316a01cd4a278b"
365378
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
366-
version = "0.4.7"
379+
version = "0.4.9"
367380

368381
[[ZygoteRules]]
369382
deps = ["MacroTools"]

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
88
CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
99
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
1010
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
11+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1112
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1213
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -19,10 +20,11 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1920
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2021

2122
[compat]
22-
CUDAapi = "2.0.0, 2.1.0, 3.0.0, 3.1.0"
23+
CUDAapi = "2.0.0, 2.1.0, 3.0.0, 3.1.0, 4.0"
2324
CUDAnative = "2.8.0, 2.8.1, 2.9.0, 2.9.1, 2.10.0, 2.10.1, 2.10.2"
2425
CuArrays = "1.4.7, 1.5.0, 1.6.0, 1.7.0, 1.7.1, 1.7.2"
2526
DataStructures = "~0.17"
27+
FillArrays = "^0.8.5"
2628
Flux = "~0.10"
2729
IRTools = "~0.3"
2830
Requires = "^1.0.0"

src/GeometricFlux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Flux
1010
using Flux: glorot_uniform, leakyrelu, GRUCell
1111
using Flux: @functor
1212
using ZygoteRules
13+
using FillArrays: Fill
1314

1415
export
1516

src/scatter.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ end
4646
@adjoint function scatter_add!(ys::AbstractArray, us::AbstractArray, xs::AbstractArray)
4747
ys_ = copy(ys)
4848
scatter_add!(ys_, us, xs)
49-
ys_, Δ -> (Δ, gather(zero(Δ)+Δ, xs), nothing)
49+
ys_, Δ -> (Δ, gather(Δ, xs), nothing)
5050
end
5151

5252
@adjoint function scatter_sub!(ys::AbstractArray, us::AbstractArray, xs::AbstractArray)
5353
ys_ = copy(ys)
5454
scatter_sub!(ys_, us, xs)
55-
ys_, Δ -> (Δ, -gather(zero(Δ)+Δ, xs), nothing)
55+
ys_, Δ -> (Δ, -gather(Δ, xs), nothing)
5656
end
5757

5858
@adjoint function scatter_mul!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
@@ -62,7 +62,7 @@ end
6262
Δy = Δ .+ zero(ys)
6363
scatter_mul!(Δy, us, xs)
6464
rev_xs = gather_indices(xs)
65-
Δu = gather(ys, xs) .* gather(zero(Δ)+Δ, xs)
65+
Δu = gather(ys, xs) .* gather(Δ, xs)
6666
@inbounds for ind = CartesianIndices(xs)
6767
inds = filter(x -> x != ind, rev_xs[xs[ind]])
6868
for i = 1:size(us, 1)
@@ -80,7 +80,7 @@ end
8080
Δy = Δ .+ zero(ys)
8181
scatter_div!(Δy, us, xs)
8282
rev_xs = gather_indices(xs)
83-
Δu = - gather(ys, xs) .* gather(zero(Δ)+Δ, xs) ./ us.^2
83+
Δu = - gather(ys, xs) .* gather(Δ, xs) ./ us.^2
8484
@inbounds for ind = CartesianIndices(xs)
8585
inds = filter(x -> x != ind, rev_xs[xs[ind]])
8686
for i = 1:size(us, 1)
@@ -96,7 +96,7 @@ end
9696
scatter_max!(max, us, xs)
9797
max, function (Δ)
9898
Δy = (ys .== max) .* Δ
99-
Δu = (us .== gather(max, xs)) .* gather(zero(Δ)+Δ, xs)
99+
Δu = (us .== gather(max, xs)) .* gather(Δ, xs)
100100
(Δy, Δu, nothing)
101101
end
102102
end
@@ -106,7 +106,7 @@ end
106106
scatter_min!(min, us, xs)
107107
min, function (Δ)
108108
Δy = (ys .== min) .* Δ
109-
Δu = (us .== gather(min, xs)) .* gather(zero(Δ)+Δ, xs)
109+
Δu = (us .== gather(min, xs)) .* gather(Δ, xs)
110110
(Δy, Δu, nothing)
111111
end
112112
end
@@ -115,7 +115,7 @@ end
115115
ys_ = copy(ys)
116116
scatter_mean!(ys_, us, xs)
117117
ys_, function (Δ)
118-
Δu = gather(zero(Δ)+Δ, xs)
118+
Δu = gather(Δ, xs)
119119
counts = zero.(xs)
120120
@inbounds for i = 1:size(ys, 2)
121121
counts += sum(xs.==i) * (xs.==i)

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ function gather(input::Matrix{T}, index::Array{Int}) where T
3434
return out
3535
end
3636

37+
gather(input::Fill{T,2,<:Any}, index::Array{Int}) where T = gather(Matrix(input), index)
38+
3739
function gather_indices(X::Array{T}) where T
3840
Y = DefaultDict{T,Vector{CartesianIndex}}(CartesianIndex[])
3941
@inbounds for (ind, val) = pairs(X)

0 commit comments

Comments
 (0)