Skip to content

Commit f49858c

Browse files
authored
GCNConv layer supports FeaturedGraph (#34)
* Correct layer bias * Linear algebra API for AbstractSimpleGraph, AbstractSimpleWeightedGraph and AbstractMetaGraph * Add FeaturedGraph to decouple graph from layer * Fix degrees doc and add selfloop kwargs * GCNConv layer supports FeaturedGraph
1 parent 48edd0c commit f49858c

File tree

15 files changed

+435
-124
lines changed

15 files changed

+435
-124
lines changed

src/GeometricFlux.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ export
8080
scatter_mean!,
8181
scatter!,
8282

83+
# graph/featuredgraphs
84+
AbstractFeaturedGraph,
85+
NullGraph,
86+
FeaturedGraph,
87+
graph,
88+
feature,
89+
8390
# graph/utils
8491
adjlist,
8592

@@ -107,6 +114,7 @@ const IntOrTuple = Union{Integer,Tuple}
107114

108115
include("scatter.jl")
109116
include("linalg.jl")
117+
include("graph/featuredgraphs.jl")
110118
include("utils.jl")
111119
include("layers/meta.jl")
112120
include("layers/msgpass.jl")

src/graph/featuredgraphs.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
abstract type AbstractFeaturedGraph end
2+
3+
struct NullGraph <: AbstractFeaturedGraph end
4+
5+
struct FeaturedGraph{T,S} <: AbstractFeaturedGraph
6+
graph::Ref{T}
7+
feature::Ref{S}
8+
FeaturedGraph(graph::T, feature::S) where {T,S} = new{T,S}(Ref(graph), Ref(feature))
9+
end
10+
11+
graph(::NullGraph) = nothing
12+
graph(fg::FeaturedGraph) = fg.graph[]
13+
14+
feature(::NullGraph) = nothing
15+
feature(fg::FeaturedGraph) = fg.feature[]
16+
17+
18+
## Linear algebra API for AbstractFeaturedGraph
19+
20+
function degrees(fg::FeaturedGraph, T::DataType=eltype(fg.graph[]); dir::Symbol=:out)
21+
degrees(fg.graph[], T; dir=dir)
22+
end
23+
24+
function degree_matrix(fg::FeaturedGraph, T::DataType=eltype(fg.graph[]); dir::Symbol=:out)
25+
degree_matrix(fg.graph[], T; dir=dir)
26+
end
27+
28+
function inv_sqrt_degree_matrix(fg::FeaturedGraph, T::DataType=eltype(fg.graph[]); dir::Symbol=:out)
29+
inv_sqrt_degree_matrix(fg.graph[], T; dir=dir)
30+
end
31+
32+
function laplacian_matrix(fg::FeaturedGraph, T::DataType=eltype(fg.graph[]); dir::Symbol=:out)
33+
laplacian_matrix(fg.graph[], T; dir=dir)
34+
end
35+
36+
function normalized_laplacian(fg::FeaturedGraph, T::DataType=eltype(fg.graph[]))
37+
normalized_laplacian(fg.graph[], T)
38+
end

src/graph/metagraphs.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,33 @@
11
using MetaGraphs: AbstractMetaGraph
22

3+
4+
## Linear algebra API for AbstractMetaGraph
5+
6+
function degrees(mg::AbstractMetaGraph, T::DataType=eltype(mg); dir::Symbol=:out)
7+
degrees(adjacency_matrix(mg.graph, T; dir=dir), T; dir=dir)
8+
end
9+
10+
function degree_matrix(mg::AbstractMetaGraph, T::DataType=eltype(mg); dir::Symbol=:out)
11+
degree_matrix(adjacency_matrix(mg.graph, T; dir=dir), T; dir=dir)
12+
end
13+
14+
function inv_sqrt_degree_matrix(mg::AbstractMetaGraph, T::DataType=eltype(mg); dir::Symbol=:out)
15+
inv_sqrt_degree_matrix(adjacency_matrix(mg.graph, T; dir=dir), T; dir=dir)
16+
end
17+
18+
function laplacian_matrix(mg::AbstractMetaGraph, T::DataType=eltype(mg); dir::Symbol=:out)
19+
laplacian_matrix(adjacency_matrix(mg.graph, T; dir=dir), T; dir=dir)
20+
end
21+
22+
function normalized_laplacian(mg::AbstractMetaGraph, T::DataType=eltype(mg); selfloop::Bool=false)
23+
adj = adjacency_matrix(mg.graph, T)
24+
selfloop && (adj += I)
25+
normalized_laplacian(adj, T)
26+
end
27+
28+
29+
## Convolution layers accepting AbstractMetaGraph
30+
331
GCNConv(g::AbstractMetaGraph, ch::Pair{<:Integer,<:Integer}, σ=identity; kwargs...) =
432
GCNConv(g.graph, ch, σ; kwargs...)
533

src/graph/simplegraphs.jl

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,45 @@
11
using LightGraphs: AbstractSimpleGraph, nv, adjacency_matrix
22

3+
4+
## Linear algebra API for AbstractSimpleGraph
5+
6+
function degrees(sg::AbstractSimpleGraph, T::DataType=eltype(sg); dir::Symbol=:out)
7+
degrees(adjacency_matrix(sg, T; dir=dir), T; dir=dir)
8+
end
9+
10+
function degree_matrix(sg::AbstractSimpleGraph, T::DataType=eltype(sg); dir::Symbol=:out)
11+
degree_matrix(adjacency_matrix(sg, T; dir=dir), T; dir=dir)
12+
end
13+
14+
function inv_sqrt_degree_matrix(sg::AbstractSimpleGraph, T::DataType=eltype(sg); dir::Symbol=:out)
15+
inv_sqrt_degree_matrix(adjacency_matrix(sg, T; dir=dir), T; dir=dir)
16+
end
17+
18+
function laplacian_matrix(sg::AbstractSimpleGraph, T::DataType=eltype(sg); dir::Symbol=:out)
19+
laplacian_matrix(adjacency_matrix(sg, T; dir=dir), T; dir=dir)
20+
end
21+
22+
function normalized_laplacian(sg::AbstractSimpleGraph, T::DataType=eltype(sg); selfloop::Bool=false)
23+
adj = adjacency_matrix(sg, T)
24+
selfloop && (adj += I)
25+
normalized_laplacian(adj, T)
26+
end
27+
28+
29+
## Convolution layers accepting AbstractSimpleGraph
30+
331
function GCNConv(g::AbstractSimpleGraph, ch::Pair{<:Integer,<:Integer}, σ = identity;
432
init = glorot_uniform, T::DataType=Float32, bias::Bool=true)
5-
N = nv(g)
6-
b = bias ? init(ch[2], N) : zeros(T, ch[2], N)
7-
adj = adjacency_matrix(g)
8-
GCNConv(init(ch[2], ch[1]), b, normalized_laplacian(adj+I, T), σ)
33+
b = bias ? init(ch[2]) : zeros(T, ch[2])
34+
fg = FeaturedGraph(Ref(g), Ref(nothing))
35+
GCNConv(init(ch[2], ch[1]), b, σ, fg)
936
end
1037

1138

1239
function ChebConv(g::AbstractSimpleGraph, ch::Pair{<:Integer,<:Integer}, k::Integer;
1340
init = glorot_uniform, T::DataType=Float32, bias::Bool=true)
1441
N = nv(g)
15-
b = bias ? init(ch[2], N) : zeros(T, ch[2], N)
42+
b = bias ? init(ch[2]) : zeros(T, ch[2])
1643
adj = adjacency_matrix(g)
1744
= T(2. / eigmax(Matrix(adj))) * normalized_laplacian(adj, T) - I
1845
ChebConv(init(ch[2], ch[1], k), b, L̃, k, ch[1], ch[2])
@@ -22,7 +49,7 @@ end
2249
function GraphConv(g::AbstractSimpleGraph, ch::Pair{<:Integer,<:Integer}, aggr=:add;
2350
init = glorot_uniform, bias::Bool=true)
2451
N = nv(g)
25-
b = bias ? init(ch[2], N) : zeros(T, ch[2], N)
52+
b = bias ? init(ch[2]) : zeros(T, ch[2])
2653
GraphConv(adjlist(g), init(ch[2], ch[1]), init(ch[2], ch[1]), b, aggr)
2754
end
2855

@@ -32,7 +59,7 @@ function GATConv(g::AbstractSimpleGraph, ch::Pair{<:Integer,<:Integer}; heads=1,
3259
bias::Bool=true)
3360
N = nv(g)
3461
w = init(ch[2]*heads, ch[1])
35-
b = bias ? init(ch[2]*heads, N) : zeros(T, ch[2]*heads, N)
62+
b = bias ? init(ch[2]*heads) : zeros(T, ch[2]*heads)
3663
a = init(2*ch[2], heads, 1)
3764
GATConv(adjlist(g), w, b, a, negative_slope, ch, heads, concat)
3865
end

src/graph/weightedgraphs.jl

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,45 @@
11
using SimpleWeightedGraphs: AbstractSimpleWeightedGraph, nv
22

3+
4+
## Linear algebra API for AbstractSimpleWeightedGraph
5+
6+
function degrees(wg::AbstractSimpleWeightedGraph, T::DataType=eltype(wg); dir::Symbol=:out)
7+
degrees(adjacency_matrix(wg, T; dir=dir), T; dir=dir)
8+
end
9+
10+
function degree_matrix(wg::AbstractSimpleWeightedGraph, T::DataType=eltype(wg); dir::Symbol=:out)
11+
degree_matrix(adjacency_matrix(wg, T; dir=dir), T; dir=dir)
12+
end
13+
14+
function inv_sqrt_degree_matrix(wg::AbstractSimpleWeightedGraph, T::DataType=eltype(wg); dir::Symbol=:out)
15+
inv_sqrt_degree_matrix(adjacency_matrix(wg, T; dir=dir), T; dir=dir)
16+
end
17+
18+
function laplacian_matrix(wg::AbstractSimpleWeightedGraph, T::DataType=eltype(wg); dir::Symbol=:out)
19+
laplacian_matrix(adjacency_matrix(wg, T; dir=dir), T; dir=dir)
20+
end
21+
22+
function normalized_laplacian(wg::AbstractSimpleWeightedGraph, T::DataType=eltype(wg); selfloop::Bool=false)
23+
adj = adjacency_matrix(wg, T)
24+
selfloop && (adj += I)
25+
normalized_laplacian(adj, T)
26+
end
27+
28+
29+
## Convolution layers accepting AbstractSimpleWeightedGraph
30+
331
function GCNConv(g::AbstractSimpleWeightedGraph, ch::Pair{<:Integer,<:Integer}, σ = identity;
432
init = glorot_uniform, T::DataType=Float32, bias::Bool=true)
5-
N = nv(g)
6-
b = bias ? init(ch[2], N) : zeros(T, ch[2], N)
7-
adj = adjacency_matrix(g)
8-
GCNConv(init(ch[2], ch[1]), b, normalized_laplacian(adj+I, T), σ)
33+
b = bias ? init(ch[2]) : zeros(T, ch[2])
34+
fg = FeaturedGraph(Ref(g), Ref(nothing))
35+
GCNConv(init(ch[2], ch[1]), b, σ, fg)
936
end
1037

1138

1239
function ChebConv(g::AbstractSimpleWeightedGraph, ch::Pair{<:Integer,<:Integer}, k::Integer;
1340
init = glorot_uniform, T::DataType=Float32, bias::Bool=true)
1441
N = nv(g)
15-
b = bias ? init(ch[2], N) : zeros(T, ch[2], N)
42+
b = bias ? init(ch[2]) : zeros(T, ch[2])
1643
adj = adjacency_matrix(g)
1744
= T(2. / eigmax(Matrix(adj))) * normalized_laplacian(adj, T) - I
1845
ChebConv(init(ch[2], ch[1], k), b, L̃, k, ch[1], ch[2])
@@ -22,7 +49,7 @@ end
2249
function GraphConv(g::AbstractSimpleWeightedGraph, ch::Pair{<:Integer,<:Integer}, aggr=:add;
2350
init = glorot_uniform, bias::Bool=true)
2451
N = nv(g)
25-
b = bias ? init(ch[2], N) : zeros(T, ch[2], N)
52+
b = bias ? init(ch[2]) : zeros(T, ch[2])
2653
GraphConv(adjlist(g), init(ch[2], ch[1]), init(ch[2], ch[1]), b, aggr)
2754
end
2855

@@ -32,7 +59,7 @@ function GATConv(g::AbstractSimpleWeightedGraph, ch::Pair{<:Integer,<:Integer};
3259
bias::Bool=true)
3360
N = nv(g)
3461
w = init(ch[2]*heads, ch[1])
35-
b = bias ? init(ch[2]*heads, N) : zeros(T, ch[2]*heads, N)
62+
b = bias ? init(ch[2]*heads) : zeros(T, ch[2]*heads)
3663
a = init(2*ch[2], heads, 1)
3764
GATConv(adjlist(g), w, b, a, negative_slope, ch, heads, concat)
3865
end

src/layers/conv.jl

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ const AGGR2STR = Dict{Symbol,String}(:add => "∑", :sub => "-∑", :mul => "∏
22
:max => "max", :min => "min", :mean => "𝔼[]")
33

44
"""
5-
GCNConv(graph, in=>out)
6-
GCNConv(graph, in=>out, σ)
5+
GCNConv([graph, ]in=>out)
6+
GCNConv([graph, ]in=>out, σ)
77
88
Graph convolutional layer.
99
@@ -17,25 +17,38 @@ Data should be stored in (# features, # nodes) order.
1717
For example, a 1000-node graph each node of which poses 100 feautres is constructed.
1818
The input data would be a `1000×100` array.
1919
"""
20-
struct GCNConv{T,F}
20+
struct GCNConv{T,F,S<:AbstractFeaturedGraph}
2121
weight::AbstractMatrix{T}
22-
bias::AbstractMatrix{T}
23-
norm::AbstractMatrix{T}
22+
bias::AbstractVector{T}
2423
σ::F
24+
graph::S
25+
end
26+
27+
function GCNConv(ch::Pair{<:Integer,<:Integer}, σ = identity;
28+
init=glorot_uniform, T::DataType=Float32, bias::Bool=true, cache::Bool=true)
29+
b = bias ? init(ch[2]) : zeros(T, ch[2])
30+
graph = cache ? FeaturedGraph(nothing, nothing) : NullGraph()
31+
GCNConv(init(ch[2], ch[1]), b, σ, graph)
2532
end
2633

2734
function GCNConv(adj::AbstractMatrix, ch::Pair{<:Integer,<:Integer}, σ = identity;
28-
init = glorot_uniform, T::DataType=Float32, bias::Bool=true)
29-
N = size(adj, 1)
30-
b = bias ? init(ch[2], N) : zeros(T, ch[2], N)
31-
GCNConv(init(ch[2], ch[1]), b, normalized_laplacian(adj+I, T), σ)
35+
init=glorot_uniform, T::DataType=Float32, bias::Bool=true, cache::Bool=true)
36+
b = bias ? init(ch[2]) : zeros(T, ch[2])
37+
graph = cache ? FeaturedGraph(adj, nothing) : NullGraph()
38+
GCNConv(init(ch[2], ch[1]), b, σ, graph)
3239
end
3340

