Skip to content

Commit 8acbe28

Browse files
authored
Merge pull request #272 from FluxML/withgraph
Integrate WithGraph and AbstractGraphLayer
2 parents d617a68 + d29b714 commit 8acbe28

File tree

9 files changed

+294
-198
lines changed

9 files changed

+294
-198
lines changed

src/GeometricFlux.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ include("layers/graphlayers.jl")
6767
include("layers/gn.jl")
6868
include("layers/msgpass.jl")
6969

70-
include("layers/utils.jl")
7170
include("layers/conv.jl")
7271
include("layers/pool.jl")
7372
include("models.jl")

src/layers/conv.jl

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ GCNConv(1024 => 256, relu)
2323
2424
See also [`WithGraph`](@ref) for training layer with static graph.
2525
"""
26-
struct GCNConv{A<:AbstractMatrix,B,F}
26+
struct GCNConv{A<:AbstractMatrix,B,F} <: AbstractGraphLayer
2727
weight::A
2828
bias::B
2929
σ::F
@@ -57,12 +57,10 @@ end
5757

5858
# For static graph
5959
WithGraph(fg::AbstractFeaturedGraph, l::GCNConv) =
60-
WithGraph(l, GraphSignals.normalized_adjacency_matrix!(fg, eltype(l.weight); selfloop=true))
60+
WithGraph(GraphSignals.normalized_adjacency_matrix(fg, eltype(l.weight); selfloop=true), l)
6161

6262
function (wg::WithGraph{<:GCNConv})(X::AbstractArray)
63-
= Zygote.ignore() do
64-
GraphSignals.normalized_adjacency_matrix(wg.fg)
65-
end
63+
= wg.graph
6664
return wg.layer(Ã, X)
6765
end
6866

@@ -75,66 +73,96 @@ end
7573

7674

7775
"""
78-
ChebConv([fg,] in=>out, k; bias=true, init=glorot_uniform)
76+
ChebConv(in=>out, k; bias=true, init=glorot_uniform)
7977
8078
Chebyshev spectral graph convolutional layer.
8179
8280
# Arguments
8381
84-
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
8582
- `in`: The dimension of input features.
8683
- `out`: The dimension of output features.
8784
- `k`: The order of Chebyshev polynomial.
8885
- `bias`: Add learnable bias.
8986
- `init`: Weights' initializer.
87+
88+
# Example
89+
90+
```jldoctest
91+
julia> cc = ChebConv(1024=>256, 5, relu)
92+
ChebConv(1024 => 256, k=5, relu)
93+
```
94+
95+
See also [`WithGraph`](@ref) for training layer with static graph.
9096
"""
91-
struct ChebConv{A<:AbstractArray{<:Number,3}, B, S<:AbstractFeaturedGraph} <: AbstractGraphLayer
97+
struct ChebConv{A<:AbstractArray{<:Number,3},B,F} <: AbstractGraphLayer
9298
weight::A
9399
bias::B
94-
fg::S
95100
k::Int
101+
σ::F
96102
end
97103

98-
function ChebConv(fg::AbstractFeaturedGraph, ch::Pair{Int,Int}, k::Int;
104+
function ChebConv(ch::Pair{Int,Int}, k::Int, σ=identity;
99105
init=glorot_uniform, bias::Bool=true)
100106
in, out = ch
101107
W = init(out, in, k)
102108
b = Flux.create_bias(W, bias, out)
103-
ChebConv(W, b, fg, k)
109+
ChebConv(W, b, k, σ)
104110
end
105111

106-
ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
107-
ChebConv(NullGraph(), ch, k; kwargs...)
108-
109112
@functor ChebConv
110113

111114
Flux.trainable(l::ChebConv) = (l.weight, l.bias)
112115

113-
function (c::ChebConv)(fg::AbstractFeaturedGraph, X::AbstractMatrix{T}) where T
114-
GraphSignals.check_num_nodes(fg, X)
115-
@assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size."
116-
117-
= Zygote.ignore() do
118-
GraphSignals.scaled_laplacian(fg, eltype(X))
119-
end
120-
116+
function (l::ChebConv)(L̃::AbstractMatrix, X::AbstractMatrix)
121117
Z_prev = X
122118
Z = X *
123-
Y = view(c.weight,:,:,1) * Z_prev
124-
Y += view(c.weight,:,:,2) * Z
125-
for k = 3:c.k
119+
Y = view(l.weight,:,:,1) * Z_prev
120+
Y += view(l.weight,:,:,2) * Z
121+
for k = 3:l.k
126122
Z, Z_prev = 2 .* Z *- Z_prev, Z
127-
Y += view(c.weight,:,:,k) * Z
123+
Y += view(l.weight,:,:,k) * Z
124+
end
125+
return l.σ.(Y .+ l.bias)
126+
end
127+
128+
function (l::ChebConv)(L̃::AbstractMatrix, X::AbstractArray)
129+
Z_prev = X
130+
Z = NNlib.batched_mul(X, L̃)
131+
Y = NNlib.batched_mul(view(l.weight,:,:,1), Z_prev)
132+
Y += NNlib.batched_mul(view(l.weight,:,:,2), Z)
133+
for k = 3:l.k
134+
Z, Z_prev = 2 .* NNlib.batched_mul(Z, L̃) .- Z_prev, Z
135+
Y += NNlib.batched_mul(view(l.weight,:,:,k), Z)
136+
end
137+
return l.σ.(Y .+ l.bias)
138+
end
139+
140+
# For variable graph
141+
function (l::ChebConv)(fg::AbstractFeaturedGraph)
142+
nf = node_feature(fg)
143+
GraphSignals.check_num_nodes(fg, nf)
144+
@assert size(nf, 1) == size(l.weight, 2) "Input feature size must match input channel size."
145+
146+
= Zygote.ignore() do
147+
GraphSignals.scaled_laplacian(fg, eltype(nf))
128148
end
129-
return Y .+ c.bias
149+
return ConcreteFeaturedGraph(fg, nf = l(L̃, nf))
130150
end
131151

132-
(l::ChebConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
152+
# For static graph
153+
WithGraph(fg::AbstractFeaturedGraph, l::ChebConv) =
154+
WithGraph(GraphSignals.scaled_laplacian(fg, eltype(l.weight)), l)
155+
156+
function (wg::WithGraph{<:ChebConv})(X::AbstractArray)
157+
= wg.graph
158+
return wg.layer(L̃, X)
159+
end
133160

134161
function Base.show(io::IO, l::ChebConv)
135162
out, in, k = size(l.weight)
136163
print(io, "ChebConv(", in, " => ", out)
137164
print(io, ", k=", k)
165+
l.σ == identity || print(io, ", ", l.σ)
138166
print(io, ")")
139167
end
140168

@@ -192,6 +220,8 @@ end
192220

193221
(l::GraphConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
194222
# (l::GraphConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
223+
(l::GraphConv)(x::AbstractMatrix) = l(l.fg, x)
224+
(l::GraphConv)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))
195225

196226
function Base.show(io::IO, l::GraphConv)
197227
in_channel = size(l.weight1, ndims(l.weight1))
@@ -307,6 +337,8 @@ end
307337

308338
(l::GATConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
309339
# (l::GATConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
340+
(l::GATConv)(x::AbstractMatrix) = l(l.fg, x)
341+
(l::GATConv)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))
310342

311343
function Base.show(io::IO, l::GATConv)
312344
in_channel = size(l.weight, ndims(l.weight))
@@ -358,13 +390,13 @@ message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j
358390
update(ggc::GatedGraphConv, m::AbstractVector, x) = m
359391

360392

361-
function (ggc::GatedGraphConv)(fg::AbstractFeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
393+
function (ggc::GatedGraphConv)(fg::AbstractFeaturedGraph, H::AbstractMatrix{T}) where {T<:Real}
362394
GraphSignals.check_num_nodes(fg, H)
363395
m, n = size(H)
364396
@assert (m <= ggc.out_ch) "number of input features must less or equals to output features."
365397
if m < ggc.out_ch
366398
Hpad = Zygote.ignore() do
367-
fill!(similar(H, S, ggc.out_ch - m, n), 0)
399+
fill!(similar(H, T, ggc.out_ch - m, n), 0)
368400
end
369401
H = vcat(H, Hpad)
370402
end
@@ -378,6 +410,8 @@ end
378410

379411
(l::GatedGraphConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
380412
# (l::GatedGraphConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
413+
(l::GatedGraphConv)(x::AbstractMatrix) = l(l.fg, x)
414+
(l::GatedGraphConv)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))
381415

382416

383417
function Base.show(io::IO, l::GatedGraphConv)
@@ -423,6 +457,8 @@ end
423457

424458
(l::EdgeConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
425459
# (l::EdgeConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, l.aggr) # edge number check break this
460+
(l::EdgeConv)(x::AbstractMatrix) = l(l.fg, x)
461+
(l::EdgeConv)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))
426462

427463
function Base.show(io::IO, l::EdgeConv)
428464
print(io, "EdgeConv(", l.nn)
@@ -475,6 +511,8 @@ end
475511

476512
(l::GINConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
477513
# (l::GINConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
514+
(l::GINConv)(x::AbstractMatrix) = l(l.fg, x)
515+
(l::GINConv)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))
478516

479517

480518
"""

