Skip to content

Commit abcb235

Browse files
authored
Merge pull request #259 from abieler/gatv2
Adds GATv2 layer
2 parents 795c187 + 47eaaa6 commit abcb235

File tree

3 files changed

+170
-2
lines changed

3 files changed

+170
-2
lines changed

src/GeometricFlux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ export
3535
ChebConv,
3636
GraphConv,
3737
GATConv,
38+
GATv2Conv,
3839
GatedGraphConv,
3940
EdgeConv,
4041
GINConv,

src/layers/conv.jl

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ end
272272

273273
function update_batch_edge(gat::GATConv, fg::AbstractFeaturedGraph, E::AbstractMatrix, X::AbstractMatrix, u)
274274
@assert Zygote.ignore(() -> check_self_loops(graph(fg))) "a vertex must have self loop (receive a message from itself)."
275-
nodes = Zygote.ignore(()->vertices(fg))
275+
nodes = Zygote.ignore(()->vertices(graph(fg)))
276276
nbr = i->cpu(GraphSignals.neighbors(graph(fg), i))
277277
ms = map(i -> graph_attention(gat, i, Zygote.ignore(()->nbr(i)), X), nodes)
278278
M = hcat_by_sum(ms)
@@ -292,7 +292,7 @@ function update_batch_vertex(gat::GATConv, ::AbstractFeaturedGraph, M::AbstractM
292292
M = M .+ gat.bias
293293
if !gat.concat
294294
N = size(M, 2)
295-
M = reshape(mean(reshape(M, :, gat.heads, N), dims=2), :, N)
295+
M = reshape(mean(reshape(M, gat.heads, :, N), dims=1), :, N)
296296
end
297297
return M
298298
end
@@ -315,6 +315,102 @@ function Base.show(io::IO, l::GATConv)
315315
end
316316

317317

318+
"""
319+
GATv2Conv([fg,] in => out;
320+
heads=1,
321+
concat=true,
322+
init=glorot_uniform
323+
negative_slope=0.2)
324+
325+
GATv2 Layer as introduced in https://arxiv.org/abs/2105.14491
326+
327+
# Arguments
328+
329+
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
330+
- `in`: The dimension of input features.
331+
- `out`: The dimension of output features.
332+
- `heads`: Number attention heads
333+
- `concat`: Concatenate layer output or not. If not, layer output is averaged.
334+
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
335+
"""
336+
struct GATv2Conv{V<:AbstractFeaturedGraph, T, A<:AbstractMatrix{T}, B} <: MessagePassing
337+
fg::V
338+
wi::A
339+
wj::A
340+
biasi::B
341+
biasj::B
342+
a::A
343+
negative_slope::T
344+
channel::Pair{Int, Int}
345+
heads::Int
346+
concat::Bool
347+
end
348+
349+
function GATv2Conv(
350+
fg::AbstractFeaturedGraph,
351+
ch::Pair{Int,Int};
352+
heads::Int=1,
353+
concat::Bool=true,
354+
negative_slope=0.2f0,
355+
bias::Bool=true,
356+
init=glorot_uniform,
357+
)
358+
in, out = ch
359+
wi = init(out*heads, in)
360+
wj = init(out*heads, in)
361+
bi = Flux.create_bias(wi, bias, out*heads)
362+
bj = Flux.create_bias(wj, bias, out*heads)
363+
a = init(out, heads)
364+
GATv2Conv(fg, wi, wj, bi, bj, a, negative_slope, ch, heads, concat)
365+
end
366+
367+
GATv2Conv(ch::Pair{Int,Int}; kwargs...) = GATv2Conv(NullGraph(), ch; kwargs...)
368+
369+
@functor GATv2Conv
370+
371+
Flux.trainable(l::GATv2Conv) = (l.wi, l.wj, l.biasi, l.biasj, l.a)
372+
373+
function message(gat::GATv2Conv, x_i::AbstractVector, x_j::AbstractVector)
374+
xi = reshape(gat.wi * x_i + gat.biasi, :, gat.heads)
375+
xj = reshape(gat.wj * x_j + gat.biasj, :, gat.heads)
376+
eij = gat.a' * leakyrelu.(xi + xj, gat.negative_slope)
377+
vcat(eij, xj)
378+
end
379+
380+
function graph_attention(gat::GATv2Conv, i, js, X::AbstractMatrix)
381+
e_ij = mapreduce(j -> GeometricFlux.message(gat, _view(X, i), _view(X, j)), hcat, js)
382+
n = size(e_ij, 1)
383+
αs = Flux.softmax(reshape(view(e_ij, 1, :), gat.heads, :), dims=2)
384+
msgs = view(e_ij, 2:n, :) .* reshape(αs, 1, :)
385+
reshape(msgs, (n-1)*gat.heads, :)
386+
end
387+
388+
function update_batch_edge(gat::GATv2Conv, fg::AbstractFeaturedGraph, E::AbstractMatrix, X::AbstractMatrix, u)
389+
@assert Zygote.ignore(() -> check_self_loops(graph(fg))) "a vertex must have self loop (receive a message from itself)."
390+
nodes = Zygote.ignore(()->vertices(graph(fg)))
391+
nbr = i->cpu(GraphSignals.neighbors(graph(fg), i))
392+
ms = map(i -> graph_attention(gat, i, Zygote.ignore(()->nbr(i)), X), nodes)
393+
M = hcat_by_sum(ms)
394+
return M
395+
end
396+
397+
function update_batch_vertex(gat::GATv2Conv, ::AbstractFeaturedGraph, M::AbstractMatrix, X::AbstractMatrix, u)
398+
if !gat.concat
399+
N = size(M, 2)
400+
M = reshape(mean(reshape(M, gat.heads, :, N), dims=1), :, N)
401+
end
402+
return M
403+
end
404+
405+
function (gat::GATv2Conv)(fg::ConcreteFeaturedGraph, X::AbstractMatrix)
406+
GraphSignals.check_num_nodes(fg, X)
407+
_, X, _ = propagate(gat, fg, edge_feature(fg), X, global_feature(fg), +)
408+
return X
409+
end
410+
411+
(l::GATv2Conv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
412+
413+
318414
"""
319415
GatedGraphConv([fg,] out, num_layers; aggr=+, init=glorot_uniform)
320416

test/layers/conv.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,77 @@
227227
end
228228
end
229229

230+
@testset "GATv2Conv" begin
231+
adj1 = [1 1 0 1;
232+
1 1 1 0;
233+
0 1 1 1;
234+
1 0 1 1]
235+
fg1 = FeaturedGraph(adj1)
236+
237+
# isolated_vertex
238+
adj2 = [1 0 0 1;
239+
0 1 0 0;
240+
0 0 1 1;
241+
1 0 1 1]
242+
fg2 = FeaturedGraph(adj2)
243+
244+
X = rand(T, in_channel, N)
245+
Xt = transpose(rand(T, N, in_channel))
246+
247+
@testset "layer with graph" begin
248+
for heads = [1, 2], concat = [true, false], adj_gat in [adj1, adj2]
249+
fg_gat = FeaturedGraph(adj_gat)
250+
gat2 = GATv2Conv(fg_gat, in_channel=>out_channel, heads=heads, concat=concat)
251+
252+
@test size(gat2.wi) == (out_channel * heads, in_channel)
253+
@test size(gat2.wi) == (out_channel * heads, in_channel)
254+
@test size(gat2.biasi) == (out_channel * heads,)
255+
@test size(gat2.biasj) == (out_channel * heads,)
256+
@test size(gat2.a) == (out_channel, heads)
257+
258+
Y = gat2(X)
259+
@test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N))
260+
261+
# Test with transposed features
262+
Y = gat2(Xt)
263+
@test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N))
264+
265+
g = Zygote.gradient(() -> sum(gat2(X)), Flux.params(gat2))
266+
@test length(g.grads) == 5
267+
end
268+
end
269+
270+
@testset "layer without graph" begin
271+
for heads = [1, 2], concat = [true, false], adj_gat in [adj1, adj2]
272+
fg_gat = FeaturedGraph(adj_gat, nf=X)
273+
gat2 = GATv2Conv(in_channel=>out_channel, heads=heads, concat=concat)
274+
@test size(gat2.wi) == (out_channel * heads, in_channel)
275+
@test size(gat2.wi) == (out_channel * heads, in_channel)
276+
@test size(gat2.biasi) == (out_channel * heads,)
277+
@test size(gat2.biasj) == (out_channel * heads,)
278+
@test size(gat2.a) == (out_channel, heads)
279+
280+
fg_ = gat2(fg_gat)
281+
Y = node_feature(fg_)
282+
@test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N))
283+
@test_throws ArgumentError gat2(X)
284+
285+
# Test with transposed features
286+
fgt = FeaturedGraph(adj_gat, nf=Xt)
287+
fgt_ = gat2(fgt)
288+
@test size(node_feature(fgt_)) == (concat ? (out_channel*heads, N) : (out_channel, N))
289+
290+
g = Zygote.gradient(() -> sum(node_feature(gat2(fg_gat))), Flux.params(gat2))
291+
@test length(g.grads) == 7
292+
end
293+
end
294+
295+
@testset "bias=false" begin
296+
@test length(Flux.params(GATv2Conv(2=>3))) == 5
297+
@test length(Flux.params(GATv2Conv(2=>3, bias=false))) == 3
298+
end
299+
end
300+
230301
@testset "GatedGraphConv" begin
231302
num_layers = 3
232303
X = rand(T, in_channel, N)

0 commit comments

Comments
 (0)