Skip to content

Commit b6b6717

Browse files
authored
Merge pull request #280 from FluxML/gat
Fix GAT example
2 parents 0551528 + df36e1c commit b6b6717

File tree

4 files changed

+178
-87
lines changed

4 files changed

+178
-87
lines changed

examples/gat.jl

Lines changed: 119 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,121 @@
1-
using GeometricFlux
2-
using Flux
3-
using Flux: onehotbatch, onecold, logitcrossentropy, throttle
4-
using Flux: @epochs
5-
using JLD2
6-
using Statistics: mean
7-
using SparseArrays
8-
using LinearAlgebra
9-
using Graphs.SimpleGraphs
10-
using Graphs: adjacency_matrix
111
using CUDA
2+
using Flux
3+
using Flux: onecold
4+
using Flux.Losses: logitcrossentropy
5+
using Flux.Data: DataLoader
6+
using GeometricFlux
7+
using GeometricFlux.Datasets
8+
using GraphSignals
9+
using Graphs
10+
using Parameters: @with_kw
11+
using ProgressMeter: Progress, next!
12+
using Statistics
13+
using Random
14+
15+
function load_data(dataset, batch_size, train_repeats=32, test_repeats=2)
16+
# (train_X, train_y) dim: (num_features, target_dim) × 2708
17+
train_X, train_y = map(x -> Matrix(x), alldata(Planetoid(), dataset, padding=true))
18+
# (test_X, test_y) dim: (num_features, target_dim) × 2708
19+
test_X, test_y = map(x -> Matrix(x), testdata(Planetoid(), dataset, padding=true))
20+
g = graphdata(Planetoid(), dataset)
21+
train_idx = 1:size(train_X, 2)
22+
test_idx = test_indices(Planetoid(), dataset)
23+
24+
add_all_self_loops!(g)
25+
fg = FeaturedGraph(g)
26+
train_data = (repeat(train_X, outer=(1,1,train_repeats)), repeat(train_y, outer=(1,1,train_repeats)))
27+
test_data = (repeat(test_X, outer=(1,1,test_repeats)), repeat(test_y, outer=(1,1,test_repeats)))
28+
train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=true)
29+
test_loader = DataLoader(test_data, batchsize=batch_size, shuffle=true)
30+
return train_loader, test_loader, fg, train_idx, test_idx
31+
end
32+
33+
function add_all_self_loops!(g)
34+
for i in vertices(g)
35+
add_edge!(g, i, i)
36+
end
37+
return g
38+
end
39+
40+
@with_kw mutable struct Args
41+
η = 0.01 # learning rate
42+
batch_size = 8 # batch size
43+
epochs = 20 # number of epochs
44+
seed = 0 # random seed
45+
cuda = true # use GPU
46+
heads = 8 # attention heads
47+
input_dim = 1433 # input dimension
48+
hidden_dim = 16 # hidden dimension
49+
target_dim = 7 # target dimension
50+
end
51+
52+
## Loss: cross entropy
53+
model_loss(model, X, y, idx) =
54+
logitcrossentropy(model(X)[:,idx,:], y[:,idx,:])
55+
56+
accuracy(model, X::AbstractArray, y::AbstractArray, idx) =
57+
mean(onecold(softmax(cpu(model(X))[:,idx,:])) .== onecold(cpu(y)[:,idx,:]))
58+
59+
accuracy(model, loader::DataLoader, device, idx) =
60+
mean(accuracy(model, X |> device, y |> device, idx) for (X, y) in loader)
61+
62+
function train(; kws...)
63+
# load hyperparamters
64+
args = Args(; kws...)
65+
args.seed > 0 && Random.seed!(args.seed)
66+
67+
# GPU config
68+
if args.cuda && CUDA.has_cuda()
69+
device = gpu
70+
@info "Training on GPU"
71+
else
72+
device = cpu
73+
@info "Training on CPU"
74+
end
75+
76+
# load Cora from Planetoid dataset
77+
train_loader, test_loader, fg, train_idx, test_idx = load_data(:cora, args.batch_size)
78+
79+
# build model
80+
model = Chain(
81+
WithGraph(fg, GATConv(args.input_dim=>args.hidden_dim, heads=args.heads)),
82+
WithGraph(fg, GATConv(args.hidden_dim*args.heads=>args.target_dim, heads=args.heads, concat=false)),
83+
) |> device
84+
85+
# ADAM optimizer
86+
opt = ADAM(args.η)
87+
88+
# parameters
89+
ps = Flux.params(model)
90+
91+
# training
92+
train_steps = 0
93+
@info "Start Training, total $(args.epochs) epochs"
94+
for epoch = 1:args.epochs
95+
@info "Epoch $(epoch)"
96+
progress = Progress(length(train_loader))
97+
98+
for (X, y) in train_loader
99+
loss, back = Flux.pullback(ps) do
100+
model_loss(model, X |> device, y |> device, train_idx |> device)
101+
end
102+
train_acc = accuracy(model, train_loader, device, train_idx)
103+
test_acc = accuracy(model, test_loader, device, test_idx)
104+
grad = back(1f0)
105+
Flux.Optimise.update!(opt, ps, grad)
106+
107+
# progress meter
108+
next!(progress; showvalues=[
109+
(:loss, loss),
110+
(:train_accuracy, train_acc),
111+
(:test_accuracy, test_acc)
112+
])
113+
114+
train_steps += 1
115+
end
116+
end
117+
118+
return model, args
119+
end
12120

