Skip to content

Fix GAT example #280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 119 additions & 52 deletions examples/gat.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,121 @@
using GeometricFlux
using Flux
using Flux: onehotbatch, onecold, logitcrossentropy, throttle
using Flux: @epochs
using JLD2
using Statistics: mean
using SparseArrays
using LinearAlgebra
using Graphs.SimpleGraphs
using Graphs: adjacency_matrix
using CUDA
using Flux
using Flux: onecold
using Flux.Losses: logitcrossentropy
using Flux.Data: DataLoader
using GeometricFlux
using GeometricFlux.Datasets
using GraphSignals
using Graphs
using Parameters: @with_kw
using ProgressMeter: Progress, next!
using Statistics
using Random

function load_data(dataset, batch_size, train_repeats=32, test_repeats=2)
# (train_X, train_y) dim: (num_features, target_dim) × 2708
train_X, train_y = map(x -> Matrix(x), alldata(Planetoid(), dataset, padding=true))
# (test_X, test_y) dim: (num_features, target_dim) × 2708
test_X, test_y = map(x -> Matrix(x), testdata(Planetoid(), dataset, padding=true))
g = graphdata(Planetoid(), dataset)
train_idx = 1:size(train_X, 2)
test_idx = test_indices(Planetoid(), dataset)

add_all_self_loops!(g)
fg = FeaturedGraph(g)
train_data = (repeat(train_X, outer=(1,1,train_repeats)), repeat(train_y, outer=(1,1,train_repeats)))
test_data = (repeat(test_X, outer=(1,1,test_repeats)), repeat(test_y, outer=(1,1,test_repeats)))
train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=true)
test_loader = DataLoader(test_data, batchsize=batch_size, shuffle=true)
return train_loader, test_loader, fg, train_idx, test_idx
end

function add_all_self_loops!(g)
for i in vertices(g)
add_edge!(g, i, i)
end
return g
end

@with_kw mutable struct Args
η = 0.01 # learning rate
batch_size = 8 # batch size
epochs = 20 # number of epochs
seed = 0 # random seed
cuda = true # use GPU
heads = 8 # attention heads
input_dim = 1433 # input dimension
hidden_dim = 16 # hidden dimension
target_dim = 7 # target dimension
end

## Loss: cross entropy
model_loss(model, X, y, idx) =
logitcrossentropy(model(X)[:,idx,:], y[:,idx,:])

accuracy(model, X::AbstractArray, y::AbstractArray, idx) =
mean(onecold(softmax(cpu(model(X))[:,idx,:])) .== onecold(cpu(y)[:,idx,:]))

accuracy(model, loader::DataLoader, device, idx) =
mean(accuracy(model, X |> device, y |> device, idx) for (X, y) in loader)

function train(; kws...)
# load hyperparamters
args = Args(; kws...)
args.seed > 0 && Random.seed!(args.seed)

# GPU config
if args.cuda && CUDA.has_cuda()
device = gpu
@info "Training on GPU"
else
device = cpu
@info "Training on CPU"
end

# load Cora from Planetoid dataset
train_loader, test_loader, fg, train_idx, test_idx = load_data(:cora, args.batch_size)

# build model
model = Chain(
WithGraph(fg, GATConv(args.input_dim=>args.hidden_dim, heads=args.heads)),
WithGraph(fg, GATConv(args.hidden_dim*args.heads=>args.target_dim, heads=args.heads, concat=false)),
) |> device

# ADAM optimizer
opt = ADAM(args.η)

# parameters
ps = Flux.params(model)

# training
train_steps = 0
@info "Start Training, total $(args.epochs) epochs"
for epoch = 1:args.epochs
@info "Epoch $(epoch)"
progress = Progress(length(train_loader))

for (X, y) in train_loader
loss, back = Flux.pullback(ps) do
model_loss(model, X |> device, y |> device, train_idx |> device)
end
train_acc = accuracy(model, train_loader, device, train_idx)
test_acc = accuracy(model, test_loader, device, test_idx)
grad = back(1f0)
Flux.Optimise.update!(opt, ps, grad)

