Skip to content

Commit 2d2f20d

Browse files
authored
Merge pull request #308 from FluxML/develop
Add EEquivGraphPE layer and introduce nested EEquivGraphConv
2 parents 0378e18 + 4dc9ed0 commit 2d2f20d

File tree

15 files changed

+475
-118
lines changed

15 files changed

+475
-118
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ DataStructures = "0.18"
3232
FillArrays = "0.13"
3333
Flux = "0.12 - 0.13"
3434
GraphMLDatasets = "0.1"
35-
GraphSignals = "0.4 - 0.5"
35+
GraphSignals = "0.6"
3636
Graphs = "1"
3737
NNlib = "0.8"
3838
NNlibCUDA = "0.2"

docs/bibliography.bib

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,17 @@ @inproceedings{Hamilton2017
183183
title = {Inductive Representation Learning on Large Graphs},
184184
year = {2017},
185185
}
186+
187+
@inproceedings{Satorras2021,
188+
abstract = {This paper introduces a new model to learn graph neural networks equivariant to rotations, translations, reflections and permutations called E(n)-Equivariant Graph Neural Networks (EGNNs). In contrast with existing methods, our work does not require computationally expensive higher-order representations in intermediate layers while it still achieves competitive or better performance. In addition, whereas existing methods are limited to equivariance on 3 dimensional spaces, our model is easily scaled to higher-dimensional spaces. We demonstrate the effectiveness of our method on dynamical systems modelling, representation learning in graph autoencoders and predicting molecular properties.},
189+
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
190+
editor = {Marina Meila and Tong Zhang},
191+
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
192+
month = {2},
193+
pages = {9323-9332},
194+
publisher = {PMLR},
195+
title = {E(n) Equivariant Graph Neural Networks},
196+
volume = {139},
197+
url = {http://arxiv.org/abs/2102.09844},
198+
year = {2021},
199+
}

docs/make.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ makedocs(
4242
"Dynamic Graph Update" => "dynamicgraph.md",
4343
"Manual" => [
4444
"FeaturedGraph" => "manual/featuredgraph.md",
45-
"Graph Convolutional Layers" => "manual/conv.md",
45+
"Graph Convolutional Layers" => "manual/graph_conv.md",
4646
"Graph Pooling Layers" => "manual/pool.md",
47+
"Group Convolutional Layers" => "manual/group_conv.md",
48+
"Positional Encoding Layers" => "manual/positional.md",
4749
"Embeddings" => "manual/embedding.md",
4850
"Models" => "manual/models.md",
4951
"Linear Algebra" => "manual/linalg.md",

docs/src/manual/featuredgraph.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ GraphSignals.edge_feature
99
GraphSignals.has_edge_feature
1010
GraphSignals.global_feature
1111
GraphSignals.has_global_feature
12+
GraphSignals.positional_feature
13+
GraphSignals.has_positional_feature
1214
GraphSignals.subgraph
1315
GraphSignals.ConcreteFeaturedGraph
1416
```
File renamed without changes.

docs/src/manual/group_conv.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Group Convolutional Layers
2+
3+
## ``E(n)``-equivariant Convolutional Layer
4+
5+
It employs message-passing scheme and can be defined by following functions:
6+
7+
- message function (Eq. 3 from paper): ``m_{ij} = \phi_e(h_i^l, h_j^l, ||x_i^l - x_j^l||^2, a_{ij})``
8+
- aggregate (Eq. 5 from paper): ``m_i = \sum_j m_{ij}``
9+
- update function (Eq. 6 from paper): ``h_i^{l+1} = \phi_h(h_i^l, m_i)``
10+
11+
where ``h_i^l`` and ``h_j^l`` denotes the node feature for node ``i`` and ``j``, respectively, in ``l``-th layer, as well as ``x_i^l`` and ``x_j^l`` denote the positional feature for node ``i`` and ``j``, respectively, in ``l``-th layer. ``a_{ij}`` is the edge feature for edge ``(i,j)``. ``\phi_e`` and ``\phi_h`` are neural network for edges and nodes.
12+
13+
```@docs
14+
EEquivGraphConv
15+
```
16+
17+
Reference: [Satorras2021](@cite)
18+
19+
---

docs/src/manual/positional.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Positional Encoding Layers
2+
3+
## ``E(n)``-equivariant Positional Encoding Layer
4+
5+
It employs message-passing scheme and can be defined by following functions:
6+
7+
- message function: ``y_{ij}^l = (x_i^l - x_j^l)\phi_x(m_{ij})``
8+
- aggregate: ``y_i^l = \frac{1}{M} \sum_{j \in \mathcal{N}(i)} y_{ij}^l``
9+
- update function: ``x_i^{l+1} = x_i^l + y_i^l``
10+
11+
where ``x_i^l`` and ``x_j^l`` denote the positional feature for node ``i`` and ``j``, respectively, in ``l``-th layer, ``\phi_x`` is the neural network for positional encoding and ``m_{ij}`` is the edge feature for edge ``(i,j)``. ``y_{ij}^l`` and ``y_i^l`` represent the encoded positional feature and aggregated positional feature, respectively, and ``M`` denotes number of neighbors of node ``i``.
12+
13+
```@docs
14+
EEquivGraphPE
15+
```
16+
17+
Reference: [Satorras2021](@cite)
18+
19+
---

src/GeometricFlux.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ export
3131
# layers/msgpass
3232
MessagePassing,
3333

34+
# layers/positional
35+
AbstractPE,
36+
positional_encode,
37+
EEquivGraphPE,
38+
3439
# layers/graph_conv
3540
GCNConv,
3641
ChebConv,
@@ -75,6 +80,7 @@ include("layers/graphlayers.jl")
7580
include("layers/gn.jl")
7681
include("layers/msgpass.jl")
7782

83+
include("layers/positional.jl")
7884
include("layers/graph_conv.jl")
7985
include("layers/group_conv.jl")
8086
include("layers/pool.jl")

src/bspline.jl

Lines changed: 159 additions & 0 deletions
Large diffs are not rendered by default.

src/groups.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import Base: *, inv
2+
3+
struct R{N} end
4+
5+
Base.ndims(::R{N}) where {N} = N
6+
7+
Base.identity(::R{N}, ::Type{T}=Float32) where {N,T<:Number} = zeros(T, N)
8+
9+
10+
struct H{N} end
11+
12+
Base.ndims(::H{N}) where {N} = N
13+
14+
Base.identity(::H{N}, ::Type{T}=Float32) where {N,T<:Number} = ones(T, N)
15+
16+
(*)(h1, h2) = h1 * h2
17+
18+
inv(h) = 1. / h
19+
20+
Base.log(h) = log.(h)
21+
22+
Base.exp(c) = exp.(c)
23+
24+
"""
25+
The logarithmic distance ||log(inv(h1).h2)||
26+
27+
"""
28+
dist(h1, h2) = log(inv(h1) * h2)

src/layers/group_conv.jl

Lines changed: 78 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,128 +1,110 @@
11
"""
2-
EEquivGraphConv(in_dim, int_dim, out_dim; init=glorot_uniform)
3-
EEquivGraphConv(in_dim, nn_edge, nn_x, nn_h)
2+
EEquivGraphConv(in_dim=>out_dim, pos_dim, edge_dim; init=glorot_uniform)
43
5-
E(n)-equivariant graph neural network layer as defined in the paper "[E(n) Equivariant Neural Networks](https://arxiv.org/abs/2102.09844)" by Satorras, Hoogeboom, and Welling (2021).
4+
E(n)-equivariant graph neural network layer.
65
76
# Arguments
87
9-
Either one of two sets of arguments:
10-
11-
Set 1:
12-
13-
- `in_dim`: node feature dimension. Data is assumed to be of the form [feature; coordinate], so `in_dim` must strictly be less than the dimension of the input vectors.
14-
- `int_dim`: intermediate dimension, can be arbitrary.
8+
- `in_dim::Int`: node feature dimension. Data is assumed to be of the form [feature; coordinate], so `in_dim` must strictly be less than the dimension of the input vectors.
159
- `out_dim`: the output of the layer will have dimension `out_dim` + (dimension of input vector - `in_dim`).
16-
- `init`: neural network initialization function, should be compatible with `Flux.Dense`.
17-
18-
Set 2:
10+
- `pos_dim::Int`: dimension of positional encoding.
11+
- `edge_dim::Int`: dimension of edge feature.
12+
- `init`: neural network initialization function.
1913
20-
- `in_dim`: as in Set 1.
21-
- `nn_edge`: a differentiable function that must take vectors of dimension `in_dim * 2 + 2` (output designated `int_dim`)
22-
- `nn_x`: a differentiable function that must take vectors of dimension `int_dim` to dimension `1`.
23-
- `nn_h`: a differentiable function that must take vectors of dimension `in_dim + int_dim` to `out_dim`.
14+
# Examples
2415
2516
```jldoctest
26-
julia> in_dim, int_dim, out_dim = 3,6,5
27-
(3, 5, 5)
28-
29-
julia> egnn = EEquivGraphConv(in_dim, int_dim, out_dim)
30-
EEquivGraphConv{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}(Dense(8 => 5), Dense(5 => 1), Dense(8 => 5), 3, 5, 5)
31-
32-
julia> m_len = 2*in_dim + 2
33-
8
34-
35-
julia> nn_edge = Flux.Dense(m_len, int_dim)
36-
Dense(8 => 5) # 45 parameters
17+
julia> in_dim, out_dim, pos_dim = 3, 5, 2
18+
(3, 5, 2)
3719
38-
julia> nn_x = Flux.Dense(int_dim, 1)
39-
Dense(5 => 1) # 6 parameters
40-
41-
julia> nn_h = Flux.Dense(in_dim + int_dim, out_dim)
42-
Dense(8 => 5) # 45 parameters
43-
44-
julia> egnn = EEquivGraphConv(in_dim, nn_edge, nn_x, nn_h)
45-
EEquivGraphConv{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}(Dense(8 => 5), Dense(5 => 1), Dense(8 => 5), 3, 5, 5)
20+
julia> egnn = EEquivGraphConv(in_dim=>out_dim, pos_dim, in_dim)
21+
EEquivGraphConv(ϕ_edge=Dense(10 => 5), ϕ_x=Dense(5 => 2), ϕ_h=Dense(8 => 5))
4622
```
47-
"""
4823
49-
struct EEquivGraphConv{E,X,H} <: MessagePassing
50-
nn_edge::E
51-
nn_x::X
52-
nn_h::H
53-
in_dim::Int
54-
int_dim::Int
55-
out_dim::Int
24+
See also [`WithGraph`](@ref) for training layer with static graph and [`EEquivGraphPE`](@ref) for positional encoding.
25+
"""
26+
struct EEquivGraphConv{X,E,H} <: AbstractGraphLayer
27+
pe::X
28+
nn_edge::E
29+
nn_h::H
5630
end
5731

5832
@functor EEquivGraphConv
5933

60-
function EEquivGraphConv(in_dim::Int, int_dim::Int, out_dim::Int; init=glorot_uniform)
61-
m_len = 2in_dim + 2
62-
nn_edge = Flux.Dense(m_len, int_dim; init=init)
63-
nn_x = Flux.Dense(int_dim, 1; init=init)
64-
nn_h = Flux.Dense(in_dim + int_dim, out_dim; init=init)
65-
return EEquivGraphConv(nn_edge, nn_x, nn_h, in_dim, int_dim, out_dim)
66-
end
34+
Flux.trainable(l::EEquivGraphConv) = (l.pe, l.nn_edge, l.nn_h)
6735

68-
function EEquivGraphConv(in_dim::Int, nn_edge, nn_x, nn_h)
69-
m_len = 2in_dim + 2
70-
int_dim = Flux.outputsize(nn_edge, (m_len, 2))[1]
71-
out_dim = Flux.outputsize(nn_h, (in_dim + int_dim, 2))[1]
72-
return EEquivGraphConv(nn_edge, nn_x, nn_h, in_dim, int_dim, out_dim)
36+
function EEquivGraphConv(ch::Pair{Int,Int}, pos_dim::Int, edge_dim::Int; init=glorot_uniform)
37+
in_dim, out_dim = ch
38+
nn_edge = Flux.Dense(2in_dim + edge_dim + 1, out_dim; init=init)
39+
pe = EEquivGraphPE(out_dim=>pos_dim; init=init)
40+
nn_h = Flux.Dense(in_dim + out_dim, out_dim; init=init)
41+
return EEquivGraphConv(pe, nn_edge, nn_h)
7342
end
7443

75-
function ϕ_edge(egnn::EEquivGraphConv, h_i, h_j, dist, a)
76-
N = size(h_i, 2)
77-
return egnn.nn_edge(vcat(h_i, h_j, dist, ones(N)' * a))
78-
end
79-
80-
ϕ_x(egnn::EEquivGraphConv, m_ij) = egnn.nn_x(m_ij)
81-
82-
function message(egnn::EEquivGraphConv, v_i, v_j, e)
83-
in_dim = egnn.in_dim
84-
h_i = v_i[1:in_dim,:]
85-
h_j = v_j[1:in_dim,:]
44+
ϕ_edge(l::EEquivGraphConv, h_i, h_j, dist, a) = l.nn_edge(vcat(h_i, h_j, dist, a))
8645

87-
N = size(h_i, 2)
46+
function message(l::EEquivGraphConv, h_i, h_j, x_i, x_j, e)
47+
dist = sum(abs2, x_i - x_j; dims=1)
48+
return ϕ_edge(l, h_i, h_j, dist, e)
49+
end
8850

89-
x_i = v_i[in_dim+1:end,:]
90-
x_j = v_j[in_dim+1:end,:]
51+
update(l::EEquivGraphConv, m, h) = l.nn_h(vcat(h, m))
9152

92-
if isnothing(e)
93-
a = 1
94-
else
95-
a = e[1]
96-
end
53+
# For variable graph
54+
function(egnn::EEquivGraphConv)(fg::AbstractFeaturedGraph)
55+
nf = node_feature(fg)
56+
ef = edge_feature(fg)
57+
pf = positional_feature(fg)
58+
GraphSignals.check_num_nodes(fg, nf)
59+
GraphSignals.check_num_edges(fg, ef)
60+
_, V, X = propagate(egnn, graph(fg), ef, nf, pf, +)
61+
return ConcreteFeaturedGraph(fg, nf=V, pf=X)
62+
end
9763

98-
dist = sum(abs2.(x_i - x_j); dims=1)
99-
edge_msg = ϕ_edge(egnn, h_i, h_j, dist, a)
100-
output_vec = vcat(edge_msg, (x_i - x_j) .* ϕ_x(egnn, edge_msg)[1], ones(N)')
101-
return reshape(output_vec, :, N)
64+
# For static graph
65+
function(l::EEquivGraphConv)(el::NamedTuple, H::AbstractArray, E::AbstractArray, X::AbstractArray)
66+
GraphSignals.check_num_nodes(el.N, H)
67+
GraphSignals.check_num_nodes(el.N, X)
68+
GraphSignals.check_num_edges(el.E, E)
69+
_, V, X = propagate(l, el, E, H, X, +)
70+
return V, X
10271
end
10372

104-
function update(e::EEquivGraphConv, m, h)
105-
N = size(m, 2)
106-
mi = m[1:e.int_dim,:]
107-
x_msg = m[e.int_dim+1:end-1,:]
108-
M = m[end,:]
73+
function Base.show(io::IO, l::EEquivGraphConv)
74+
print(io, "EEquivGraphConv(ϕ_edge=", l.nn_edge)
75+
print(io, ", ϕ_x=", l.pe.nn)
76+
print(io, ", ϕ_h=", l.nn_h)
77+
print(io, ")")
78+
end
10979

110-
C = 1 ./ (M.-1)
111-
C = reshape(C, :, N)
80+
function aggregate_neighbors(::EEquivGraphConv, el::NamedTuple, aggr, E)
81+
batch_size = size(E)[end]
82+
dstsize = (size(E, 1), el.N, batch_size)
83+
xs = batched_index(el.xs, batch_size)
84+
return _scatter(aggr, E, xs, dstsize)
85+
end
11286

113-
nn_node_out = e.nn_h(vcat(h[1:e.in_dim,:], mi))
87+
aggregate_neighbors(::EEquivGraphConv, el::NamedTuple, aggr, E::AbstractMatrix) = _scatter(aggr, E, el.xs)
11488

115-
coord_dim = size(h,1) - e.in_dim
89+
@inline aggregate_neighbors(::EEquivGraphConv, ::NamedTuple, ::Nothing, E) = nothing
90+
@inline aggregate_neighbors(::EEquivGraphConv, ::NamedTuple, ::Nothing, ::AbstractMatrix) = nothing
11691

117-
z = zeros(e.out_dim + coord_dim, N)
118-
z[1:e.out_dim,:] = nn_node_out
119-
z[e.out_dim+1:end,:] = h[e.in_dim+1:end,:] + C .* x_msg
120-
return z
92+
function propagate(l::EEquivGraphConv, sg::SparseGraph, E, V, X, aggr)
93+
el = to_namedtuple(sg)
94+
return propagate(l, el, E, V, X, aggr)
12195
end
12296

123-
function(egnn::EEquivGraphConv)(fg::AbstractFeaturedGraph)
124-
X = node_feature(fg)
125-
GraphSignals.check_num_nodes(fg, X)
126-
_, V, _ = propagate(egnn, graph(fg), nothing, X, nothing, +, nothing, nothing)
127-
return ConcreteFeaturedGraph(fg, nf=V)
97+
function propagate(l::EEquivGraphConv, el::NamedTuple, E, V, X, aggr)
98+
E = message(
99+
l, _gather(V, el.xs), _gather(V, el.nbrs),
100+
_gather(X, el.xs), _gather(X, el.nbrs),
101+
_gather(E, el.es)
102+
)
103+
X = positional_encode(l.pe, el, X, E)
104+
= aggregate_neighbors(l, el, aggr, E)
105+
V = update(l, Ē, V)
106+
return E, V, X
128107
end
108+
109+
WithGraph(fg::AbstractFeaturedGraph, l::EEquivGraphConv) = WithGraph(to_namedtuple(fg), l)
110+
(wg::WithGraph{<:EEquivGraphConv})(args...) = wg.layer(wg.graph, args...)

0 commit comments

Comments
 (0)