13-
@load "data/cora_features.jld2" features
14-
@load "data/cora_labels.jld2" labels
15-
@load "data/cora_graph.jld2" g
16-
17-
num_nodes = 2708
18-
num_features = 1433
19-
heads = 8
20-
hidden = 8
21-
target_catg = 7
22-
epochs = 10
23-
24-
## Preprocessing data
25-
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
26-
train_y = Matrix{Float32}(labels) |> gpu # dim: target_catg * num_nodes
27-
A = Matrix{Int}((adjacency_matrix(g) + I) .≥ 1)
28-
fg = FeaturedGraph(A, :adjm)
29-
30-
## Model
31-
model = Chain(GATConv(fg, num_features=>hidden, heads=heads),
32-
Dropout(0.6),
33-
GATConv(fg, hidden*heads=>target_catg, heads=heads, concat=false)
34-
) |> gpu
35-
# test model
36-
@show model(train_X)
37-
38-
## Loss
39-
loss(x, y) = logitcrossentropy(model(x), y)
40-
accuracy(x, y) = mean(onecold(cpu(model(x))) .== onecold(cpu(y)))
41-
42-
# test loss
43-
@show loss(train_X, train_y)
44-
45-
# test gradient
46-
@show gradient(()->loss(train_X, train_y), Flux.params(model))
47-
48-
## Training
49-
ps = Flux.params(model)
50-
train_data = Flux.Data.DataLoader((train_X, train_y), batchsize=num_nodes)
51-
opt = ADAM(0.01)
52-
evalcb() = @show(accuracy(train_X, train_y))
53-
54-
@epochs epochs Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))
121+
model, args = train()

src/GeometricFlux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module GeometricFlux
33
using DelimitedFiles
44
using SparseArrays
55
using Statistics: mean
6-
using LinearAlgebra: Adjoint, norm, Transpose
6+
using LinearAlgebra
77
using Random
88
using Reexport
99

src/layers/conv.jl

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -294,46 +294,30 @@ end
294294
Flux.trainable(l::GATConv) = (l.weight, l.bias, l.a)
295295

296296
# neighbor attention
297-
function message(gat::GATConv, Xi::AbstractMatrix, Xj::AbstractMatrix, e_ij)
298-
Xi = reshape(Xi, size(Xi)..., 1)
299-
Xj = reshape(Xj, size(Xj)..., 1)
300-
m = message(gat, Xi, Xj, nothing)
301-
return reshape(m, :)
297+
function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractMatrix, u)
298+
X = reshape(X, size(X)..., 1)
299+
M = update_batch_edge(gat, el, E, X, u)
300+
return reshape(M, size(M)[1:2]...)
302301
end
303302

304-
function message(gat::GATConv, Xi::AbstractArray, Xj::AbstractArray, e_ij)
303+
function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractArray, u)
304+
Xi, Xj = _gather(X, el.xs), _gather(X, el.nbrs)
305305
_, nb, bch_sz = size(Xj)
306306
heads = gat.heads
307307
Q = reshape(NNlib.batched_mul(gat.weight, Xi), :, heads, nb, bch_sz) # dims: (out, heads, nb, bch_sz)
308308
K = reshape(NNlib.batched_mul(gat.weight, Xj), :, heads, nb, bch_sz)
309309
V = reshape(NNlib.batched_mul(gat.weight, Xj), :, heads, nb, bch_sz)
310310
QK = vcat(Q, K) # dims: (2out, heads, nb, bch_sz)
311311
A = leakyrelu.(sum(QK .* gat.a, dims=1), gat.negative_slope) # dims: (1, heads, nb, bch_sz)
312-
α = Flux.softmax(A, dims=3) # dims: (1, heads, nb, bch_sz)
313-
return reshape(sum(V .* α, dims=3), :, 1, bch_sz) # dims: (out*heads, 1, bch_sz)
312+
α = indexed_softmax(A, el.xs, el.N, dims=3) # dims: (1, heads, nb, bch_sz)
313+
N = incidence_matrix(el.xs, el.N)
314+
Y = NNlib.batched_mul(reshape(V .* α, :, nb, bch_sz), N) # dims: (out*heads, N, bch_sz)
315+
return Y
314316
end
315317

