Skip to content

Commit 5c0f20f

Browse files
committed
add tests
Conflicts: Project.toml
1 parent d0a7227 commit 5c0f20f

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ end
446446
function update_batch_vertex(gat::GATv2Conv, ::AbstractFeaturedGraph, M::AbstractMatrix, X::AbstractMatrix, u)
447447
if !gat.concat
448448
N = size(M, 2)
449-
M = reshape(mean(reshape(M, :, gat.heads, N), dims=2), :, N)
449+
M = reshape(mean(reshape(M, gat.heads, :, N), dims=1), :, N)
450450
end
451451
return M
452452
end

test/layers/conv.jl

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,77 @@
194194
end
195195
end
196196

197+
@testset "GATv2Conv" begin
198+
adj1 = [1 1 0 1;
199+
1 1 1 0;
200+
0 1 1 1;
201+
1 0 1 1]
202+
fg1 = FeaturedGraph(adj1)
203+
204+
# isolated_vertex
205+
adj2 = [1 0 0 1;
206+
0 1 0 0;
207+
0 0 1 1;
208+
1 0 1 1]
209+
fg2 = FeaturedGraph(adj2)
210+
211+
X = rand(T, in_channel, N)
212+
Xt = transpose(rand(T, N, in_channel))
213+
214+
@testset "layer with graph" begin
215+
for heads = [1, 2], concat = [true, false], adj_gat in [adj1, adj2]
216+
fg_gat = FeaturedGraph(adj_gat)
217+
gat2 = GATv2Conv(fg_gat, in_channel=>out_channel, heads=heads, concat=concat)
218+
219+
@test size(gat2.wi) == (out_channel * heads, in_channel)
220+
@test size(gat2.wi) == (out_channel * heads, in_channel)
221+
@test size(gat2.biasi) == (out_channel * heads,)
222+
@test size(gat2.biasj) == (out_channel * heads,)
223+
@test size(gat2.a) == (out_channel, heads)
224+
225+
Y = gat2(X)
226+
@test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N))
227+
228+
# Test with transposed features
229+
Y = gat2(Xt)
230+
@test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N))
231+
232+
g = Zygote.gradient(() -> sum(gat2(X)), Flux.params(gat2))
233+
@test length(g.grads) == 5
234+
end
235+
end
236+
237+
@testset "layer without graph" begin
238+
for heads = [1, 2], concat = [true, false], adj_gat in [adj1, adj2]
239+
fg_gat = FeaturedGraph(adj_gat, nf=X)
240+
gat2 = GATv2Conv(in_channel=>out_channel, heads=heads, concat=concat)
241+
@test size(gat2.wi) == (out_channel * heads, in_channel)
242+
@test size(gat2.wi) == (out_channel * heads, in_channel)
243+
@test size(gat2.biasi) == (out_channel * heads,)
244+
@test size(gat2.biasj) == (out_channel * heads,)
245+
@test size(gat2.a) == (out_channel, heads)
246+
247+
fg_ = gat2(fg_gat)
248+
Y = node_feature(fg_)
249+
@test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N))
250+
@test_throws ArgumentError gat2(X)
251+
252+
# Test with transposed features
253+
fgt = FeaturedGraph(adj_gat, nf=Xt)
254+
fgt_ = gat2(fgt)
255+
@test size(node_feature(fgt_)) == (concat ? (out_channel*heads, N) : (out_channel, N))
256+
257+
g = Zygote.gradient(() -> sum(node_feature(gat2(fg_gat))), Flux.params(gat2))
258+
@test length(g.grads) == 7
259+
end
260+
end
261+
262+
@testset "bias=false" begin
263+
@test length(Flux.params(GATv2Conv(2=>3))) == 5
264+
@test length(Flux.params(GATv2Conv(2=>3, bias=false))) == 3
265+
end
266+
end
267+
197268
# @testset "GatedGraphConv" begin
198269
# num_layers = 3
199270
# X = rand(T, in_channel, N)
@@ -202,7 +273,6 @@
202273
# ggc = GatedGraphConv(fg, out_channel, num_layers)
203274
# @test adjacency_list(ggc.fg) == [[2,4], [1,3], [2,4], [1,3]]
204275
# @test size(ggc.weight) == (out_channel, out_channel, num_layers)
205-
206276
# Y = ggc(X)
207277
# @test size(Y) == (out_channel, N)
208278

0 commit comments

Comments
 (0)