3441
@functor GCNConv
3542

36-
(g::GCNConv)(X::AbstractMatrix) = g.σ.(g.weight * X * g.norm + g.bias)
37-
(g::GCNConv)(X::AbstractMatrix{T}, A::AbstractMatrix) where T =
38-
g.σ.(g.weight * X * normalized_laplacian(A+I, T) + g.bias)
43+
function (g::GCNConv)(X::AbstractMatrix{T}) where {T}
44+
g.σ.(g.weight * X * normalized_laplacian(graph(g.graph), T; selfloop=true) .+ g.bias)
45+
end
46+
47+
function (g::GCNConv)(gr::FeaturedGraph)
48+
X = feature(gr)
49+
A = graph(gr)
50+
g.σ.(g.weight * X * normalized_laplacian(A, eltype(X); selfloop=true) .+ g.bias)
51+
end
3952

4053
function Base.show(io::IO, l::GCNConv)
4154
in_channel = size(l.weight, ndims(l.weight))
@@ -62,7 +75,7 @@ Chebyshev spectral graph convolutional layer.
6275
"""
6376
struct ChebConv{T}
6477
weight::AbstractArray{T,3}
65-
bias::AbstractMatrix{T}
78+
bias::AbstractVector{T}
6679
::AbstractMatrix{T}
6780
k::Integer
6881
in_channel::Integer
@@ -72,7 +85,7 @@ end
7285
function ChebConv(adj::AbstractMatrix, ch::Pair{<:Integer,<:Integer}, k::Integer;
7386
init = glorot_uniform, T::DataType=Float32, bias::Bool=true)
7487
N = size(adj, 1)
75-
b = bias ? init(ch[2], N) : zeros(T, ch[2], N)
88+
b = bias ? init(ch[2]) : zeros(T, ch[2])
7689
= T(2. / eigmax(adj)) * normalized_laplacian(adj, T) - I
7790
ChebConv(init(ch[2], ch[1], k), b, L̃, k, ch[1], ch[2])
7891
end
@@ -97,7 +110,7 @@ function (c::ChebConv)(X::AbstractMatrix{T}) where {T<:Real}
97110
for k = 2:c.k
98111
Y += view(c.weight, :, :, k) * view(Z, :, :, k)
99112
end
100-
Y += c.bias
113+
Y .+= c.bias
101114
return Y
102115
end
103116

@@ -127,29 +140,29 @@ struct GraphConv{V,T} <: MessagePassing
127140
adjlist::V
128141
weight1::AbstractMatrix{T}
129142
weight2::AbstractMatrix{T}
130-
bias::AbstractMatrix{T}
143+
bias::AbstractVector{T}
131144
aggr::Symbol
132145
end
133146

134147
function GraphConv(el::AbstractVector{<:AbstractVector{<:Integer}},
135148
ch::Pair{<:Integer,<:Integer}, aggr=:add;
136149
init = glorot_uniform, bias::Bool=true)
137150
N = size(el, 1)
138-
b = bias ? init(ch[2], N) : zeros(T, ch[2], N)
151+
b = bias ? init(ch[2]) : zeros(T, ch[2])
139152
GraphConv(el, init(ch[2], ch[1]), init(ch[2], ch[1]), b, aggr)
140153
end
141154

142155
function GraphConv(adj::AbstractMatrix, ch::Pair{<:Integer,<:Integer}, aggr=:add;
143156
init = glorot_uniform, bias::Bool=true, T::DataType=Float32)
144157
N = size(adj, 1)
145-
b = bias ? init(ch[2], N) : zeros(T, ch[2], N)
158+
b = bias ? init(ch[2]) : zeros(T, ch[2])
146159
GraphConv(neighbors(adj), init(ch[2], ch[1]), init(ch[2], ch[1]), b, aggr)
147160
end
148161

149162
@functor GraphConv
150163

151164
message(g::GraphConv; x_i=zeros(0), x_j=zeros(0)) = g.weight2 * x_j
152-
update(g::GraphConv; X=zeros(0), M=zeros(0)) = g.weight1*X + M + g.bias
165+
update(g::GraphConv; X=zeros(0), M=zeros(0)) = g.weight1*X + M .+ g.bias
153166
(g::GraphConv)(X::AbstractMatrix) = propagate(g, X=X, aggr=:add)
154167

155168
function Base.show(io::IO, l::GraphConv)
@@ -178,7 +191,7 @@ Graph attentional layer.
178191
struct GATConv{V,T} <: MessagePassing
179192
adjlist::V
180193
weight::AbstractMatrix{T}
181-
bias::AbstractMatrix{T}
194+
bias::AbstractVector{T}
182195
a::AbstractArray{T,3}
183196
negative_slope::Real
184197
channel::Pair{<:Integer,<:Integer}
@@ -191,7 +204,7 @@ function GATConv(adj::AbstractMatrix, ch::Pair{<:Integer,<:Integer}; heads::Inte
191204
bias::Bool=true, T::DataType=Float32)
192205
N = size(adj, 1)
193206
w = init(ch[2]*heads, ch[1])
194-
b = bias ? init(ch[2]*heads, N) : zeros(T, ch[2]*heads, N)
207+
b = bias ? init(ch[2]*heads) : zeros(T, ch[2]*heads)
195208
a = init(2*ch[2], heads, 1)
196209
GATConv(neighbors(adj), w, b, a, negative_slope, ch, heads, concat)
197210
end

0 commit comments

Comments
 (0)