Skip to content

Integrate WithGraph and AbstractGraphLayer #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/GeometricFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ include("layers/graphlayers.jl")
include("layers/gn.jl")
include("layers/msgpass.jl")

include("layers/utils.jl")
include("layers/conv.jl")
include("layers/pool.jl")
include("models.jl")
Expand Down
98 changes: 68 additions & 30 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ GCNConv(1024 => 256, relu)

See also [`WithGraph`](@ref) for training layer with static graph.
"""
struct GCNConv{A<:AbstractMatrix,B,F}
struct GCNConv{A<:AbstractMatrix,B,F} <: AbstractGraphLayer
weight::A
bias::B
σ::F
Expand Down Expand Up @@ -57,12 +57,10 @@ end

# For static graph
WithGraph(fg::AbstractFeaturedGraph, l::GCNConv) =
WithGraph(l, GraphSignals.normalized_adjacency_matrix!(fg, eltype(l.weight); selfloop=true))
WithGraph(GraphSignals.normalized_adjacency_matrix(fg, eltype(l.weight); selfloop=true), l)

function (wg::WithGraph{<:GCNConv})(X::AbstractArray)
à = Zygote.ignore() do
GraphSignals.normalized_adjacency_matrix(wg.fg)
end
à = wg.graph
return wg.layer(Ã, X)
end

Expand All @@ -75,66 +73,96 @@ end


"""
ChebConv([fg,] in=>out, k; bias=true, init=glorot_uniform)
ChebConv(in=>out, k; bias=true, init=glorot_uniform)

Chebyshev spectral graph convolutional layer.

# Arguments

- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `k`: The order of Chebyshev polynomial.
- `bias`: Add learnable bias.
- `init`: Weights' initializer.

# Example

```jldoctest
julia> cc = ChebConv(1024=>256, 5, relu)
ChebConv(1024 => 256, k=5, relu)
```

See also [`WithGraph`](@ref) for training layer with static graph.
"""
struct ChebConv{A<:AbstractArray{<:Number,3}, B, S<:AbstractFeaturedGraph} <: AbstractGraphLayer
struct ChebConv{A<:AbstractArray{<:Number,3},B,F} <: AbstractGraphLayer
weight::A
bias::B
fg::S
k::Int
σ::F
end

function ChebConv(fg::AbstractFeaturedGraph, ch::Pair{Int,Int}, k::Int;
function ChebConv(ch::Pair{Int,Int}, k::Int, σ=identity;
init=glorot_uniform, bias::Bool=true)
in, out = ch
W = init(out, in, k)
b = Flux.create_bias(W, bias, out)
ChebConv(W, b, fg, k)
ChebConv(W, b, k, σ)
end

ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
ChebConv(NullGraph(), ch, k; kwargs...)

@functor ChebConv

Flux.trainable(l::ChebConv) = (l.weight, l.bias)

function (c::ChebConv)(fg::AbstractFeaturedGraph, X::AbstractMatrix{T}) where T
GraphSignals.check_num_nodes(fg, X)
@assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size."

L̃ = Zygote.ignore() do
GraphSignals.scaled_laplacian(fg, eltype(X))
end

function (l::ChebConv)(L̃::AbstractMatrix, X::AbstractMatrix)
Z_prev = X
Z = X * L̃
Y = view(c.weight,:,:,1) * Z_prev
Y += view(c.weight,:,:,2) * Z
for k = 3:c.k
Y = view(l.weight,:,:,1) * Z_prev
Y += view(l.weight,:,:,2) * Z
for k = 3:l.k
Z, Z_prev = 2 .* Z * L̃ - Z_prev, Z
Y += view(c.weight,:,:,k) * Z
Y += view(l.weight,:,:,k) * Z
end
return l.σ.(Y .+ l.bias)
end

function (l::ChebConv)(L̃::AbstractMatrix, X::AbstractArray)
Z_prev = X
Z = NNlib.batched_mul(X, L̃)
Y = NNlib.batched_mul(view(l.weight,:,:,1), Z_prev)
Y += NNlib.batched_mul(view(l.weight,:,:,2), Z)
for k = 3:l.k
Z, Z_prev = 2 .* NNlib.batched_mul(Z, L̃) .- Z_prev, Z
Y += NNlib.batched_mul(view(l.weight,:,:,k), Z)
end
return l.σ.(Y .+ l.bias)
end

# For variable graph
function (l::ChebConv)(fg::AbstractFeaturedGraph)
nf = node_feature(fg)
GraphSignals.check_num_nodes(fg, nf)
@assert size(nf, 1) == size(l.weight, 2) "Input feature size must match input channel size."

L̃ = Zygote.ignore() do
GraphSignals.scaled_laplacian(fg, eltype(nf))
end
return Y .+ c.bias
return ConcreteFeaturedGraph(fg, nf = l(L̃, nf))
end

(l::ChebConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
# For static graph
WithGraph(fg::AbstractFeaturedGraph, l::ChebConv) =
WithGraph(GraphSignals.scaled_laplacian(fg, eltype(l.weight)), l)

function (wg::WithGraph{<:ChebConv})(X::AbstractArray)
L̃ = wg.graph
return wg.layer(L̃, X)
end

function Base.show(io::IO, l::ChebConv)
out, in, k = size(l.weight)
print(io, "ChebConv(", in, " => ", out)
print(io, ", k=", k)
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end

Expand Down Expand Up @@ -192,6 +220,8 @@ end

(l::GraphConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
# (l::GraphConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
(l::GraphConv)(x::AbstractMatrix) = l(l.fg, x)
(l::GraphConv)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))

function Base.show(io::IO, l::GraphConv)
in_channel = size(l.weight1, ndims(l.weight1))
Expand Down Expand Up @@ -307,6 +337,8 @@ end

(l::GATConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
# (l::GATConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
(l::GATConv)(x::AbstractMatrix) = l(l.fg, x)
(l::GATConv)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))

function Base.show(io::IO, l::GATConv)
in_channel = size(l.weight, ndims(l.weight))
Expand Down Expand Up @@ -358,13 +390,13 @@ message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j
update(ggc::GatedGraphConv, m::AbstractVector, x) = m


function (ggc::GatedGraphConv)(fg::AbstractFeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
function (ggc::GatedGraphConv)(fg::AbstractFeaturedGraph, H::AbstractMatrix{T}) where {T<:Real}
GraphSignals.check_num_nodes(fg, H)
m, n = size(H)
@assert (m <= ggc.out_ch) "number of input features must less or equals to output features."
if m < ggc.out_ch
Hpad = Zygote.ignore() do
fill!(similar(H, S, ggc.out_ch - m, n), 0)
fill!(similar(H, T, ggc.out_ch - m, n), 0)
end
H = vcat(H, Hpad)
end
Expand All @@ -378,6 +410,8 @@ end

(l::GatedGraphConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
# (l::GatedGraphConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
(l::GatedGraphConv)(x::AbstractMatrix) = l(l.fg, x)
(l::GatedGraphConv)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))


function Base.show(io::IO, l::GatedGraphConv)
Expand Down Expand Up @@ -423,6 +457,8 @@ end

(l::EdgeConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
# (l::EdgeConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, l.aggr) # edge number check break this
(l::EdgeConv)(x::AbstractMatrix) = l(l.fg, x)
(l::EdgeConv)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))

function Base.show(io::IO, l::EdgeConv)
print(io, "EdgeConv(", l.nn)
Expand Down Expand Up @@ -475,6 +511,8 @@ end

(l::GINConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
# (l::GINConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
(l::GINConv)(x::AbstractMatrix) = l(l.fg, x)
(l::GINConv)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))


"""
Expand Down
124 changes: 122 additions & 2 deletions src/layers/graphlayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,125 @@ An abstract type of graph neural network layer for GeometricFlux.
"""
abstract type AbstractGraphLayer end

(l::AbstractGraphLayer)(x::AbstractMatrix) = l(l.fg, x)
(l::AbstractGraphLayer)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))
"""
WithGraph(fg, layer)

