|
1 | 1 | """
|
2 |
| - EEquivGraphConv(in_dim, int_dim, out_dim; init=glorot_uniform) |
3 |
| - EEquivGraphConv(in_dim, nn_edge, nn_x, nn_h) |
| 2 | + EEquivGraphConv(in_dim=>out_dim, pos_dim, edge_dim; init=glorot_uniform) |
4 | 3 |
|
5 |
| -E(n)-equivariant graph neural network layer as defined in the paper "[E(n) Equivariant Neural Networks](https://arxiv.org/abs/2102.09844)" by Satorras, Hoogeboom, and Welling (2021). |
| 4 | +E(n)-equivariant graph neural network layer. |
6 | 5 |
|
7 | 6 | # Arguments
|
8 | 7 |
|
9 |
| -Either one of two sets of arguments: |
10 |
| -
|
11 |
| -Set 1: |
12 |
| -
|
13 |
| -- `in_dim`: node feature dimension. Data is assumed to be of the form [feature; coordinate], so `in_dim` must strictly be less than the dimension of the input vectors. |
14 |
| -- `int_dim`: intermediate dimension, can be arbitrary. |
| 8 | +- `in_dim::Int`: node feature dimension. Data is assumed to be of the form [feature; coordinate], so `in_dim` must strictly be less than the dimension of the input vectors. |
15 | 9 | - `out_dim`: the output of the layer will have dimension `out_dim` + (dimension of input vector - `in_dim`).
|
16 |
| -- `init`: neural network initialization function, should be compatible with `Flux.Dense`. |
17 |
| -
|
18 |
| -Set 2: |
| 10 | +- `pos_dim::Int`: dimension of positional encoding. |
| 11 | +- `edge_dim::Int`: dimension of edge feature. |
| 12 | +- `init`: neural network initialization function. |
19 | 13 |
|
20 |
| -- `in_dim`: as in Set 1. |
21 |
| -- `nn_edge`: a differentiable function that must take vectors of dimension `in_dim * 2 + 2` (output designated `int_dim`) |
22 |
| -- `nn_x`: a differentiable function that must take vectors of dimension `int_dim` to dimension `1`. |
23 |
| -- `nn_h`: a differentiable function that must take vectors of dimension `in_dim + int_dim` to `out_dim`. |
| 14 | +# Examples |
24 | 15 |
|
25 | 16 | ```jldoctest
|
26 |
| -julia> in_dim, int_dim, out_dim = 3,6,5 |
27 |
| -(3, 5, 5) |
28 |
| -
|
29 |
| -julia> egnn = EEquivGraphConv(in_dim, int_dim, out_dim) |
30 |
| -EEquivGraphConv{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}(Dense(8 => 5), Dense(5 => 1), Dense(8 => 5), 3, 5, 5) |
31 |
| -
|
32 |
| -julia> m_len = 2*in_dim + 2 |
33 |
| -8 |
34 |
| -
|
35 |
| -julia> nn_edge = Flux.Dense(m_len, int_dim) |
36 |
| -Dense(8 => 5) # 45 parameters |
| 17 | +julia> in_dim, out_dim, pos_dim = 3, 5, 2 |
| 18 | +(3, 5, 2) |
37 | 19 |
|
38 |
| -julia> nn_x = Flux.Dense(int_dim, 1) |
39 |
| -Dense(5 => 1) # 6 parameters |
40 |
| -
|
41 |
| -julia> nn_h = Flux.Dense(in_dim + int_dim, out_dim) |
42 |
| -Dense(8 => 5) # 45 parameters |
43 |
| -
|
44 |
| -julia> egnn = EEquivGraphConv(in_dim, nn_edge, nn_x, nn_h) |
45 |
| -EEquivGraphConv{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}(Dense(8 => 5), Dense(5 => 1), Dense(8 => 5), 3, 5, 5) |
| 20 | +julia> egnn = EEquivGraphConv(in_dim=>out_dim, pos_dim, in_dim) |
| 21 | +EEquivGraphConv(ϕ_edge=Dense(10 => 5), ϕ_x=Dense(5 => 2), ϕ_h=Dense(8 => 5)) |
46 | 22 | ```
|
47 |
| -""" |
48 | 23 |
|
49 |
| -struct EEquivGraphConv{E,X,H} <: MessagePassing |
50 |
| - nn_edge::E |
51 |
| - nn_x::X |
52 |
| - nn_h::H |
53 |
| - in_dim::Int |
54 |
| - int_dim::Int |
55 |
| - out_dim::Int |
| 24 | +See also [`WithGraph`](@ref) for training layer with static graph and [`EEquivGraphPE`](@ref) for positional encoding. |
| 25 | +""" |
| 26 | +struct EEquivGraphConv{X,E,H} <: AbstractGraphLayer |
| 27 | + pe::X |
| 28 | + nn_edge::E |
| 29 | + nn_h::H |
56 | 30 | end
|
57 | 31 |
|
58 | 32 | @functor EEquivGraphConv
|
59 | 33 |
|
60 |
| -function EEquivGraphConv(in_dim::Int, int_dim::Int, out_dim::Int; init=glorot_uniform) |
61 |
| - m_len = 2in_dim + 2 |
62 |
| - nn_edge = Flux.Dense(m_len, int_dim; init=init) |
63 |
| - nn_x = Flux.Dense(int_dim, 1; init=init) |
64 |
| - nn_h = Flux.Dense(in_dim + int_dim, out_dim; init=init) |
65 |
| - return EEquivGraphConv(nn_edge, nn_x, nn_h, in_dim, int_dim, out_dim) |
66 |
| -end |
| 34 | +Flux.trainable(l::EEquivGraphConv) = (l.pe, l.nn_edge, l.nn_h) |
67 | 35 |
|
68 |
| -function EEquivGraphConv(in_dim::Int, nn_edge, nn_x, nn_h) |
69 |
| - m_len = 2in_dim + 2 |
70 |
| - int_dim = Flux.outputsize(nn_edge, (m_len, 2))[1] |
71 |
| - out_dim = Flux.outputsize(nn_h, (in_dim + int_dim, 2))[1] |
72 |
| - return EEquivGraphConv(nn_edge, nn_x, nn_h, in_dim, int_dim, out_dim) |
| 36 | +function EEquivGraphConv(ch::Pair{Int,Int}, pos_dim::Int, edge_dim::Int; init=glorot_uniform) |
| 37 | + in_dim, out_dim = ch |
| 38 | + nn_edge = Flux.Dense(2in_dim + edge_dim + 1, out_dim; init=init) |
| 39 | + pe = EEquivGraphPE(out_dim=>pos_dim; init=init) |
| 40 | + nn_h = Flux.Dense(in_dim + out_dim, out_dim; init=init) |
| 41 | + return EEquivGraphConv(pe, nn_edge, nn_h) |
73 | 42 | end
|
74 | 43 |
|
75 |
| -function ϕ_edge(egnn::EEquivGraphConv, h_i, h_j, dist, a) |
76 |
| - N = size(h_i, 2) |
77 |
| - return egnn.nn_edge(vcat(h_i, h_j, dist, ones(N)' * a)) |
78 |
| -end |
79 |
| - |
80 |
| -ϕ_x(egnn::EEquivGraphConv, m_ij) = egnn.nn_x(m_ij) |
81 |
| - |
82 |
| -function message(egnn::EEquivGraphConv, v_i, v_j, e) |
83 |
| - in_dim = egnn.in_dim |
84 |
| - h_i = v_i[1:in_dim,:] |
85 |
| - h_j = v_j[1:in_dim,:] |
| 44 | +ϕ_edge(l::EEquivGraphConv, h_i, h_j, dist, a) = l.nn_edge(vcat(h_i, h_j, dist, a)) |
86 | 45 |
|
87 |
| - N = size(h_i, 2) |
| 46 | +function message(l::EEquivGraphConv, h_i, h_j, x_i, x_j, e) |
| 47 | + dist = sum(abs2, x_i - x_j; dims=1) |
| 48 | + return ϕ_edge(l, h_i, h_j, dist, e) |
| 49 | +end |
88 | 50 |
|
89 |
| - x_i = v_i[in_dim+1:end,:] |
90 |
| - x_j = v_j[in_dim+1:end,:] |
| 51 | +update(l::EEquivGraphConv, m, h) = l.nn_h(vcat(h, m)) |
91 | 52 |
|
92 |
| - if isnothing(e) |
93 |
| - a = 1 |
94 |
| - else |
95 |
| - a = e[1] |
96 |
| - end |
| 53 | +# For variable graph |
| 54 | +function(egnn::EEquivGraphConv)(fg::AbstractFeaturedGraph) |
| 55 | + nf = node_feature(fg) |
| 56 | + ef = edge_feature(fg) |
| 57 | + pf = positional_feature(fg) |
| 58 | + GraphSignals.check_num_nodes(fg, nf) |
| 59 | + GraphSignals.check_num_edges(fg, ef) |
| 60 | + _, V, X = propagate(egnn, graph(fg), ef, nf, pf, +) |
| 61 | + return ConcreteFeaturedGraph(fg, nf=V, pf=X) |
| 62 | +end |
97 | 63 |
|
98 |
| - dist = sum(abs2.(x_i - x_j); dims=1) |
99 |
| - edge_msg = ϕ_edge(egnn, h_i, h_j, dist, a) |
100 |
| - output_vec = vcat(edge_msg, (x_i - x_j) .* ϕ_x(egnn, edge_msg)[1], ones(N)') |
101 |
| - return reshape(output_vec, :, N) |
| 64 | +# For static graph |
| 65 | +function(l::EEquivGraphConv)(el::NamedTuple, H::AbstractArray, E::AbstractArray, X::AbstractArray) |
| 66 | + GraphSignals.check_num_nodes(el.N, H) |
| 67 | + GraphSignals.check_num_nodes(el.N, X) |
| 68 | + GraphSignals.check_num_edges(el.E, E) |
| 69 | + _, V, X = propagate(l, el, E, H, X, +) |
| 70 | + return V, X |
102 | 71 | end
|
103 | 72 |
|
104 |
| -function update(e::EEquivGraphConv, m, h) |
105 |
| - N = size(m, 2) |
106 |
| - mi = m[1:e.int_dim,:] |
107 |
| - x_msg = m[e.int_dim+1:end-1,:] |
108 |
| - M = m[end,:] |
| 73 | +function Base.show(io::IO, l::EEquivGraphConv) |
| 74 | + print(io, "EEquivGraphConv(ϕ_edge=", l.nn_edge) |
| 75 | + print(io, ", ϕ_x=", l.pe.nn) |
| 76 | + print(io, ", ϕ_h=", l.nn_h) |
| 77 | + print(io, ")") |
| 78 | +end |
109 | 79 |
|
110 |
| - C = 1 ./ (M.-1) |
111 |
| - C = reshape(C, :, N) |
| 80 | +function aggregate_neighbors(::EEquivGraphConv, el::NamedTuple, aggr, E) |
| 81 | + batch_size = size(E)[end] |
| 82 | + dstsize = (size(E, 1), el.N, batch_size) |
| 83 | + xs = batched_index(el.xs, batch_size) |
| 84 | + return _scatter(aggr, E, xs, dstsize) |
| 85 | +end |
112 | 86 |
|
113 |
| - nn_node_out = e.nn_h(vcat(h[1:e.in_dim,:], mi)) |
| 87 | +aggregate_neighbors(::EEquivGraphConv, el::NamedTuple, aggr, E::AbstractMatrix) = _scatter(aggr, E, el.xs) |
114 | 88 |
|
115 |
| - coord_dim = size(h,1) - e.in_dim |
| 89 | +@inline aggregate_neighbors(::EEquivGraphConv, ::NamedTuple, ::Nothing, E) = nothing |
| 90 | +@inline aggregate_neighbors(::EEquivGraphConv, ::NamedTuple, ::Nothing, ::AbstractMatrix) = nothing |
116 | 91 |
|
117 |
| - z = zeros(e.out_dim + coord_dim, N) |
118 |
| - z[1:e.out_dim,:] = nn_node_out |
119 |
| - z[e.out_dim+1:end,:] = h[e.in_dim+1:end,:] + C .* x_msg |
120 |
| - return z |
| 92 | +function propagate(l::EEquivGraphConv, sg::SparseGraph, E, V, X, aggr) |
| 93 | + el = to_namedtuple(sg) |
| 94 | + return propagate(l, el, E, V, X, aggr) |
121 | 95 | end
|
122 | 96 |
|
123 |
| -function(egnn::EEquivGraphConv)(fg::AbstractFeaturedGraph) |
124 |
| - X = node_feature(fg) |
125 |
| - GraphSignals.check_num_nodes(fg, X) |
126 |
| - _, V, _ = propagate(egnn, graph(fg), nothing, X, nothing, +, nothing, nothing) |
127 |
| - return ConcreteFeaturedGraph(fg, nf=V) |
| 97 | +function propagate(l::EEquivGraphConv, el::NamedTuple, E, V, X, aggr) |
| 98 | + E = message( |
| 99 | + l, _gather(V, el.xs), _gather(V, el.nbrs), |
| 100 | + _gather(X, el.xs), _gather(X, el.nbrs), |
| 101 | + _gather(E, el.es) |
| 102 | + ) |
| 103 | + X = positional_encode(l.pe, el, X, E) |
| 104 | + Ē = aggregate_neighbors(l, el, aggr, E) |
| 105 | + V = update(l, Ē, V) |
| 106 | + return E, V, X |
128 | 107 | end
|
| 108 | + |
| 109 | +WithGraph(fg::AbstractFeaturedGraph, l::EEquivGraphConv) = WithGraph(to_namedtuple(fg), l) |
| 110 | +(wg::WithGraph{<:EEquivGraphConv})(args...) = wg.layer(wg.graph, args...) |
0 commit comments