Skip to content

Commit 2b27fd7

Browse files
committed
add tests and refactor graph attention
1 parent 6c20dd5 commit 2b27fd7

File tree

4 files changed

+181
-79
lines changed

4 files changed

+181
-79
lines changed

docs/src/manual/conv.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ Reference: [Graph Attention Networks](https://arxiv.org/abs/1710.10903)
7575

7676
---
7777

78+
## Graph Attentional Layer v2
79+
80+
81+
```@docs
82+
GATv2Conv
83+
```
84+
85+
Reference: [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491)
86+
---
87+
7888
## Gated Graph Convolution Layer
7989

8090
```math

src/layers/conv.jl

Lines changed: 112 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -253,16 +253,16 @@ Graph attentional layer.
253253
254254
```jldoctest
255255
julia> GATConv(1024=>256, relu)
256-
GATConv(1024=>256, heads=1, concat=true, LeakyReLU(λ=0.2))
256+
GATConv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.2))
257257
258258
julia> GATConv(1024=>256, relu, heads=4)
259-
GATConv(1024=>1024, heads=4, concat=true, LeakyReLU(λ=0.2))
259+
GATConv(1024=>1024, relu, heads=4, concat=true, LeakyReLU(λ=0.2))
260260
261261
julia> GATConv(1024=>256, relu, heads=4, concat=false)
262-
GATConv(1024=>1024, heads=4, concat=false, LeakyReLU(λ=0.2))
262+
GATConv(1024=>1024, relu, heads=4, concat=false, LeakyReLU(λ=0.2))
263263
264264
julia> GATConv(1024=>256, relu, negative_slope=0.1f0)
265-
GATConv(1024=>256, heads=1, concat=true, LeakyReLU(λ=0.1))
265+
GATConv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.1))
266266
```
267267
268268
See also [`WithGraph`](@ref) for training layer with static graph.
@@ -282,7 +282,7 @@ function GATConv(ch::Pair{Int,Int}, σ=identity; heads::Int=1, concat::Bool=true
282282
negative_slope=0.2f0, init=glorot_uniform, bias::Bool=true)
283283
in, out = ch
284284
W = init(out*heads, in)
285-
b = Flux.create_bias(W, bias, out, 1, heads)
285+
b = Flux.create_bias(W, bias, out*heads)
286286
a = init(2*out, heads)
287287
GATConv(W, b, a, σ, negative_slope, ch, heads, concat)
288288
end
@@ -297,22 +297,20 @@ Flux.trainable(l::GATConv) = (l.weight, l.bias, l.a)
297297
function message(gat::GATConv, Xi::AbstractMatrix, Xj::AbstractMatrix, e_ij)
298298
Xi = reshape(Xi, size(Xi)..., 1)
299299
Xj = reshape(Xj, size(Xj)..., 1)
300-
A = message(gat, Xi, Xj, nothing)
301-
return reshape(A, size(A)[1:3]...)
300+
m = message(gat, Xi, Xj, nothing)
301+
return reshape(m, :)
302302
end
303303

304304
function message(gat::GATConv, Xi::AbstractArray, Xj::AbstractArray, e_ij)
305305
_, nb, bch_sz = size(Xj)
306306
heads = gat.heads
307-
Q = reshape(NNlib.batched_mul(gat.weight, Xi), :, nb, heads*bch_sz) # dims: (out, nb, heads*bch_sz)
308-
K = reshape(NNlib.batched_mul(gat.weight, Xj), :, nb, heads*bch_sz)
309-
V = reshape(NNlib.batched_mul(gat.weight, Xj), :, nb, heads*bch_sz)
310-
QK = reshape(vcat(Q, K), :, nb, heads, bch_sz) # dims: (2out, nb, heads, bch_sz)
311-
QK = permutedims(QK, (1, 3, 2, 4)) # dims: (2out, heads, nb, bch_sz)
307+
Q = reshape(NNlib.batched_mul(gat.weight, Xi), :, heads, nb, bch_sz) # dims: (out, heads, nb, bch_sz)
308+
K = reshape(NNlib.batched_mul(gat.weight, Xj), :, heads, nb, bch_sz)
309+
V = reshape(NNlib.batched_mul(gat.weight, Xj), :, heads, nb, bch_sz)
310+
QK = vcat(Q, K) # dims: (2out, heads, nb, bch_sz)
312311
A = leakyrelu.(sum(QK .* gat.a, dims=1), gat.negative_slope) # dims: (1, heads, nb, bch_sz)
313-
QK = permutedims(QK, (1, 3, 2, 4)) # dims: (1, nb, heads, bch_sz)
314-
α = Flux.softmax(reshape(A, nb, 1, :), dims=1) # dims: (nb, 1, heads*bch_sz)
315-
return reshape(NNlib.batched_mul(V, α), :, 1, heads, bch_sz) # dims: (out, 1, heads, bch_sz)
312+
α = Flux.softmax(A, dims=3) # dims: (1, heads, nb, bch_sz)
313+
return reshape(sum(V .* α, dims=3), :, 1, bch_sz) # dims: (out*heads, 1, bch_sz)
316314
end
317315

318316
# graph attention
@@ -325,7 +323,7 @@ function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractArray, u)
325323
return message(gat, Xi, Xj, nothing)
326324
end
327325
hs = [_message(gat, el, i, X) for i in 1:el.N]
328-
return hcat(hs...) # dims: (out, N, heads, [bch_sz])
326+
return hcat(hs...) # dims: (out*heads, N, [bch_sz])
329327
end
330328

331329
function check_self_loops(sg::SparseGraph)
@@ -337,20 +335,18 @@ function check_self_loops(sg::SparseGraph)
337335
return true
338336
end
339337

340-
function update(gat::GATConv, M::AbstractArray, X::AbstractArray)
338+
function update(gat::GATConv, M::AbstractArray, X)
341339
M = M .+ gat.bias
342-
if gat.concat
343-
M = gat.σ.(M) # dims: (out, N, heads, [bch_sz])
340+
if gat.concat || gat.heads == 1
341+
M = gat.σ.(M) # dims: (out*heads, N, [bch_sz])
344342
else
345-
M = gat.σ.(mean(M, dims=3))
346-
M = _reshape(M) # dims: (out, N, [bch_sz])
343+
M = reshape(M, :, gat.heads, size(M)[2:end]...)
344+
M = gat.σ.(mean(M, dims=2))
345+
M = reshape(M, :, size(M)[2:end]...) # dims: (out, N, [bch_sz])
347346
end
348347
return M
349348
end
350349

351-
_reshape(M::AbstractArray{<:Real,3}) = reshape(M, size(M)[[1,2]]...)
352-
_reshape(M::AbstractArray{<:Real,4}) = reshape(M, size(M)[[1,2,4]]...)
353-
354350
# For variable graph
355351
function (l::GATConv)(fg::AbstractFeaturedGraph)
356352
X = node_feature(fg)
@@ -365,7 +361,7 @@ end
365361

366362
# For static graph
367363
function (l::GATConv)(el::NamedTuple, X::AbstractArray)
368-
GraphSignals.check_num_nodes(el.N, size(X, 2))
364+
GraphSignals.check_num_nodes(el.N, X)
369365
# TODO: should have self loops check for el
370366
= update_batch_edge(l, el, nothing, X, nothing)
371367
V = update_batch_vertex(l, el, Ē, X, nothing)
@@ -376,6 +372,7 @@ function Base.show(io::IO, l::GATConv)
376372
in_channel = size(l.weight, ndims(l.weight))
377373
out_channel = size(l.weight, ndims(l.weight)-1)
378374
print(io, "GATConv(", in_channel, "=>", out_channel)
375+
l.σ == identity || print(io, ", ", l.σ)
379376
print(io, ", heads=", l.heads)
380377
print(io, ", concat=", l.concat)
381378
print(io, ", LeakyReLU(λ=", l.negative_slope)
@@ -384,99 +381,141 @@ end
384381

385382

386383
"""
387-
GATv2Conv([fg,] in => out;
388-
heads=1,
389-
concat=true,
390-
init=glorot_uniform
391-
negative_slope=0.2)
384+
GATv2Conv(in => out, σ=identity; heads=1, concat=true,
385+
init=glorot_uniform, negative_slope=0.2)
392386
393-
GATv2 Layer as introduced in https://arxiv.org/abs/2105.14491
387+
Graph attentional layer v2.
394388
395389
# Arguments
396390
397-
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
398391
- `in`: The dimension of input features.
399392
- `out`: The dimension of output features.
393+
- `σ`: Activation function.
400394
- `heads`: Number attention heads
401395
- `concat`: Concatenate layer output or not. If not, layer output is averaged.
402396
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
397+
398+
# Examples
399+
400+
```jldoctest
401+
julia> GATv2Conv(1024=>256, relu)
402+
GATv2Conv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.2))
403+
404+
julia> GATv2Conv(1024=>256, relu, heads=4)
405+
GATv2Conv(1024=>1024, relu, heads=4, concat=true, LeakyReLU(λ=0.2))
406+
407+
julia> GATv2Conv(1024=>256, relu, heads=4, concat=false)
408+
GATv2Conv(1024=>1024, relu, heads=4, concat=false, LeakyReLU(λ=0.2))
409+
410+
julia> GATv2Conv(1024=>256, relu, negative_slope=0.1f0)
411+
GATv2Conv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.1))
412+
```
413+
414+
See also [`WithGraph`](@ref) for training layer with static graph.
403415
"""
404-
struct GATv2Conv{V<:AbstractFeaturedGraph, T, A<:AbstractMatrix{T}, B} <: MessagePassing
405-
fg::V
416+
struct GATv2Conv{T, A<:AbstractMatrix{T}, B, F} <: MessagePassing
406417
wi::A
407418
wj::A
408419
biasi::B
409420
biasj::B
410421
a::A
422+
σ::F
411423
negative_slope::T
412424
channel::Pair{Int, Int}
413425
heads::Int
414426
concat::Bool
415427
end
416428

417-
function GATv2Conv(
418-
fg::AbstractFeaturedGraph,
419-
ch::Pair{Int,Int};
420-
heads::Int=1,
421-
concat::Bool=true,
422-
negative_slope=0.2f0,
423-
bias::Bool=true,
424-
init=glorot_uniform,
425-
)
429+
function GATv2Conv(ch::Pair{Int,Int}, σ=identity; heads::Int=1, concat::Bool=true,
430+
negative_slope=0.2f0, bias::Bool=true, init=glorot_uniform)
426431
in, out = ch
427432
wi = init(out*heads, in)
428433
wj = init(out*heads, in)
429434
bi = Flux.create_bias(wi, bias, out*heads)
430435
bj = Flux.create_bias(wj, bias, out*heads)
431436
a = init(out, heads)
432-
GATv2Conv(fg, wi, wj, bi, bj, a, negative_slope, ch, heads, concat)
437+
GATv2Conv(wi, wj, bi, bj, a, σ, negative_slope, ch, heads, concat)
433438
end
434439

435-
GATv2Conv(ch::Pair{Int,Int}; kwargs...) = GATv2Conv(NullGraph(), ch; kwargs...)
436-
437440
@functor GATv2Conv
438441

439442
Flux.trainable(l::GATv2Conv) = (l.wi, l.wj, l.biasi, l.biasj, l.a)
440443

441-
function message(gat::GATv2Conv, x_i::AbstractVector, x_j::AbstractVector)
442-
xi = reshape(gat.wi * x_i + gat.biasi, :, gat.heads)
443-
xj = reshape(gat.wj * x_j + gat.biasj, :, gat.heads)
444-
eij = gat.a' * leakyrelu.(xi + xj, gat.negative_slope)
445-
vcat(eij, xj)
444+
function message(gat::GATv2Conv, Xi::AbstractMatrix, Xj::AbstractMatrix, e_ij)
445+
Xi = reshape(Xi, size(Xi)..., 1)
446+
Xj = reshape(Xj, size(Xj)..., 1)
447+
m = message(gat, Xi, Xj, nothing)
448+
return reshape(m, :)
446449
end
447450

448-
function graph_attention(gat::GATv2Conv, i, js, X::AbstractMatrix)
449-
e_ij = mapreduce(j -> GeometricFlux.message(gat, _view(X, i), _view(X, j)), hcat, js)
450-
n = size(e_ij, 1)
451-
αs = Flux.softmax(reshape(view(e_ij, 1, :), gat.heads, :), dims=2)
452-
msgs = view(e_ij, 2:n, :) .* reshape(αs, 1, :)
453-
reshape(msgs, (n-1)*gat.heads, :)
451+
function message(gat::GATv2Conv, Xi::AbstractArray, Xj::AbstractArray, e_ij)
452+
_, nb, bch_sz = size(Xj)
453+
heads = gat.heads
454+
Q = reshape(NNlib.batched_mul(gat.wi, Xi) .+ gat.biasi, :, heads, nb, bch_sz) # dims: (out, heads, nb, bch_sz)
455+
K = reshape(NNlib.batched_mul(gat.wj, Xj) .+ gat.biasj, :, heads, nb, bch_sz)
456+
V = reshape(NNlib.batched_mul(gat.wj, Xj) .+ gat.biasj, :, heads, nb, bch_sz)
457+
QK = Q + K # dims: (out, heads, nb, bch_sz)
458+
A = leakyrelu.(sum(QK .* gat.a, dims=1), gat.negative_slope) # dims: (1, heads, nb, bch_sz)
459+
α = Flux.softmax(A, dims=3) # dims: (1, heads, nb, bch_sz)
460+
return reshape(sum(V .* α, dims=3), :, 1, bch_sz) # dims: (out*heads, 1, bch_sz)
454461
end
455462

456-
function update_batch_edge(gat::GATv2Conv, fg::AbstractFeaturedGraph, E::AbstractMatrix, X::AbstractMatrix, u)
457-
@assert Zygote.ignore(() -> check_self_loops(graph(fg))) "a vertex must have self loop (receive a message from itself)."
458-
nodes = Zygote.ignore(()->vertices(graph(fg)))
459-
nbr = i->cpu(GraphSignals.neighbors(graph(fg), i))
460-
ms = map(i -> graph_attention(gat, i, Zygote.ignore(()->nbr(i)), X), nodes)
461-
M = hcat_by_sum(ms)
462-
return M
463+
function update_batch_edge(gat::GATv2Conv, el::NamedTuple, E, X::AbstractArray, u)
464+
function _message(gat, el, i, X)
465+
xs = el.xs[el.xs .== i]
466+
nbrs = el.nbrs[el.xs .== i]
467+
Xi = _gather(X, xs)
468+
Xj = _gather(X, nbrs)
469+
return message(gat, Xi, Xj, nothing)
470+
end
471+
hs = [_message(gat, el, i, X) for i in 1:el.N]
472+
return hcat(hs...) # dims: (out*heads, N, [bch_sz])
463473
end
464474

465-
function update_batch_vertex(gat::GATv2Conv, ::AbstractFeaturedGraph, M::AbstractMatrix, X::AbstractMatrix, u)
466-
if !gat.concat
467-
N = size(M, 2)
468-
M = reshape(mean(reshape(M, :, gat.heads, N), dims=2), :, N)
475+
function update(gat::GATv2Conv, M::AbstractArray, X)
476+
if gat.concat || gat.heads == 1
477+
M = gat.σ.(M) # dims: (out*heads, N, [bch_sz])
478+
else
479+
M = reshape(M, :, gat.heads, size(M)[2:end]...)
480+
M = gat.σ.(mean(M, dims=2))
481+
M = reshape(M, :, size(M)[2:end]...) # dims: (out, N, [bch_sz])
469482
end
470483
return M
471484
end
472485

473-
function (gat::GATv2Conv)(fg::ConcreteFeaturedGraph, X::AbstractMatrix)
486+
# For variable graph
487+
function (gat::GATv2Conv)(fg::AbstractFeaturedGraph)
488+
X = node_feature(fg)
474489
GraphSignals.check_num_nodes(fg, X)
475-
_, X, _ = propagate(gat, fg, edge_feature(fg), X, global_feature(fg), +)
476-
return X
490+
sg = graph(fg)
491+
@assert Zygote.ignore(() -> check_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
492+
es, nbrs, xs = Zygote.ignore(() -> collect(edges(sg)))
493+
el = (N=nv(sg), E=ne(sg), es=es, nbrs=nbrs, xs=xs)
494+
= update_batch_edge(gat, el, nothing, X, nothing)
495+
V = update_batch_vertex(gat, el, Ē, X, nothing)
496+
return ConcreteFeaturedGraph(fg, nf=V)
497+
end
498+
499+
# For static graph
500+
function (l::GATv2Conv)(el::NamedTuple, X::AbstractArray)
501+
GraphSignals.check_num_nodes(el.N, X)
502+
# TODO: should have self loops check for el
503+
= update_batch_edge(l, el, nothing, X, nothing)
504+
V = update_batch_vertex(l, el, Ē, X, nothing)
505+
return V
506+
end
507+
508+
function Base.show(io::IO, l::GATv2Conv)
509+
in_channel = size(l.wi, ndims(l.wi))
510+
out_channel = size(l.wi, ndims(l.wi)-1)
511+
print(io, "GATv2Conv(", in_channel, "=>", out_channel)
512+
l.σ == identity || print(io, ", ", l.σ)
513+
print(io, ", heads=", l.heads)
514+
print(io, ", concat=", l.concat)
515+
print(io, ", LeakyReLU(λ=", l.negative_slope)
516+
print(io, "))")
477517
end
478518

479-
(l::GATv2Conv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
480519

481520

482521
"""
@@ -569,7 +608,7 @@ Edge convolutional layer.
569608
570609
# Arguments
571610
572-
- `nn`: A neural network (e.g. a Dense layer or a MLP).
611+
- `nn`: A neural network (e.g. a Dense layer or a MLP).
573612
- `aggr`: An aggregate function applied to the result of message function.
574613
`+`, `max` and `mean` are available.
575614

test/cuda/conv.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,13 @@
104104
@testset "layer without graph" begin
105105
gat = GATConv(in_channel=>out_channel, heads=2) |> gpu
106106
@test size(gat.weight) == (out_channel * heads, in_channel)
107-
@test size(gat.bias) == (out_channel, 1, heads)
107+
@test size(gat.bias) == (out_channel * heads,)
108108
@test size(gat.a) == (2*out_channel, heads)
109109

110110
X = rand(T, in_channel, N)
111111
fg = FeaturedGraph(adj, nf=X) |> gpu
112112
fg_ = gat(fg)
113-
@test size(node_feature(fg_)) == (out_channel, N, heads)
113+
@test size(node_feature(fg_)) == (out_channel * heads, N)
114114

115115
g = Zygote.gradient(() -> sum(node_feature(gat(fg))), Flux.params(gat))
116116
@test length(g.grads) == 5
@@ -120,7 +120,7 @@
120120
X = rand(T, in_channel, N, batch_size)
121121
gat = WithGraph(fg, GATConv(in_channel=>out_channel, heads=2)) |> gpu
122122
Y = gat(X |> gpu)
123-
@test size(Y) == (out_channel, N, heads, batch_size)
123+
@test size(Y) == (out_channel * heads, N, batch_size)
124124

125125
g = Zygote.gradient(() -> sum(gat(X |> gpu)), Flux.params(gat))
126126
@test length(g.grads) == 4

0 commit comments

Comments
 (0)