Train GNN layers with static graph.

# Arguments

- `fg`: A fixed `FeaturedGraph` to train with.
- `layer`: A GNN layer.

# Example

```jldoctest
julia> using GraphSignals, GeometricFlux

julia> adj = [0 1 0 1;
1 0 1 0;
0 1 0 1;
1 0 1 0];

julia> fg = FeaturedGraph(adj);

julia> gc = WithGraph(fg, GCNConv(1024=>256))
WithGraph(Graph(#V=4, #E=4), GCNConv(1024 => 256))

julia> WithGraph(fg, Dense(10, 5))
Dense(10, 5) # 55 parameters

julia> model = Chain(
GCNConv(32=>32),
gc,
);

julia> WithGraph(fg, model)
Chain(
WithGraph(
GCNConv(32 => 32), # 1_056 parameters
),
WithGraph(
GCNConv(1024 => 256), # 262_400 parameters
),
) # Total: 4 trainable arrays, 263_456 parameters,
# plus 2 non-trainable, 32 parameters, summarysize 1.006 MiB.
```
"""
struct WithGraph{L<:AbstractGraphLayer,G}
graph::G
layer::L
end

@functor WithGraph

Flux.trainable(l::WithGraph) = (l.layer, )

function Flux.destructure(m::WithGraph)
p, re = Flux.destructure(m.layer)
function re_withgraph(x)
WithGraph(re(x), m.fg)
end

return p, re_withgraph
end

function Base.show(io::IO, l::WithGraph)
print(io, "WithGraph(Graph(#V=", nv(l.graph))
print(io, ", #E=", ne(l.graph), "), ")
print(io, l.layer, ")")
end

WithGraph(fg::AbstractFeaturedGraph, model::Chain) = Chain(map(l -> WithGraph(fg, l), model.layers)...)
WithGraph(::AbstractFeaturedGraph, layer::WithGraph) = layer
WithGraph(::AbstractFeaturedGraph, layer) = layer

"""
GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity)

Passing features in `FeaturedGraph` in parallel. It takes `FeaturedGraph` as input
and it can be specified by assigning layers for specific (node, edge and global) features.

# Arguments

- `node_layer`: A regular Flux layer for passing node features.
- `edge_layer`: A regular Flux layer for passing edge features.
- `global_layer`: A regular Flux layer for passing global features.

# Example

```jldoctest
julia> using Flux, GeometricFlux

julia> l = GraphParallel(
node_layer=Dropout(0.5),
global_layer=Dense(10, 5)
)
GraphParallel(node_layer=Dropout(0.5), edge_layer=identity, global_layer=Dense(10, 5))
```
"""
struct GraphParallel{N,E,G}
node_layer::N
edge_layer::E
global_layer::G
end

@functor GraphParallel

GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity) =
GraphParallel(node_layer, edge_layer, global_layer)

function (l::GraphParallel)(fg::AbstractFeaturedGraph)
nf = l.node_layer(node_feature(fg))
ef = l.edge_layer(edge_feature(fg))
gf = l.global_layer(global_feature(fg))
return ConcreteFeaturedGraph(fg, nf=nf, ef=ef, gf=gf)
end

function Base.show(io::IO, l::GraphParallel)
print(io, "GraphParallel(")
print(io, "node_layer=", l.node_layer)
print(io, ", edge_layer=", l.edge_layer)
print(io, ", global_layer=", l.global_layer)
print(io, ")")
end
Loading