# progress meter
next!(progress; showvalues=[
(:loss, loss),
(:train_accuracy, train_acc),
(:test_accuracy, test_acc)
])

train_steps += 1
end
end

return model, args
end

@load "data/cora_features.jld2" features
@load "data/cora_labels.jld2" labels
@load "data/cora_graph.jld2" g

num_nodes = 2708
num_features = 1433
heads = 8
hidden = 8
target_catg = 7
epochs = 10

## Preprocessing data
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
train_y = Matrix{Float32}(labels) |> gpu # dim: target_catg * num_nodes
A = Matrix{Int}((adjacency_matrix(g) + I) .≥ 1)
fg = FeaturedGraph(A, :adjm)

## Model
model = Chain(GATConv(fg, num_features=>hidden, heads=heads),
Dropout(0.6),
GATConv(fg, hidden*heads=>target_catg, heads=heads, concat=false)
) |> gpu
# test model
@show model(train_X)

## Loss
loss(x, y) = logitcrossentropy(model(x), y)
accuracy(x, y) = mean(onecold(cpu(model(x))) .== onecold(cpu(y)))

# test loss
@show loss(train_X, train_y)

# test gradient
@show gradient(()->loss(train_X, train_y), Flux.params(model))

## Training
ps = Flux.params(model)
train_data = Flux.Data.DataLoader((train_X, train_y), batchsize=num_nodes)
opt = ADAM(0.01)
evalcb() = @show(accuracy(train_X, train_y))

@epochs epochs Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))
model, args = train()
2 changes: 1 addition & 1 deletion src/GeometricFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module GeometricFlux
using DelimitedFiles
using SparseArrays
using Statistics: mean
using LinearAlgebra: Adjoint, norm, Transpose
using LinearAlgebra
using Random
using Reexport

Expand Down
50 changes: 16 additions & 34 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,46 +294,30 @@ end
Flux.trainable(l::GATConv) = (l.weight, l.bias, l.a)

# neighbor attention
function message(gat::GATConv, Xi::AbstractMatrix, Xj::AbstractMatrix, e_ij)
Xi = reshape(Xi, size(Xi)..., 1)
Xj = reshape(Xj, size(Xj)..., 1)
m = message(gat, Xi, Xj, nothing)
return reshape(m, :)
function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractMatrix, u)
X = reshape(X, size(X)..., 1)
M = update_batch_edge(gat, el, E, X, u)
return reshape(M, size(M)[1:2]...)
end

function message(gat::GATConv, Xi::AbstractArray, Xj::AbstractArray, e_ij)
function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractArray, u)
Xi, Xj = _gather(X, el.xs), _gather(X, el.nbrs)
_, nb, bch_sz = size(Xj)
heads = gat.heads
Q = reshape(NNlib.batched_mul(gat.weight, Xi), :, heads, nb, bch_sz) # dims: (out, heads, nb, bch_sz)
K = reshape(NNlib.batched_mul(gat.weight, Xj), :, heads, nb, bch_sz)
V = reshape(NNlib.batched_mul(gat.weight, Xj), :, heads, nb, bch_sz)
QK = vcat(Q, K) # dims: (2out, heads, nb, bch_sz)
A = leakyrelu.(sum(QK .* gat.a, dims=1), gat.negative_slope) # dims: (1, heads, nb, bch_sz)
α = Flux.softmax(A, dims=3) # dims: (1, heads, nb, bch_sz)
return reshape(sum(V .* α, dims=3), :, 1, bch_sz) # dims: (out*heads, 1, bch_sz)
α = indexed_softmax(A, el.xs, el.N, dims=3) # dims: (1, heads, nb, bch_sz)
N = incidence_matrix(el.xs, el.N)
Y = NNlib.batched_mul(reshape(V .* α, :, nb, bch_sz), N) # dims: (out*heads, N, bch_sz)
return Y
end

# graph attention
function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractArray, u)
function _message(gat, el, i, X)
xs = el.xs[el.xs .== i]
nbrs = el.nbrs[el.xs .== i]
Xi = _gather(X, xs)
Xj = _gather(X, nbrs)
return message(gat, Xi, Xj, nothing)
end
hs = [_message(gat, el, i, X) for i in 1:el.N]
return hcat(hs...) # dims: (out*heads, N, [bch_sz])
end

