Skip to content

Commit 497f45b

Browse files
authored
Merge branch 'master' into compathelper/new_version/2020-06-20-00-18-36-785-540576185
2 parents 6a86ac9 + be7b4b7 commit 497f45b

File tree

5 files changed

+51
-17
lines changed

5 files changed

+51
-17
lines changed

Manifest.toml

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ version = "0.0.4"
2626

2727
[[ArrayLayouts]]
2828
deps = ["FillArrays", "LinearAlgebra"]
29-
git-tree-sha1 = "a504dca2ac7eda8761c8f7c1ed52427a1be75a3c"
29+
git-tree-sha1 = "89182776a99b69964e995cc2f1e37b5fc3476d56"
3030
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
31-
version = "0.2.6"
31+
version = "0.3.4"
3232

3333
[[Base64]]
3434
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
@@ -62,6 +62,18 @@ git-tree-sha1 = "ac86db2b05fdfec96b011e25a504ffe7476e8a68"
6262
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
6363
version = "3.1.0"
6464

65+
[[ChainRules]]
66+
deps = ["ChainRulesCore", "LinearAlgebra", "Reexport", "Requires", "Statistics"]
67+
git-tree-sha1 = "85f130f2c5ce208a5a395b550802398d2fcc5ee6"
68+
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
69+
version = "0.6.4"
70+
71+
[[ChainRulesCore]]
72+
deps = ["MuladdMacro"]
73+
git-tree-sha1 = "32e2c6e44d4fdd985b5688b5e85c1f6892cf3d15"
74+
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
75+
version = "0.8.0"
76+
6577
[[CodeTracking]]
6678
deps = ["InteractiveUtils", "UUIDs"]
6779
git-tree-sha1 = "9c173f62af93cce8af2bd3527d160b6ddd6eaf81"
@@ -191,9 +203,9 @@ version = "0.2.0"
191203

192204
[[IRTools]]
193205
deps = ["InteractiveUtils", "MacroTools", "Test"]
194-
git-tree-sha1 = "90ee39f9beaaa186e4968417ea2b8ed5673c91c0"
206+
git-tree-sha1 = "6875ae3cfcb9a50af80553d5cc825f406e8d13bc"
195207
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
196-
version = "0.3.3"
208+
version = "0.4.0"
197209

198210
[[Inflate]]
199211
git-tree-sha1 = "f5fc07d4e706b84f72d54eedcc1c13d92fb0871c"
@@ -260,6 +272,11 @@ version = "0.4.3"
260272
[[Mmap]]
261273
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
262274

275+
[[MuladdMacro]]
276+
git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68"
277+
uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
278+
version = "0.2.2"
279+
263280
[[NNlib]]
264281
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
265282
git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824"
@@ -401,10 +418,10 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
401418
version = "1.2.11+11"
402419

403420
[[Zygote]]
404-
deps = ["AbstractFFTs", "ArrayLayouts", "DiffRules", "FillArrays", "ForwardDiff", "Future", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
405-
git-tree-sha1 = "707ceea58e2bd0ff3077ab13a92f8355181d3ee4"
421+
deps = ["AbstractFFTs", "ArrayLayouts", "ChainRules", "FillArrays", "ForwardDiff", "Future", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "Random", "Requires", "Statistics", "ZygoteRules"]
422+
git-tree-sha1 = "6d0f78976db6dbea9a36865efe068e6e2a5db6ed"
406423
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
407-
version = "0.4.20"
424+
version = "0.4.21"
408425

409426
[[ZygoteRules]]
410427
deps = ["MacroTools"]

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
1010
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1111
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1212
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
13-
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
1413
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1514
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1615
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -28,7 +27,7 @@ CuArrays = "1.7.1, 1.7.2, 2.0"
2827
DataStructures = "~0.17"
2928
FillArrays = "^0.8.5"
3029
Flux = "~0.10"
31-
IRTools = "~0.3"
30+
IRTools = "~0.3, 0.4"
3231
LightGraphs = "1.3"
3332
Requires = "^1.0.0"
3433
StaticArrays = "^0.12.1"

src/cuda/pool.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ meanpool(cluster::Array{Int}, X::CuArray{T}, c::Integer=length(Set(cluster))) wh
8585
ind = Tuple(ind)
8686
inds = filter(x -> x != ind, rev_cluster[cluster[ind...]])
8787
for i = 1:size(X, 1)
88-
∇X[i, ind...] *= mapreduce(j -> X[i, j...], *, inds; init=one(T))
88+
multiplier = one(T)
89+
for j = inds
90+
multiplier *= X[i, j...]
91+
end
92+
∇X[i, ind...] *= multiplier
8993
end
9094
end
9195
(nothing, ∇X)
@@ -101,7 +105,11 @@ end
101105
ind = Tuple(ind)
102106
inds = filter(x -> x != ind, rev_cluster[cluster[ind...]])
103107
for i = 1:size(X, 1)
104-
∇X[i, ind...] /= mapreduce(j -> X[i, j...], *, inds; init=one(T))
108+
denom = one(T)
109+
for j = inds
110+
denom *= X[i, j...]
111+
end
112+
∇X[i, ind...] /= denom
105113
end
106114
end
107115
(nothing, ∇X)

src/cuda/scatter.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,11 @@ end
122122
ind = Tuple(ind)
123123
inds = filter(x -> x != ind, rev_xs[xs[ind...]])
124124
for i = 1:size(us, 1)
125-
Δu[i, ind...] *= mapreduce(j -> us[i, j...], *, inds; init=one(T))
125+
multiplier = one(T)
126+
for j = inds
127+
multiplier *= us[i, j...]
128+
end
129+
Δu[i, ind...] *= multiplier
126130
end
127131
end
128132
(Δy, Δu, nothing)
@@ -143,7 +147,11 @@ end
143147
ind = Tuple(ind)
144148
inds = filter(x -> x != ind, rev_xs[xs[ind...]])
145149
for i = 1:size(us, 1)
146-
Δu[i, ind...] /= mapreduce(j -> us[i, j...], *, inds; init=one(T))
150+
denom = one(T)
151+
for j = inds
152+
denom *= us[i, j...]
153+
end
154+
Δu[i, ind...] /= denom
147155
end
148156
end
149157
(Δy, Δu, nothing)

test/cuda/pool.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,14 @@ X = CuArray(reshape(1:24, 2, 3, 4))
123123
end
124124

125125
@testset "divpool" begin
126+
# It seems that y have to convert to CuArray to avoid error,
127+
# instead of broadcastly casting an array
126128
y = 1 ./ [1729, 4480, 27, 40, 315, 352, 55, 72, 391, 432]
127129
y = reshape(y, 2, 5)
128-
@test divpool(CuArray{Int64}(cluster), T.(X)) T.(y)
129-
@test pool(:div, CuArray{Int64}(cluster), T.(X)) T.(y)
130-
@test divpool(cluster, T.(X)) T.(y)
131-
@test pool(:div, cluster, T.(X)) T.(y)
130+
@test divpool(CuArray{Int64}(cluster), T.(X)) CuArray{T}(y)
131+
@test pool(:div, CuArray{Int64}(cluster), T.(X)) CuArray{T}(y)
132+
@test divpool(cluster, T.(X)) CuArray{T}(y)
133+
@test pool(:div, cluster, T.(X)) CuArray{T}(y)
132134
end
133135

134136
@testset "meanpool" begin

0 commit comments

Comments
 (0)