|
194 | 194 | end
|
195 | 195 | end
|
196 | 196 |
|
| 197 | + @testset "GATv2Conv" begin |
| 198 | + adj1 = [1 1 0 1; |
| 199 | + 1 1 1 0; |
| 200 | + 0 1 1 1; |
| 201 | + 1 0 1 1] |
| 202 | + fg1 = FeaturedGraph(adj1) |
| 203 | + |
| 204 | + # isolated_vertex |
| 205 | + adj2 = [1 0 0 1; |
| 206 | + 0 1 0 0; |
| 207 | + 0 0 1 1; |
| 208 | + 1 0 1 1] |
| 209 | + fg2 = FeaturedGraph(adj2) |
| 210 | + |
| 211 | + X = rand(T, in_channel, N) |
| 212 | + Xt = transpose(rand(T, N, in_channel)) |
| 213 | + |
| 214 | + @testset "layer with graph" begin |
| 215 | + for heads = [1, 2], concat = [true, false], adj_gat in [adj1, adj2] |
| 216 | + fg_gat = FeaturedGraph(adj_gat) |
| 217 | + gat2 = GATv2Conv(fg_gat, in_channel=>out_channel, heads=heads, concat=concat) |
| 218 | + |
| 219 | + @test size(gat2.wi) == (out_channel * heads, in_channel) |
| 220 | + @test size(gat2.wi) == (out_channel * heads, in_channel) |
| 221 | + @test size(gat2.biasi) == (out_channel * heads,) |
| 222 | + @test size(gat2.biasj) == (out_channel * heads,) |
| 223 | + @test size(gat2.a) == (out_channel, heads) |
| 224 | + |
| 225 | + Y = gat2(X) |
| 226 | + @test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N)) |
| 227 | + |
| 228 | + # Test with transposed features |
| 229 | + Y = gat2(Xt) |
| 230 | + @test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N)) |
| 231 | + |
| 232 | + g = Zygote.gradient(() -> sum(gat2(X)), Flux.params(gat2)) |
| 233 | + @test length(g.grads) == 5 |
| 234 | + end |
| 235 | + end |
| 236 | + |
| 237 | + @testset "layer without graph" begin |
| 238 | + for heads = [1, 2], concat = [true, false], adj_gat in [adj1, adj2] |
| 239 | + fg_gat = FeaturedGraph(adj_gat, nf=X) |
| 240 | + gat2 = GATv2Conv(in_channel=>out_channel, heads=heads, concat=concat) |
| 241 | + @test size(gat2.wi) == (out_channel * heads, in_channel) |
| 242 | + @test size(gat2.wi) == (out_channel * heads, in_channel) |
| 243 | + @test size(gat2.biasi) == (out_channel * heads,) |
| 244 | + @test size(gat2.biasj) == (out_channel * heads,) |
| 245 | + @test size(gat2.a) == (out_channel, heads) |
| 246 | + |
| 247 | + fg_ = gat2(fg_gat) |
| 248 | + Y = node_feature(fg_) |
| 249 | + @test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N)) |
| 250 | + @test_throws ArgumentError gat2(X) |
| 251 | + |
| 252 | + # Test with transposed features |
| 253 | + fgt = FeaturedGraph(adj_gat, nf=Xt) |
| 254 | + fgt_ = gat2(fgt) |
| 255 | + @test size(node_feature(fgt_)) == (concat ? (out_channel*heads, N) : (out_channel, N)) |
| 256 | + |
| 257 | + g = Zygote.gradient(() -> sum(node_feature(gat2(fg_gat))), Flux.params(gat2)) |
| 258 | + @test length(g.grads) == 7 |
| 259 | + end |
| 260 | + end |
| 261 | + |
| 262 | + @testset "bias=false" begin |
| 263 | + @test length(Flux.params(GATv2Conv(2=>3))) == 5 |
| 264 | + @test length(Flux.params(GATv2Conv(2=>3, bias=false))) == 3 |
| 265 | + end |
| 266 | + end |
| 267 | + |
197 | 268 | # @testset "GatedGraphConv" begin
|
198 | 269 | # num_layers = 3
|
199 | 270 | # X = rand(T, in_channel, N)
|
|
202 | 273 | # ggc = GatedGraphConv(fg, out_channel, num_layers)
|
203 | 274 | # @test adjacency_list(ggc.fg) == [[2,4], [1,3], [2,4], [1,3]]
|
204 | 275 | # @test size(ggc.weight) == (out_channel, out_channel, num_layers)
|
205 |
| - |
206 | 276 | # Y = ggc(X)
|
207 | 277 | # @test size(Y) == (out_channel, N)
|
208 | 278 |
|
|
0 commit comments