src/layers/graphlayers.jl

Lines changed: 122 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,125 @@ An abstract type of graph neural network layer for GeometricFlux.
55
"""
66
abstract type AbstractGraphLayer end
77

8-
(l::AbstractGraphLayer)(x::AbstractMatrix) = l(l.fg, x)
9-
(l::AbstractGraphLayer)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))
8+
"""
9+
WithGraph(fg, layer)
10+
11+
Train GNN layers with static graph.
12+
13+
# Arguments
14+
15+
- `fg`: A fixed `FeaturedGraph` to train with.
16+
- `layer`: A GNN layer.
17+
18+
# Example
19+
20+
```jldoctest
21+
julia> using GraphSignals, GeometricFlux
22+
23+
julia> adj = [0 1 0 1;
24+
1 0 1 0;
25+
0 1 0 1;
26+
1 0 1 0];
27+
28+
julia> fg = FeaturedGraph(adj);
29+
30+
julia> gc = WithGraph(fg, GCNConv(1024=>256))
31+
WithGraph(Graph(#V=4, #E=4), GCNConv(1024 => 256))
32+
33+
julia> WithGraph(fg, Dense(10, 5))
34+
Dense(10, 5) # 55 parameters
35+
36+
julia> model = Chain(
37+
GCNConv(32=>32),
38+
gc,
39+
);
40+
41+
julia> WithGraph(fg, model)
42+
Chain(
43+
WithGraph(
44+
GCNConv(32 => 32), # 1_056 parameters
45+
),
46+
WithGraph(
47+
GCNConv(1024 => 256), # 262_400 parameters
48+
),
49+
) # Total: 4 trainable arrays, 263_456 parameters,
50+
# plus 2 non-trainable, 32 parameters, summarysize 1.006 MiB.
51+
```
52+
"""
53+
struct WithGraph{L<:AbstractGraphLayer,G}
54+
graph::G
55+
layer::L
56+
end
57+
58+
@functor WithGraph
59+
60+
Flux.trainable(l::WithGraph) = (l.layer, )
61+
62+
function Flux.destructure(m::WithGraph)
63+
p, re = Flux.destructure(m.layer)
64+
function re_withgraph(x)
65+
WithGraph(re(x), m.fg)
66+
end
67+
68+
return p, re_withgraph
69+
end
70+
71+
function Base.show(io::IO, l::WithGraph)
72+
print(io, "WithGraph(Graph(#V=", nv(l.graph))
73+
print(io, ", #E=", ne(l.graph), "), ")
74+
print(io, l.layer, ")")
75+
end
76+
77+
WithGraph(fg::AbstractFeaturedGraph, model::Chain) = Chain(map(l -> WithGraph(fg, l), model.layers)...)
78+
WithGraph(::AbstractFeaturedGraph, layer::WithGraph) = layer
79+
WithGraph(::AbstractFeaturedGraph, layer) = layer
80+
81+
"""
82+
GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity)
83+
84+
Passing features in `FeaturedGraph` in parallel. It takes `FeaturedGraph` as input
85+
and it can be specified by assigning layers for specific (node, edge and global) features.
86+
87+
# Arguments
88+
89+
- `node_layer`: A regular Flux layer for passing node features.
90+
- `edge_layer`: A regular Flux layer for passing edge features.
91+
- `global_layer`: A regular Flux layer for passing global features.
92+
93+
# Example
94+
95+
```jldoctest
96+
julia> using Flux, GeometricFlux
97+
98+
julia> l = GraphParallel(
99+
node_layer=Dropout(0.5),
100+
global_layer=Dense(10, 5)
101+
)
102+
GraphParallel(node_layer=Dropout(0.5), edge_layer=identity, global_layer=Dense(10, 5))
103+
```
104+
"""
105+
struct GraphParallel{N,E,G}
106+
node_layer::N
107+
edge_layer::E
108+
global_layer::G
109+
end
110+
111+
@functor GraphParallel
112+
113+
GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity) =
114+
GraphParallel(node_layer, edge_layer, global_layer)
115+
116+
function (l::GraphParallel)(fg::AbstractFeaturedGraph)
117+
nf = l.node_layer(node_feature(fg))
118+
ef = l.edge_layer(edge_feature(fg))
119+
gf = l.global_layer(global_feature(fg))
120+
return ConcreteFeaturedGraph(fg, nf=nf, ef=ef, gf=gf)
121+
end
122+
123+
function Base.show(io::IO, l::GraphParallel)
124+
print(io, "GraphParallel(")
125+
print(io, "node_layer=", l.node_layer)
126+
print(io, ", edge_layer=", l.edge_layer)
127+
print(io, ", global_layer=", l.global_layer)
128+
print(io, ")")
129+
end

0 commit comments

Comments
 (0)