Skip to content

Add EEquivGraphPE layer and introduce nested EEquivGraphConv #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ DataStructures = "0.18"
FillArrays = "0.13"
Flux = "0.12 - 0.13"
GraphMLDatasets = "0.1"
GraphSignals = "0.4 - 0.5"
GraphSignals = "0.6"
Graphs = "1"
NNlib = "0.8"
NNlibCUDA = "0.2"
Expand Down
14 changes: 14 additions & 0 deletions docs/bibliography.bib
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,17 @@ @inproceedings{Hamilton2017
title = {Inductive Representation Learning on Large Graphs},
year = {2017},
}

@inproceedings{Satorras2021,
abstract = {This paper introduces a new model to learn graph neural networks equivariant to rotations, translations, reflections and permutations called E(n)-Equivariant Graph Neural Networks (EGNNs). In contrast with existing methods, our work does not require computationally expensive higher-order representations in intermediate layers while it still achieves competitive or better performance. In addition, whereas existing methods are limited to equivariance on 3 dimensional spaces, our model is easily scaled to higher-dimensional spaces. We demonstrate the effectiveness of our method on dynamical systems modelling, representation learning in graph autoencoders and predicting molecular properties.},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
editor = {Marina Meila and Tong Zhang},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
month = {2},
pages = {9323-9332},
publisher = {PMLR},
title = {E(n) Equivariant Graph Neural Networks},
volume = {139},
url = {http://arxiv.org/abs/2102.09844},
year = {2021},
}
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ makedocs(
"Dynamic Graph Update" => "dynamicgraph.md",
"Manual" => [
"FeaturedGraph" => "manual/featuredgraph.md",
"Graph Convolutional Layers" => "manual/conv.md",
"Graph Convolutional Layers" => "manual/graph_conv.md",
"Graph Pooling Layers" => "manual/pool.md",
"Group Convolutional Layers" => "manual/group_conv.md",
"Positional Encoding Layers" => "manual/positional.md",
"Embeddings" => "manual/embedding.md",
"Models" => "manual/models.md",
"Linear Algebra" => "manual/linalg.md",
Expand Down
2 changes: 2 additions & 0 deletions docs/src/manual/featuredgraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ GraphSignals.edge_feature
GraphSignals.has_edge_feature
GraphSignals.global_feature
GraphSignals.has_global_feature
GraphSignals.positional_feature
GraphSignals.has_positional_feature
GraphSignals.subgraph
GraphSignals.ConcreteFeaturedGraph
```
File renamed without changes.
19 changes: 19 additions & 0 deletions docs/src/manual/group_conv.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Group Convolutional Layers

## ``E(n)``-equivariant Convolutional Layer

It employs message-passing scheme and can be defined by following functions:

- message function (Eq. 3 from paper): ``m_{ij} = \phi_e(h_i^l, h_j^l, ||x_i^l - x_j^l||^2, a_{ij})``
- aggregate (Eq. 5 from paper): ``m_i = \sum_j m_{ij}``
- update function (Eq. 6 from paper): ``h_i^{l+1} = \phi_h(h_i^l, m_i)``

where ``h_i^l`` and ``h_j^l`` denotes the node feature for node ``i`` and ``j``, respectively, in ``l``-th layer, as well as ``x_i^l`` and ``x_j^l`` denote the positional feature for node ``i`` and ``j``, respectively, in ``l``-th layer. ``a_{ij}`` is the edge feature for edge ``(i,j)``. ``\phi_e`` and ``\phi_h`` are neural network for edges and nodes.

```@docs
EEquivGraphConv
```

Reference: [Satorras2021](@cite)

---
19 changes: 19 additions & 0 deletions docs/src/manual/positional.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Positional Encoding Layers

## ``E(n)``-equivariant Positional Encoding Layer

It employs message-passing scheme and can be defined by following functions:

- message function: ``y_{ij}^l = (x_i^l - x_j^l)\phi_x(m_{ij})``
- aggregate: ``y_i^l = \frac{1}{M} \sum_{j \in \mathcal{N}(i)} y_{ij}^l``
- update function: ``x_i^{l+1} = x_i^l + y_i^l``

where ``x_i^l`` and ``x_j^l`` denote the positional feature for node ``i`` and ``j``, respectively, in ``l``-th layer, ``\phi_x`` is the neural network for positional encoding and ``m_{ij}`` is the edge feature for edge ``(i,j)``. ``y_{ij}^l`` and ``y_i^l`` represent the encoded positional feature and aggregated positional feature, respectively, and ``M`` denotes number of neighbors of node ``i``.

```@docs
EEquivGraphPE
```

Reference: [Satorras2021](@cite)

---
6 changes: 6 additions & 0 deletions src/GeometricFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ export
# layers/msgpass
MessagePassing,

# layers/positional
AbstractPE,
positional_encode,
EEquivGraphPE,

# layers/graph_conv
GCNConv,
ChebConv,
Expand Down Expand Up @@ -75,6 +80,7 @@ include("layers/graphlayers.jl")
include("layers/gn.jl")
include("layers/msgpass.jl")

include("layers/positional.jl")
include("layers/graph_conv.jl")
include("layers/group_conv.jl")
include("layers/pool.jl")
Expand Down
159 changes: 159 additions & 0 deletions src/bspline.jl

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions src/groups.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import Base: *, inv

struct R{N} end

Base.ndims(::R{N}) where {N} = N

Base.identity(::R{N}, ::Type{T}=Float32) where {N,T<:Number} = zeros(T, N)


struct H{N} end

Base.ndims(::H{N}) where {N} = N

Base.identity(::H{N}, ::Type{T}=Float32) where {N,T<:Number} = ones(T, N)

(*)(h1, h2) = h1 * h2

inv(h) = 1. / h

Base.log(h) = log.(h)

Base.exp(c) = exp.(c)

"""
The logarithmic distance ||log(inv(h1).h2)||

"""
dist(h1, h2) = log(inv(h1) * h2)
174 changes: 78 additions & 96 deletions src/layers/group_conv.jl
Original file line number Diff line number Diff line change
@@ -1,128 +1,110 @@
"""
EEquivGraphConv(in_dim, int_dim, out_dim; init=glorot_uniform)
EEquivGraphConv(in_dim, nn_edge, nn_x, nn_h)
EEquivGraphConv(in_dim=>out_dim, pos_dim, edge_dim; init=glorot_uniform)

E(n)-equivariant graph neural network layer as defined in the paper "[E(n) Equivariant Neural Networks](https://arxiv.org/abs/2102.09844)" by Satorras, Hoogeboom, and Welling (2021).
E(n)-equivariant graph neural network layer.

# Arguments

Either one of two sets of arguments:

Set 1:

- `in_dim`: node feature dimension. Data is assumed to be of the form [feature; coordinate], so `in_dim` must strictly be less than the dimension of the input vectors.
- `int_dim`: intermediate dimension, can be arbitrary.
- `in_dim::Int`: node feature dimension. Data is assumed to be of the form [feature; coordinate], so `in_dim` must strictly be less than the dimension of the input vectors.
- `out_dim`: the output of the layer will have dimension `out_dim` + (dimension of input vector - `in_dim`).
- `init`: neural network initialization function, should be compatible with `Flux.Dense`.

Set 2:
- `pos_dim::Int`: dimension of positional encoding.
- `edge_dim::Int`: dimension of edge feature.
- `init`: neural network initialization function.

- `in_dim`: as in Set 1.
- `nn_edge`: a differentiable function that must take vectors of dimension `in_dim * 2 + 2` (output designated `int_dim`)
- `nn_x`: a differentiable function that must take vectors of dimension `int_dim` to dimension `1`.
- `nn_h`: a differentiable function that must take vectors of dimension `in_dim + int_dim` to `out_dim`.
# Examples

```jldoctest
julia> in_dim, int_dim, out_dim = 3,6,5
(3, 5, 5)

julia> egnn = EEquivGraphConv(in_dim, int_dim, out_dim)
EEquivGraphConv{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}(Dense(8 => 5), Dense(5 => 1), Dense(8 => 5), 3, 5, 5)

julia> m_len = 2*in_dim + 2
8

julia> nn_edge = Flux.Dense(m_len, int_dim)
Dense(8 => 5) # 45 parameters
julia> in_dim, out_dim, pos_dim = 3, 5, 2
(3, 5, 2)

julia> nn_x = Flux.Dense(int_dim, 1)
Dense(5 => 1) # 6 parameters

julia> nn_h = Flux.Dense(in_dim + int_dim, out_dim)
Dense(8 => 5) # 45 parameters

julia> egnn = EEquivGraphConv(in_dim, nn_edge, nn_x, nn_h)
EEquivGraphConv{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}(Dense(8 => 5), Dense(5 => 1), Dense(8 => 5), 3, 5, 5)
julia> egnn = EEquivGraphConv(in_dim=>out_dim, pos_dim, in_dim)
EEquivGraphConv(ϕ_edge=Dense(10 => 5), ϕ_x=Dense(5 => 2), ϕ_h=Dense(8 => 5))
```
"""

struct EEquivGraphConv{E,X,H} <: MessagePassing
nn_edge::E
nn_x::X
nn_h::H
in_dim::Int
int_dim::Int
out_dim::Int
See also [`WithGraph`](@ref) for training layer with static graph and [`EEquivGraphPE`](@ref) for positional encoding.
"""
struct EEquivGraphConv{X,E,H} <: AbstractGraphLayer
pe::X
nn_edge::E
nn_h::H
end

@functor EEquivGraphConv

function EEquivGraphConv(in_dim::Int, int_dim::Int, out_dim::Int; init=glorot_uniform)
m_len = 2in_dim + 2
nn_edge = Flux.Dense(m_len, int_dim; init=init)
nn_x = Flux.Dense(int_dim, 1; init=init)
nn_h = Flux.Dense(in_dim + int_dim, out_dim; init=init)
return EEquivGraphConv(nn_edge, nn_x, nn_h, in_dim, int_dim, out_dim)
end
Flux.trainable(l::EEquivGraphConv) = (l.pe, l.nn_edge, l.nn_h)

function EEquivGraphConv(in_dim::Int, nn_edge, nn_x, nn_h)
m_len = 2in_dim + 2
int_dim = Flux.outputsize(nn_edge, (m_len, 2))[1]
out_dim = Flux.outputsize(nn_h, (in_dim + int_dim, 2))[1]
return EEquivGraphConv(nn_edge, nn_x, nn_h, in_dim, int_dim, out_dim)
function EEquivGraphConv(ch::Pair{Int,Int}, pos_dim::Int, edge_dim::Int; init=glorot_uniform)
in_dim, out_dim = ch
nn_edge = Flux.Dense(2in_dim + edge_dim + 1, out_dim; init=init)
pe = EEquivGraphPE(out_dim=>pos_dim; init=init)
nn_h = Flux.Dense(in_dim + out_dim, out_dim; init=init)
return EEquivGraphConv(pe, nn_edge, nn_h)
end

function ϕ_edge(egnn::EEquivGraphConv, h_i, h_j, dist, a)
N = size(h_i, 2)
return egnn.nn_edge(vcat(h_i, h_j, dist, ones(N)' * a))
end

ϕ_x(egnn::EEquivGraphConv, m_ij) = egnn.nn_x(m_ij)

function message(egnn::EEquivGraphConv, v_i, v_j, e)
in_dim = egnn.in_dim
h_i = v_i[1:in_dim,:]
h_j = v_j[1:in_dim,:]
ϕ_edge(l::EEquivGraphConv, h_i, h_j, dist, a) = l.nn_edge(vcat(h_i, h_j, dist, a))

N = size(h_i, 2)
function message(l::EEquivGraphConv, h_i, h_j, x_i, x_j, e)
dist = sum(abs2, x_i - x_j; dims=1)
return ϕ_edge(l, h_i, h_j, dist, e)
end

x_i = v_i[in_dim+1:end,:]
x_j = v_j[in_dim+1:end,:]
update(l::EEquivGraphConv, m, h) = l.nn_h(vcat(h, m))

if isnothing(e)
a = 1
else
a = e[1]
end
# For variable graph
function(egnn::EEquivGraphConv)(fg::AbstractFeaturedGraph)
nf = node_feature(fg)
ef = edge_feature(fg)
pf = positional_feature(fg)
GraphSignals.check_num_nodes(fg, nf)
GraphSignals.check_num_edges(fg, ef)
_, V, X = propagate(egnn, graph(fg), ef, nf, pf, +)
return ConcreteFeaturedGraph(fg, nf=V, pf=X)
end

dist = sum(abs2.(x_i - x_j); dims=1)
edge_msg = ϕ_edge(egnn, h_i, h_j, dist, a)
output_vec = vcat(edge_msg, (x_i - x_j) .* ϕ_x(egnn, edge_msg)[1], ones(N)')
return reshape(output_vec, :, N)
# For static graph
function(l::EEquivGraphConv)(el::NamedTuple, H::AbstractArray, E::AbstractArray, X::AbstractArray)
GraphSignals.check_num_nodes(el.N, H)
GraphSignals.check_num_nodes(el.N, X)
GraphSignals.check_num_edges(el.E, E)
_, V, X = propagate(l, el, E, H, X, +)
return V, X
end

function update(e::EEquivGraphConv, m, h)
N = size(m, 2)
mi = m[1:e.int_dim,:]
x_msg = m[e.int_dim+1:end-1,:]
M = m[end,:]
function Base.show(io::IO, l::EEquivGraphConv)
print(io, "EEquivGraphConv(ϕ_edge=", l.nn_edge)
print(io, ", ϕ_x=", l.pe.nn)
print(io, ", ϕ_h=", l.nn_h)
print(io, ")")
end

C = 1 ./ (M.-1)
C = reshape(C, :, N)
function aggregate_neighbors(::EEquivGraphConv, el::NamedTuple, aggr, E)
batch_size = size(E)[end]
dstsize = (size(E, 1), el.N, batch_size)
xs = batched_index(el.xs, batch_size)
return _scatter(aggr, E, xs, dstsize)
end

nn_node_out = e.nn_h(vcat(h[1:e.in_dim,:], mi))
aggregate_neighbors(::EEquivGraphConv, el::NamedTuple, aggr, E::AbstractMatrix) = _scatter(aggr, E, el.xs)

coord_dim = size(h,1) - e.in_dim
@inline aggregate_neighbors(::EEquivGraphConv, ::NamedTuple, ::Nothing, E) = nothing
@inline aggregate_neighbors(::EEquivGraphConv, ::NamedTuple, ::Nothing, ::AbstractMatrix) = nothing

z = zeros(e.out_dim + coord_dim, N)
z[1:e.out_dim,:] = nn_node_out
z[e.out_dim+1:end,:] = h[e.in_dim+1:end,:] + C .* x_msg
return z
function propagate(l::EEquivGraphConv, sg::SparseGraph, E, V, X, aggr)
el = to_namedtuple(sg)
return propagate(l, el, E, V, X, aggr)
end

function(egnn::EEquivGraphConv)(fg::AbstractFeaturedGraph)
X = node_feature(fg)
GraphSignals.check_num_nodes(fg, X)
_, V, _ = propagate(egnn, graph(fg), nothing, X, nothing, +, nothing, nothing)
return ConcreteFeaturedGraph(fg, nf=V)
function propagate(l::EEquivGraphConv, el::NamedTuple, E, V, X, aggr)
E = message(
l, _gather(V, el.xs), _gather(V, el.nbrs),
_gather(X, el.xs), _gather(X, el.nbrs),
_gather(E, el.es)
)
X = positional_encode(l.pe, el, X, E)
Ē = aggregate_neighbors(l, el, aggr, E)
V = update(l, Ē, V)
return E, V, X
end

WithGraph(fg::AbstractFeaturedGraph, l::EEquivGraphConv) = WithGraph(to_namedtuple(fg), l)
(wg::WithGraph{<:EEquivGraphConv})(args...) = wg.layer(wg.graph, args...)
Loading