Skip to content

Commit d1cef32

Browse files
committed
add self loops
avoid using break down tensors in graph attention
1 parent 8bf0afd commit d1cef32

File tree

4 files changed

+32
-31
lines changed

4 files changed

+32
-31
lines changed

examples/gat.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Flux.Data: DataLoader
66
using GeometricFlux
77
using GeometricFlux.Datasets
88
using GraphSignals
9+
using Graphs
910
using Parameters: @with_kw
1011
using ProgressMeter: Progress, next!
1112
using Statistics
@@ -20,6 +21,7 @@ function load_data(dataset, batch_size, train_repeats=32, test_repeats=2)
2021
train_idx = 1:size(train_X, 2)
2122
test_idx = test_indices(Planetoid(), dataset)
2223

24+
add_all_self_loops!(g)
2325
fg = FeaturedGraph(g)
2426
train_data = (repeat(train_X, outer=(1,1,train_repeats)), repeat(train_y, outer=(1,1,train_repeats)))
2527
test_data = (repeat(test_X, outer=(1,1,test_repeats)), repeat(test_y, outer=(1,1,test_repeats)))
@@ -28,6 +30,13 @@ function load_data(dataset, batch_size, train_repeats=32, test_repeats=2)
2830
return train_loader, test_loader, fg, train_idx, test_idx
2931
end
3032

33+
function add_all_self_loops!(g)
34+
for i in vertices(g)
35+
add_edge!(g, i, i)
36+
end
37+
return g
38+
end
39+
3140
@with_kw mutable struct Args
3241
η = 0.01 # learning rate
3342
batch_size = 8 # batch size
@@ -70,7 +79,6 @@ function train(; kws...)
7079
# build model
7180
model = Chain(
7281
WithGraph(fg, GATConv(args.input_dim=>args.hidden_dim, heads=args.heads)),
73-
Dropout(0.6),
7482
WithGraph(fg, GATConv(args.hidden_dim*args.heads=>args.target_dim, heads=args.heads, concat=false)),
7583
) |> device
7684

src/GeometricFlux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module GeometricFlux
33
using DelimitedFiles
44
using SparseArrays
55
using Statistics: mean
6-
using LinearAlgebra: Adjoint, norm, Transpose
6+
using LinearAlgebra
77
using Random
88
using Reexport
99

src/layers/conv.jl

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -294,43 +294,30 @@ end
294294
Flux.trainable(l::GATConv) = (l.weight, l.bias, l.a)
295295

296296
# neighbor attention
297-
function message(gat::GATConv, Xi::AbstractMatrix, Xj::AbstractMatrix, e_ij)
298-
Xi = reshape(Xi, size(Xi)..., 1)
299-
Xj = reshape(Xj, size(Xj)..., 1)
300-
m = message(gat, Xi, Xj, nothing)
301-
return reshape(m, :)
297+
function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractMatrix, u)
298+
X = reshape(X, size(X)..., 1)
299+
M = update_batch_edge(gat, el, E, X, u)
300+
return reshape(M, size(M)[1:2]...)
302301
end
303302