function check_self_loops(sg::SparseGraph)
for i in 1:nv(sg)
if !(i in collect(GraphSignals.rowvalview(sg.S, i)))
return false
end
end
return true
end
aggregate_neighbors(gat::GATConv, el::NamedTuple, aggr, E::AbstractArray) = E # dims: (out*heads, N, [bch_sz])
aggregate_neighbors(gat::GATConv, el::NamedTuple, aggr, E::AbstractMatrix) = E

function update(gat::GATConv, M::AbstractArray, X)
M = M .+ gat.bias
Expand All @@ -353,19 +337,17 @@ function (l::GATConv)(fg::AbstractFeaturedGraph)
X = node_feature(fg)
GraphSignals.check_num_nodes(fg, X)
sg = graph(fg)
@assert Zygote.ignore(() -> check_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
@assert Zygote.ignore(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
el = to_namedtuple(sg)
Ē = update_batch_edge(l, el, nothing, X, nothing)
V = update_batch_vertex(l, el, Ē, X, nothing)
_, V, _ = propagate(l, el, nothing, X, nothing, hcat, nothing, nothing)
return ConcreteFeaturedGraph(fg, nf=V)
end

# For static graph
function (l::GATConv)(el::NamedTuple, X::AbstractArray)
GraphSignals.check_num_nodes(el.N, X)
# TODO: should have self loops check for el
Ē = update_batch_edge(l, el, nothing, X, nothing)
V = update_batch_vertex(l, el, Ē, X, nothing)
_, V, _ = propagate(l, el, nothing, X, nothing, hcat, nothing, nothing)
return V
end

Expand Down Expand Up @@ -490,7 +472,7 @@ function (gat::GATv2Conv)(fg::AbstractFeaturedGraph)
X = node_feature(fg)
GraphSignals.check_num_nodes(fg, X)
sg = graph(fg)
@assert Zygote.ignore(() -> check_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
@assert Zygote.ignore(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
es, nbrs, xs = Zygote.ignore(() -> collect(edges(sg)))
el = (N=nv(sg), E=ne(sg), es=es, nbrs=nbrs, xs=xs)
Ē = update_batch_edge(gat, el, nothing, X, nothing)
Expand Down
42 changes: 42 additions & 0 deletions src/operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,46 @@ aggregate(::typeof(max), X) = maximum(X, dims=2)
aggregate(::typeof(min), X) = minimum(X, dims=2)
aggregate(::typeof(mean), X) = mean(X, dims=2)

function incidence_matrix(xs::AbstractVector{T}, N) where {T}
A = similar(xs, T, size(xs, 1), N)
copyto!(A, Array(I(N))[Array(xs), :])
return A
end

function indexed_softmax(x::AbstractArray, xs, N; dims=1)
# memory pre-allocation approach leads to loss fluctuation but not drop anyway
# be aware of model loss while optimizing this code snippet
as = map(1:N) do i
idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(x))
NNlib.softmax(x[idx...]; dims)
end
return cat(as...; dims)
end

function ∇indexed_softmax(dy::AbstractArray{T}, y::AbstractArray{S}, xs, N; dims=1) where {T,S}
dx = if NNlib.within_grad()
tmp = dy .* y
for i in 1:N
idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(y))
tmp[idx...] .= tmp[idx...] .- y[idx...] .* sum(tmp[idx...]; dims)
end
tmp
else
out = similar(y, promote_type(T,S))
out .= dy .* y
for i in 1:N
idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(y))
out[idx...] .= out[idx...] .- y[idx...] .* sum(out[idx...]; dims)
end
out
end
end

function ChainRulesCore.rrule(::typeof(indexed_softmax), x, xs, N; dims=1)
y = indexed_softmax(x, xs, N; dims)
indexed_softmax_pullback(dy) = (NoTangent(), ∇indexed_softmax(unthunk(dy), y, xs, N; dims), NoTangent(), NoTangent())
return y, indexed_softmax_pullback
end

@non_differentiable batched_index(x...)
@non_differentiable incidence_matrix(x...)