Skip to content

New features: positional encoding #317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ DataStructures = "0.18"
FillArrays = "0.13"
Flux = "0.12 - 0.13"
GraphMLDatasets = "0.1"
GraphSignals = "0.4 - 0.5"
GraphSignals = "0.6"
Graphs = "1"
NNlib = "0.8"
NNlibCUDA = "0.2"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"

[compat]
Documenter = "0.27"
23 changes: 23 additions & 0 deletions docs/bibliography.bib
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,29 @@ @inproceedings{Hamilton2017
year = {2017},
}

@inproceedings{Satorras2021,
abstract = {This paper introduces a new model to learn graph neural networks equivariant to rotations, translations, reflections and permutations called E(n)-Equivariant Graph Neural Networks (EGNNs). In contrast with existing methods, our work does not require computationally expensive higher-order representations in intermediate layers while it still achieves competitive or better performance. In addition, whereas existing methods are limited to equivariance on 3 dimensional spaces, our model is easily scaled to higher-dimensional spaces. We demonstrate the effectiveness of our method on dynamical systems modelling, representation learning in graph autoencoders and predicting molecular properties.},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
editor = {Marina Meila and Tong Zhang},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
month = {2},
pages = {9323-9332},
publisher = {PMLR},
title = {E(n) Equivariant Graph Neural Networks},
volume = {139},
url = {http://arxiv.org/abs/2102.09844},
year = {2021},
}

@article{Dwivedi2021,
abstract = {Graph neural networks (GNNs) have become the standard learning architectures for graphs. GNNs have been applied to numerous domains ranging from quantum chemistry, recommender systems to knowledge graphs and natural language processing. A major issue with arbitrary graphs is the absence of canonical positional information of nodes, which decreases the representation power of GNNs to distinguish e.g. isomorphic nodes and other graph symmetries. An approach to tackle this issue is to introduce Positional Encoding (PE) of nodes, and inject it into the input layer, like in Transformers. Possible graph PE are Laplacian eigenvectors. In this work, we propose to decouple structural and positional representations to make easy for the network to learn these two essential properties. We introduce a novel generic architecture which we call LSPE (Learnable Structural and Positional Encodings). We investigate several sparse and fully-connected (Transformer-like) GNNs, and observe a performance increase for molecular datasets, from 2.87% up to 64.14% when considering learnable PE for both GNN classes.},
author = {Vijay Prakash Dwivedi and Anh Tuan Luu and Thomas Laurent and Yoshua Bengio and Xavier Bresson},
month = {10},
title = {Graph Neural Networks with Learnable Structural and Positional Representations},
url = {http://arxiv.org/abs/2110.07875},
year = {2021},
}

@article{Battaglia2018,
abstract = {Artificial intelligence (AI) has undergone a renaissance recently, making major progress in key domains such as vision, language, control, and decision-making. This has been due, in part, to cheap data and cheap compute resources, which have fit the natural strengths of deep learning. However, many defining characteristics of human intelligence, which developed under much different pressures, remain out of reach for current approaches. In particular, generalizing beyond one's experiences--a hallmark of human intelligence from infancy--remains a formidable challenge for modern AI. The following is part position paper, part review, and part unification. We argue that combinatorial generalization must be a top priority for AI to achieve human-like abilities, and that structured representations and computations are key to realizing this objective. Just as biology uses nature and nurture cooperatively, we reject the false choice between "hand-engineering" and "end-to-end" learning, and instead advocate for an approach which benefits from their complementary strengths. We explore how using relational inductive biases within deep learning architectures can facilitate learning about entities, relations, and rules for composing them. We present a new building block for the AI toolkit with a strong relational inductive bias--the graph network--which generalizes and extends various approaches for neural networks that operate on graphs, and provides a straightforward interface for manipulating structured knowledge and producing structured behaviors. We discuss how graph networks can support relational reasoning and combinatorial generalization, laying the foundation for more sophisticated, interpretable, and flexible patterns of reasoning. As a companion to this paper, we have released an open-source software library for building graph networks, with demonstrations of how to use them in practice.},
author = {Peter W. Battaglia and Jessica B. Hamrick and Victor Bapst and Alvaro Sanchez-Gonzalez and Vinicius Zambaldi and Mateusz Malinowski and Andrea Tacchetti and David Raposo and Adam Santoro and Ryan Faulkner and Caglar Gulcehre and Francis Song and Andrew Ballard and Justin Gilmer and George Dahl and Ashish Vaswani and Kelsey Allen and Charles Nash and Victoria Langston and Chris Dyer and Nicolas Heess and Daan Wierstra and Pushmeet Kohli and Matt Botvinick and Oriol Vinyals and Yujia Li and Razvan Pascanu},
Expand Down
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ makedocs(
"Dynamic Graph Update" => "dynamicgraph.md",
"Manual" => [
"FeaturedGraph" => "manual/featuredgraph.md",
"Graph Convolutional Layers" => "manual/conv.md",
"Graph Convolutional Layers" => "manual/graph_conv.md",
"Graph Pooling Layers" => "manual/pool.md",
"Group Convolutional Layers" => "manual/group_conv.md",
"Positional Encoding Layers" => "manual/positional.md",
"Embeddings" => "manual/embedding.md",
"Models" => "manual/models.md",
"Linear Algebra" => "manual/linalg.md",
Expand Down
2 changes: 2 additions & 0 deletions docs/src/manual/featuredgraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ GraphSignals.edge_feature
GraphSignals.has_edge_feature
GraphSignals.global_feature
GraphSignals.has_global_feature
GraphSignals.positional_feature
GraphSignals.has_positional_feature
GraphSignals.subgraph
GraphSignals.ConcreteFeaturedGraph
```
File renamed without changes.
19 changes: 19 additions & 0 deletions docs/src/manual/group_conv.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Group Convolutional Layers

## ``E(n)``-equivariant Convolutional Layer

It employs message-passing scheme and can be defined by following functions:

- message function (Eq. 3 from paper): ``m_{ij} = \phi_e(h_i^l, h_j^l, ||x_i^l - x_j^l||^2, a_{ij})``
- aggregate (Eq. 5 from paper): ``m_i = \sum_j m_{ij}``
- update function (Eq. 6 from paper): ``h_i^{l+1} = \phi_h(h_i^l, m_i)``

where ``h_i^l`` and ``h_j^l`` denotes the node feature for node ``i`` and ``j``, respectively, in ``l``-th layer, as well as ``x_i^l`` and ``x_j^l`` denote the positional feature for node ``i`` and ``j``, respectively, in ``l``-th layer. ``a_{ij}`` is the edge feature for edge ``(i,j)``. ``\phi_e`` and ``\phi_h`` are neural network for edges and nodes.

```@docs
EEquivGraphConv
```

Reference: [Satorras2021](@cite)

---
43 changes: 43 additions & 0 deletions docs/src/manual/positional.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Positional Encoding

## Positional Encoding Methods

```@docs
AbstractPositionalEncoding
RandomWalkPE
LaplacianPE
positional_encode
```

## Positional Encoding Layers

### ``E(n)``-equivariant Positional Encoding Layer

It employs message-passing scheme and can be defined by following functions:

- message function: ``y_{ij}^l = (x_i^l - x_j^l)\phi_x(m_{ij})``
- aggregate: ``y_i^l = \frac{1}{M} \sum_{j \in \mathcal{N}(i)} y_{ij}^l``
- update function: ``x_i^{l+1} = x_i^l + y_i^l``

where ``x_i^l`` and ``x_j^l`` denote the positional feature for node ``i`` and ``j``, respectively, in ``l``-th layer, ``\phi_x`` is the neural network for positional encoding and ``m_{ij}`` is the edge feature for edge ``(i,j)``. ``y_{ij}^l`` and ``y_i^l`` represent the encoded positional feature and aggregated positional feature, respectively, and ``M`` denotes number of neighbors of node ``i``.

```@docs
EEquivGraphPE
```

Reference: [Satorras2021](@cite)

---

### Learnable Structural Positional Encoding layer

(WIP)

```@docs
LSPE
```

Reference: [Dwivedi2021](@cite)

---

17 changes: 15 additions & 2 deletions src/GeometricFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ export
# layers/msgpass
MessagePassing,

# layers/conv
# layers/positional
AbstractPositionalEncoding,
RandomWalkPE,
LaplacianPE,
positional_encode,
EEquivGraphPE,
LSPE,

# layers/graph_conv
GCNConv,
ChebConv,
GraphConv,
Expand All @@ -44,6 +52,9 @@ export
MeanAggregator, MeanPoolAggregator, MaxPoolAggregator,
LSTMAggregator,

# layers/group_conv
EEquivGraphConv,

# layer/pool
GlobalPool,
LocalPool,
Expand Down Expand Up @@ -71,7 +82,9 @@ include("layers/graphlayers.jl")
include("layers/gn.jl")
include("layers/msgpass.jl")

include("layers/conv.jl")
include("layers/positional.jl")
include("layers/graph_conv.jl")
include("layers/group_conv.jl")
include("layers/pool.jl")
include("models.jl")

Expand Down
17 changes: 7 additions & 10 deletions src/layers/gn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ aggregate_vertices(::GraphNet, aggr, V) = aggregate(aggr, V)
@inline aggregate_vertices(::GraphNet, ::Nothing, V) = nothing

function propagate(gn::GraphNet, sg::SparseGraph, E, V, u, naggr, eaggr, vaggr)
el = to_namedtuple(sg)
el = GraphSignals.to_namedtuple(sg)
return propagate(gn, el, E, V, u, naggr, eaggr, vaggr)
end

Expand All @@ -179,14 +179,11 @@ function propagate(gn::GraphNet, el::NamedTuple, E, V, u, naggr, eaggr, vaggr)
return E, V, u
end

WithGraph(fg::AbstractFeaturedGraph, gn::GraphNet) = WithGraph(to_namedtuple(fg), gn)
WithGraph(gn::GraphNet; dynamic=nothing) = WithGraph(DynamicGraph(dynamic), gn)
WithGraph(fg::AbstractFeaturedGraph, gn::GraphNet) =
WithGraph(GraphSignals.to_namedtuple(fg), gn, positional_feature(fg))

to_namedtuple(fg::AbstractFeaturedGraph) = to_namedtuple(graph(fg))
WithGraph(gn::GraphNet; dynamic=nothing) =
WithGraph(DynamicGraph(dynamic), gn, GraphSignals.NullDomain())

function to_namedtuple(sg::SparseGraph)
es, nbrs, xs = collect(edges(sg))
return (N=nv(sg), E=ne(sg), es=es, nbrs=nbrs, xs=xs)
end

@non_differentiable to_namedtuple(x...)
WithGraph(fg::AbstractFeaturedGraph, gn::GraphNet, pos::GraphSignals.AbstractGraphDomain) =
WithGraph(GraphSignals.to_namedtuple(fg), gn, pos)
15 changes: 10 additions & 5 deletions src/layers/conv.jl → src/layers/graph_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ end

# For static graph
WithGraph(fg::AbstractFeaturedGraph, l::GCNConv) =
WithGraph(GraphSignals.normalized_adjacency_matrix(fg, eltype(l.weight); selfloop=true), l)
WithGraph(GraphSignals.normalized_adjacency_matrix(fg, eltype(l.weight); selfloop=true),
l,
GraphSignals.NullDomain())

function (wg::WithGraph{<:GCNConv})(X::AbstractArray)
à = wg.graph
Expand Down Expand Up @@ -135,7 +137,9 @@ end

# For static graph
WithGraph(fg::AbstractFeaturedGraph, l::ChebConv) =
WithGraph(GraphSignals.scaled_laplacian(fg, eltype(l.weight)), l)
WithGraph(GraphSignals.scaled_laplacian(fg, eltype(l.weight)),
l,
GraphSignals.NullDomain())

function (wg::WithGraph{<:ChebConv})(X::AbstractArray)
L̃ = wg.graph
Expand Down Expand Up @@ -332,7 +336,7 @@ function (l::GATConv)(fg::AbstractFeaturedGraph)
GraphSignals.check_num_nodes(fg, X)
sg = graph(fg)
@assert ChainRulesCore.ignore_derivatives(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
el = to_namedtuple(sg)
el = GraphSignals.to_namedtuple(sg)
_, V, _ = propagate(l, el, nothing, X, nothing, hcat, nothing, nothing)
return ConcreteFeaturedGraph(fg, nf=V)
end
Expand Down Expand Up @@ -460,7 +464,7 @@ function (l::GATv2Conv)(fg::AbstractFeaturedGraph)
GraphSignals.check_num_nodes(fg, X)
sg = graph(fg)
@assert ChainRulesCore.ignore_derivatives(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
el = to_namedtuple(sg)
el = GraphSignals.to_namedtuple(sg)
_, V, _ = propagate(l, el, nothing, X, nothing, hcat, nothing, nothing)
return ConcreteFeaturedGraph(fg, nf=V)
end
Expand Down Expand Up @@ -536,7 +540,7 @@ update(ggc::GatedGraphConv, m::AbstractArray, x) = m
function (l::GatedGraphConv)(fg::AbstractFeaturedGraph)
nf = node_feature(fg)
GraphSignals.check_num_nodes(fg, nf)
V = l(to_namedtuple(fg), nf)
V = l(GraphSignals.GraphSignals.to_namedtuple(fg), nf)
return ConcreteFeaturedGraph(fg, nf=V)
end

Expand Down Expand Up @@ -723,6 +727,7 @@ end

function message(c::CGConv, x_i::AbstractArray, x_j::AbstractArray, e::AbstractArray)
z = vcat(x_i, x_j, e)

return σ.(_matmul(c.Wf, z) .+ c.bf) .* softplus.(_matmul(c.Ws, z) .+ c.bs)
end

Expand Down
13 changes: 8 additions & 5 deletions src/layers/graphlayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ Chain(
# plus 2 non-trainable, 32 parameters, summarysize 1.006 MiB.
```
"""
struct WithGraph{L<:AbstractGraphLayer,G}
struct WithGraph{L<:AbstractGraphLayer,G,P}
graph::G
layer::L
position::P
end

@functor WithGraph
Expand All @@ -71,13 +72,15 @@ function Optimisers.destructure(m::WithGraph)
end

function Base.show(io::IO, l::WithGraph)
print(io, "WithGraph(Graph(#V=", nv(l.graph))
print(io, ", #E=", ne(l.graph), "), ")
print(io, l.layer, ")")
print(io, "WithGraph(Graph(#V=", nv(l.graph), ", #E=", ne(l.graph), "), ")
print(io, l.layer)
has_positional_feature(l.position) &&
print(io, ", domain_dim=", GraphSignals.pf_dims_repr(l.position))
print(io, ")")
end

WithGraph(fg::AbstractFeaturedGraph, model::Chain; kwargs...) =
Chain(map(l -> WithGraph(fg, l; kwargs...), model.layers)...)
Chain([WithGraph(fg, l; kwargs...) for l in model.layers]...)
WithGraph(::AbstractFeaturedGraph, layer::WithGraph; kwargs...) = layer
WithGraph(::AbstractFeaturedGraph, layer; kwargs...) = layer

Expand Down
Loading