Skip to content

Commit 7131803

Browse files
authored
Merge pull request #41 from kshyatt/ksh/ganex
Some updates to get the GCN example working...
2 parents cd7a2dc + f9b9f46 commit 7131803

File tree

5 files changed

+74
-23
lines changed

5 files changed

+74
-23
lines changed

examples/gat.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
using GeometricFlux
2+
using Flux
3+
using Flux: onehotbatch, onecold, crossentropy, throttle
4+
using JLD2 # use v0.1.2
5+
using Statistics: mean
6+
using SparseArrays
7+
using LightGraphs.SimpleGraphs
8+
using CuArrays
9+
10+
@load "data/cora_features.jld2" features
11+
@load "data/cora_labels.jld2" labels
12+
@load "data/cora_graph.jld2" g
13+
14+
num_nodes = 2708
15+
num_features = 1433
16+
17+
heads = 8
18+
hidden = 8
19+
target_catg = 7
20+
epochs = 10
21+
22+
## Preprocessing data
23+
train_X = features |> gpu # dim: num_features * num_nodes
24+
train_y = labels |> gpu # dim: target_catg * num_nodes
25+
26+
## Model
27+
model = Chain(GATConv(g, num_features=>hidden, heads=heads),
28+
Dropout(0.6),
29+
GATConv(g, hidden=>target_catg, heads=heads),
30+
softmax) |> gpu
31+
# test model
32+
# model(train_X)
33+
34+
## Loss
35+
loss(x, y) = crossentropy(model(x), y)
36+
accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))
37+
38+
## Training
39+
ps = Flux.params(model)
40+
train_data = [(train_X, train_y)]
41+
opt = ADAM(0.01)
42+
evalcb() = @show(accuracy(train_X, train_y))
43+
44+
for i = 1:epochs
45+
Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))
46+
end

examples/gcn.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using JLD2 # use v0.1.2
55
using Statistics: mean
66
using SparseArrays
77
using LightGraphs.SimpleGraphs
8+
using LightGraphs: adjacency_matrix
89
using CuArrays
910

1011
@load "data/cora_features.jld2" features
@@ -15,19 +16,19 @@ num_nodes = 2708
1516
num_features = 1433
1617
hidden = 16
1718
target_catg = 7
18-
epochs = 10
19+
epochs = 20
1920

2021
## Preprocessing data
21-
train_X = features |> gpu # dim: num_features * num_nodes
22-
train_y = labels |> gpu # dim: target_catg * num_nodes
22+
train_X = Float32.(features) |> gpu # dim: num_features * num_nodes
23+
train_y = Float32.(labels) |> gpu # dim: target_catg * num_nodes
24+
25+
adj_mat = Matrix{Float32}(adjacency_matrix(g)) |> gpu
2326