304-
function message(gat::GATConv, Xi::AbstractArray, Xj::AbstractArray, e_ij)
303+
function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractArray, u)
304+
Xi, Xj = _gather(X, el.xs), _gather(X, el.nbrs)
305305
_, nb, bch_sz = size(Xj)
306306
heads = gat.heads
307307
Q = reshape(NNlib.batched_mul(gat.weight, Xi), :, heads, nb, bch_sz) # dims: (out, heads, nb, bch_sz)
308308
K = reshape(NNlib.batched_mul(gat.weight, Xj), :, heads, nb, bch_sz)
309309
V = reshape(NNlib.batched_mul(gat.weight, Xj), :, heads, nb, bch_sz)
310310
QK = vcat(Q, K) # dims: (2out, heads, nb, bch_sz)
311311
A = leakyrelu.(sum(QK .* gat.a, dims=1), gat.negative_slope) # dims: (1, heads, nb, bch_sz)
312-
α = Flux.softmax(A, dims=3) # dims: (1, heads, nb, bch_sz)
313-
return reshape(sum(V .* α, dims=3), :, 1, bch_sz) # dims: (out*heads, 1, bch_sz)
314-
end
315-
316-
# graph attention
317-
function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractArray, u)
318-
function _message(gat, el, i, X)
319-
xs = el.xs[el.xs .== i]
320-
nbrs = el.nbrs[el.xs .== i]
321-
Xi = _gather(X, xs)
322-
Xj = _gather(X, nbrs)
323-
return message(gat, Xi, Xj, nothing)
324-
end
325-
hs = [_message(gat, el, i, X) for i in 1:el.N]
326-
return hcat(hs...) # dims: (out*heads, N, [bch_sz])
312+
A = Flux.softmax(A, dims=3) # dims: (1, heads, nb, bch_sz)
313+
A = reshape(V .* A, :, nb, bch_sz)
314+
N = incidence_matrix(el.xs, el.N)
315+
return NNlib.batched_mul(A, N) # dims: (out*heads, N, bch_sz)
327316
end
328317

329-
update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractArray, u) =
330-
[update_batch_edge(gat, el, X, i) for i in 1:el.N]
331-
332318
# graph attention
333-
aggregate_neighbors(gat::GATConv, el::NamedTuple, aggr, E) = aggr(E...) # dims: (out, N, heads, [bch_sz])
319+
aggregate_neighbors(gat::GATConv, el::NamedTuple, aggr, E::AbstractArray) = E # dims: (out*heads, N, [bch_sz])
320+
aggregate_neighbors(gat::GATConv, el::NamedTuple, aggr, E::AbstractMatrix) = E
334321

335322
function update(gat::GATConv, M::AbstractArray, X)
336323
M = M .+ gat.bias
@@ -342,7 +329,7 @@ function update(gat::GATConv, M::AbstractArray, X)
342329
M = gat.σ.(mean(M, dims=2))
343330
M = reshape(M, :, dims...) # dims: (out, N, [bch_sz])
344331
end
345-
return _reshape(M)
332+
return M
346333
end
347334

348335
# For variable graph
@@ -360,8 +347,7 @@ end
360347
function (l::GATConv)(el::NamedTuple, X::AbstractArray)
361348
GraphSignals.check_num_nodes(el.N, X)
362349
# TODO: should have self loops check for el
363-
= update_batch_edge(l, el, nothing, X, nothing)
364-
V = update_batch_vertex(l, el, Ē, X, nothing)
350+
_, V, _ = propagate(l, el, nothing, X, nothing, hcat, nothing, nothing)
365351
return V
366352
end
367353

@@ -486,7 +472,7 @@ function (gat::GATv2Conv)(fg::AbstractFeaturedGraph)
486472
X = node_feature(fg)
487473
GraphSignals.check_num_nodes(fg, X)
488474
sg = graph(fg)
489-
@assert Zygote.ignore(() -> check_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
475+
@assert Zygote.ignore(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
490476
es, nbrs, xs = Zygote.ignore(() -> collect(edges(sg)))
491477
el = (N=nv(sg), E=ne(sg), es=es, nbrs=nbrs, xs=xs)
492478
= update_batch_edge(gat, el, nothing, X, nothing)

src/operation.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,11 @@ aggregate(::typeof(max), X) = maximum(X, dims=2)
2222
aggregate(::typeof(min), X) = minimum(X, dims=2)
2323
aggregate(::typeof(mean), X) = mean(X, dims=2)
2424

25+
function incidence_matrix(xs::AbstractVector{T}, N) where {T}
26+
A = similar(xs, T, size(xs, 1), N)
27+
copyto!(A, Array(I(N))[Array(xs), :])
28+
return A
29+
end
30+
2531
@non_differentiable batched_index(x...)
32+
@non_differentiable incidence_matrix(x...)

0 commit comments

Comments
 (0)