316318
# graph attention
317-
function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractArray, u)
318-
function _message(gat, el, i, X)
319-
xs = el.xs[el.xs .== i]
320-
nbrs = el.nbrs[el.xs .== i]
321-
Xi = _gather(X, xs)
322-
Xj = _gather(X, nbrs)
323-
return message(gat, Xi, Xj, nothing)
324-
end
325-
hs = [_message(gat, el, i, X) for i in 1:el.N]
326-
return hcat(hs...) # dims: (out*heads, N, [bch_sz])
327-
end
328-
329-
function check_self_loops(sg::SparseGraph)
330-
for i in 1:nv(sg)
331-
if !(i in collect(GraphSignals.rowvalview(sg.S, i)))
332-
return false
333-
end
334-
end
335-
return true
336-
end
319+
aggregate_neighbors(gat::GATConv, el::NamedTuple, aggr, E::AbstractArray) = E # dims: (out*heads, N, [bch_sz])
320+
aggregate_neighbors(gat::GATConv, el::NamedTuple, aggr, E::AbstractMatrix) = E
337321

338322
function update(gat::GATConv, M::AbstractArray, X)
339323
M = M .+ gat.bias
@@ -353,19 +337,17 @@ function (l::GATConv)(fg::AbstractFeaturedGraph)
353337
X = node_feature(fg)
354338
GraphSignals.check_num_nodes(fg, X)
355339
sg = graph(fg)
356-
@assert Zygote.ignore(() -> check_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
340+
@assert Zygote.ignore(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
357341
el = to_namedtuple(sg)
358-
= update_batch_edge(l, el, nothing, X, nothing)
359-
V = update_batch_vertex(l, el, Ē, X, nothing)
342+
_, V, _ = propagate(l, el, nothing, X, nothing, hcat, nothing, nothing)
360343
return ConcreteFeaturedGraph(fg, nf=V)
361344
end
362345

363346
# For static graph
364347
function (l::GATConv)(el::NamedTuple, X::AbstractArray)
365348
GraphSignals.check_num_nodes(el.N, X)
366349
# TODO: should have self loops check for el
367-
= update_batch_edge(l, el, nothing, X, nothing)
368-
V = update_batch_vertex(l, el, Ē, X, nothing)
350+
_, V, _ = propagate(l, el, nothing, X, nothing, hcat, nothing, nothing)
369351
return V
370352
end
371353

@@ -490,7 +472,7 @@ function (gat::GATv2Conv)(fg::AbstractFeaturedGraph)
490472
X = node_feature(fg)
491473
GraphSignals.check_num_nodes(fg, X)
492474
sg = graph(fg)
493-
@assert Zygote.ignore(() -> check_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
475+
@assert Zygote.ignore(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
494476
es, nbrs, xs = Zygote.ignore(() -> collect(edges(sg)))
495477
el = (N=nv(sg), E=ne(sg), es=es, nbrs=nbrs, xs=xs)
496478
= update_batch_edge(gat, el, nothing, X, nothing)

src/operation.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,46 @@ aggregate(::typeof(max), X) = maximum(X, dims=2)
2222
aggregate(::typeof(min), X) = minimum(X, dims=2)
2323
aggregate(::typeof(mean), X) = mean(X, dims=2)
2424

25+
function incidence_matrix(xs::AbstractVector{T}, N) where {T}
26+
A = similar(xs, T, size(xs, 1), N)
27+
copyto!(A, Array(I(N))[Array(xs), :])
28+
return A
29+
end
30+
31+
function indexed_softmax(x::AbstractArray, xs, N; dims=1)
32+
# memory pre-allocation approach leads to loss fluctuation but not drop anyway
33+
# be aware of model loss while optimizing this code snippet
34+
as = map(1:N) do i
35+
idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(x))
36+
NNlib.softmax(x[idx...]; dims)
37+
end
38+
return cat(as...; dims)
39+
end
40+
41+
function ∇indexed_softmax(dy::AbstractArray{T}, y::AbstractArray{S}, xs, N; dims=1) where {T,S}
42+
dx = if NNlib.within_grad()
43+
tmp = dy .* y
44+
for i in 1:N
45+
idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(y))
46+
tmp[idx...] .= tmp[idx...] .- y[idx...] .* sum(tmp[idx...]; dims)
47+
end
48+
tmp
49+
else
50+
out = similar(y, promote_type(T,S))
51+
out .= dy .* y
52+
for i in 1:N
53+
idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(y))
54+
out[idx...] .= out[idx...] .- y[idx...] .* sum(out[idx...]; dims)
55+
end
56+
out
57+
end
58+
end
59+
60+
function ChainRulesCore.rrule(::typeof(indexed_softmax), x, xs, N; dims=1)
61+
y = indexed_softmax(x, xs, N; dims)
62+
indexed_softmax_pullback(dy) = (NoTangent(), ∇indexed_softmax(unthunk(dy), y, xs, N; dims), NoTangent(), NoTangent())
63+
return y, indexed_softmax_pullback
64+
end
65+
2566
@non_differentiable batched_index(x...)
67+
@non_differentiable incidence_matrix(x...)

0 commit comments

Comments
 (0)