Skip to content

Commit 3c62905

Browse files
committed
add EEquivGraphPE layer
1 parent 0378e18 commit 3c62905

File tree

5 files changed

+91
-8
lines changed

5 files changed

+91
-8
lines changed

src/GeometricFlux.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ export
3131
# layers/msgpass
3232
MessagePassing,
3333

34+
# layers/positional
35+
AbstractPE,
36+
positional_encode,
37+
EEquivGraphPE,
38+
3439
# layers/graph_conv
3540
GCNConv,
3641
ChebConv,
@@ -75,6 +80,7 @@ include("layers/graphlayers.jl")
7580
include("layers/gn.jl")
7681
include("layers/msgpass.jl")
7782

83+
include("layers/positional.jl")
7884
include("layers/graph_conv.jl")
7985
include("layers/group_conv.jl")
8086
include("layers/pool.jl")

src/layers/group_conv.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,16 @@ E(n)-equivariant graph neural network layer as defined in the paper "[E(n) Equiv
66
77
# Arguments
88
9-
Either one of two sets of arguments:
10-
11-
Set 1:
12-
139
- `in_dim`: node feature dimension. Data is assumed to be of the form [feature; coordinate], so `in_dim` must strictly be less than the dimension of the input vectors.
1410
- `int_dim`: intermediate dimension, can be arbitrary.
1511
- `out_dim`: the output of the layer will have dimension `out_dim` + (dimension of input vector - `in_dim`).
1612
- `init`: neural network initialization function, should be compatible with `Flux.Dense`.
17-
18-
Set 2:
19-
20-
- `in_dim`: as in Set 1.
2113
- `nn_edge`: a differentiable function that must take vectors of dimension `in_dim * 2 + 2` (output designated `int_dim`)
2214
- `nn_x`: a differentiable function that must take vectors of dimension `int_dim` to dimension `1`.
2315
- `nn_h`: a differentiable function that must take vectors of dimension `in_dim + int_dim` to `out_dim`.
2416
17+
# Examples
18+
2519
```jldoctest
2620
julia> in_dim, int_dim, out_dim = 3,6,5
2721
(3, 5, 5)
@@ -57,6 +51,8 @@ end
5751

5852
@functor EEquivGraphConv
5953

54+
Flux.trainable(l::EEquivGraphConv) = (l.nn_edge, l.nn_x, l.nn_h)
55+
6056
function EEquivGraphConv(in_dim::Int, int_dim::Int, out_dim::Int; init=glorot_uniform)
6157
m_len = 2in_dim + 2
6258
nn_edge = Flux.Dense(m_len, int_dim; init=init)
@@ -120,9 +116,19 @@ function update(e::EEquivGraphConv, m, h)
120116
return z
121117
end
122118

119+
# For variable graph
123120
function(egnn::EEquivGraphConv)(fg::AbstractFeaturedGraph)
124121
X = node_feature(fg)
125122
GraphSignals.check_num_nodes(fg, X)
126123
_, V, _ = propagate(egnn, graph(fg), nothing, X, nothing, +, nothing, nothing)
127124
return ConcreteFeaturedGraph(fg, nf=V)
128125
end
126+
127+
function Base.show(io::IO, l::EEquivGraphConv)
128+
in_channel = size(l.weight1, ndims(l.weight1))
129+
out_channel = size(l.weight1, ndims(l.weight1)-1)
130+
print(io, "GraphConv(", in_channel, " => ", out_channel)
131+
l.σ == identity || print(io, ", ", l.σ)
132+
print(io, ", aggr=", l.aggr)
133+
print(io, ")")
134+
end

src/layers/positional.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""
2+
AbstractPE
3+
4+
Abstract type of positional encoding for GNN.
5+
"""
6+
abstract type AbstractPE end
7+
8+
positional_encode(l::AbstractPE, args...) = throw(ErrorException("positional_encode function for $l is not implemented."))
9+
10+
struct EEquivGraphPE{X} <: MessagePassing
11+
nn_x::X
12+
end
13+
14+
function EEquivGraphPE(ch::Pair{Int,Int}; init=glorot_uniform, bias::Bool=true)
15+
in, out = ch
16+
nn_x = Flux.Dense(in, out; init=init, bias=bias)
17+
return EEquivGraphPE(nn_x)
18+
end
19+
20+
@functor EEquivGraphPE
21+
22+
ϕ_x(l::EEquivGraphPE, m_ij) = l.nn_x(m_ij)
23+
24+
message(l::EEquivGraphPE, x_i, x_j, e) = (x_i - x_j) .* ϕ_x(l, e)
25+
26+
update(l::EEquivGraphPE, m::AbstractArray, x::AbstractArray) = m .+ x
27+
28+
# For variable graph
29+
function(l::EEquivGraphPE)(fg::AbstractFeaturedGraph)
30+
X = node_feature(fg)
31+
E = edge_feature(fg)
32+
GraphSignals.check_num_nodes(fg, X)
33+
GraphSignals.check_num_nodes(fg, E)
34+
_, V, _ = propagate(l, graph(fg), E, X, nothing, mean, nothing, nothing)
35+
return ConcreteFeaturedGraph(fg, nf=V)
36+
end
37+
38+
# For static graph
39+
function(l::EEquivGraphPE)(el::NamedTuple, X::AbstractArray, E::AbstractArray)
40+
GraphSignals.check_num_nodes(el.N, X)
41+
GraphSignals.check_num_nodes(el.E, E)
42+
_, V, _ = propagate(l, el, E, X, nothing, mean, nothing, nothing)
43+
return V
44+
end

test/layers/positional.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
@testset "positional" begin
2+
T = Float32
3+
batch_size = 10
4+
in_channel = 3
5+
out_channel = 5
6+
7+
N = 4
8+
E = 4
9+
adj = T[0. 1. 0. 1.;
10+
1. 0. 1. 0.;
11+
0. 1. 0. 1.;
12+
1. 0. 1. 0.]
13+
14+
@testset "EEquivGraphPE" begin
15+
l = EEquivGraphPE(in_channel=>out_channel)
16+
17+
nf = rand(T, out_channel, N)
18+
ef = rand(T, in_channel, E)
19+
fg = FeaturedGraph(adj, nf=nf, ef=ef)
20+
fg_ = l(fg)
21+
@test size(node_feature(fg_)) == (out_channel, N)
22+
23+
g = Zygote.gradient(() -> sum(node_feature(l(fg))), Flux.params(l))
24+
@test length(g.grads) == 4
25+
end
26+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ cuda_tests = [
2222
tests = [
2323
"layers/gn",
2424
"layers/msgpass",
25+
"layers/positional",
2526
"layers/graph_conv",
2627
"layers/group_conv",
2728
"layers/pool",

0 commit comments

Comments
 (0)