2427
## Model
25-
model = Chain(GCNConv(g, num_features=>hidden, relu),
28+
model = Chain(GCNConv(adj_mat, num_features=>hidden, relu),
2629
Dropout(0.5),
27-
GCNConv(g, hidden=>target_catg),
30+
GCNConv(adj_mat, hidden=>target_catg),
2831
softmax) |> gpu
29-
# test model
30-
# model(train_X)
3132

3233
## Loss
3334
loss(x, y) = crossentropy(model(x), y)

src/graph/simplegraphs.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using LightGraphs: AbstractSimpleGraph, nv, adjacency_matrix
2-
1+
using LightGraphs: AbstractSimpleGraph, nv, adjacency_matrix, inneighbors, outneighbors, all_neighbors
2+
import LightGraphs: adjacency_matrix
33

44
## Linear algebra API for AbstractSimpleGraph
55

@@ -19,12 +19,7 @@ function laplacian_matrix(sg::AbstractSimpleGraph, T::DataType=eltype(sg); dir::
1919
laplacian_matrix(adjacency_matrix(sg, T; dir=dir), T; dir=dir)
2020
end
2121

22-
function normalized_laplacian(sg::AbstractSimpleGraph, T::DataType=eltype(sg); selfloop::Bool=false)
23-
adj = adjacency_matrix(sg, T)
24-
selfloop && (adj += I)
25-
normalized_laplacian(adj, T)
26-
end
27-
22+
adjacency_matrix(sg::Base.RefValue{<:AbstractSimpleGraph}, T::DataType=eltype(sg)) = adjacency_matrix(sg[], T)
2823

2924
## Convolution layers accepting AbstractSimpleGraph
3025

src/layers/conv.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,24 @@ end
2626

2727
function GCNConv(ch::Pair{<:Integer,<:Integer}, σ = identity;
2828
init=glorot_uniform, T::DataType=Float32, bias::Bool=true, cache::Bool=true)
29-
b = bias ? init(ch[2]) : zeros(T, ch[2])
29+
b = bias ? T.(init(ch[2])) : zeros(T, ch[2])
3030
graph = cache ? FeaturedGraph(nothing, nothing) : NullGraph()
31-
GCNConv(init(ch[2], ch[1]), b, σ, graph)
31+
GCNConv(T.(init(ch[2], ch[1])), b, σ, graph)
3232
end
3333

3434
function GCNConv(adj::AbstractMatrix, ch::Pair{<:Integer,<:Integer}, σ = identity;
3535
init=glorot_uniform, T::DataType=Float32, bias::Bool=true, cache::Bool=true)
36-
b = bias ? init(ch[2]) : zeros(T, ch[2])
36+
b = bias ? T.(init(ch[2])) : zeros(T, ch[2])
3737
graph = cache ? FeaturedGraph(adj, nothing) : NullGraph()
38-
GCNConv(init(ch[2], ch[1]), b, σ, graph)
38+
GCNConv(T.(init(ch[2], ch[1])), b, σ, graph)
3939
end
4040

4141
@functor GCNConv
4242

4343
function (g::GCNConv)(X::AbstractMatrix{T}) where {T}
44-
g.σ.(g.weight * X * normalized_laplacian(graph(g.graph), T; selfloop=true) .+ g.bias)
44+
W, b, σ = g.weight, g.bias, g.σ
45+
nl = normalized_laplacian(graph(g.graph), float(T); selfloop=true)
46+
σ.(W * X * nl .+ b)
4547
end
4648

4749
function (g::GCNConv)(fg::FeaturedGraph)

src/linalg.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
## Linear algebra API for adjacency matrix
2+
using LinearAlgebra
23

34
"""
45
degrees(g[, T; dir=:out])
@@ -82,8 +83,8 @@ The values other than diagonal are zeros.
8283
- `dir`: direction of degree; should be `:in`, `:out`, or `:both` (optional).
8384
"""
8485
function inv_sqrt_degree_matrix(adj::AbstractMatrix, T::DataType=eltype(adj); dir::Symbol=:out)
85-
d = degrees(adj, T, dir=dir).^(-0.5)
86-
return SparseMatrixCSC(T.(diagm(0=>d)))
86+
d = inv.(sqrt.(degrees(adj, T, dir=dir)))
87+
return Diagonal(d)
8788
end
8889

8990
"""
@@ -110,10 +111,16 @@ Normalized Laplacian matrix of graph `g`.
110111
- `T`: result element type of degree vector; default is the element type of `g` (optional).
111112
- `selfloop`: adding self loop while calculating the matrix (optional).
112113
"""
114+
function normalized_laplacian(sg, T::DataType=eltype(sg); selfloop::Bool=false)
115+
adj = adjacency_matrix(sg, T)
116+
selfloop && (adj += I)
117+
normalized_laplacian(adj, T)
118+
end
119+
113120
function normalized_laplacian(adj::AbstractMatrix, T::DataType=eltype(adj); selfloop::Bool=false)
114121
selfloop && (adj += I)
115122
inv_sqrtD = inv_sqrt_degree_matrix(adj, T, dir=:both)
116-
I - inv_sqrtD * SparseMatrixCSC(T.(adj)) * inv_sqrtD
123+
I - inv_sqrtD * adj * inv_sqrtD
117124
end
118125

119126
function neighbors(adj::AbstractMatrix, T::DataType=eltype(adj))

0 commit comments

Comments
 (0)