@@ -78,3 +78,126 @@ output_dim(l::EEquivGraphPE) = size(l.nn.weight, 1)
78
78
79
79
positional_encode (wg:: WithGraph{<:EEquivGraphPE} , args... ) = wg (args... )
80
80
positional_encode (l:: EEquivGraphPE , args... ) = l (args... )
81
+
82
+ """
83
+ LSPE(f_h, f_e, f_p, pe_dim; init=glorot_uniform,
84
+ init_pe=random_walk_pe)
85
+
86
+ Learnable structural positional encoding layer.
87
+
88
+ # Arguments
89
+
90
+ - `f_h::MessagePassing`: Neural network layer for node update.
91
+ - `f_e`: Neural network layer for edge update.
92
+ - `f_p`: Neural network layer for positional encoding.
93
+ - `pe_dim::Int`: Dimension of positional encoding.
94
+ - `init`: Initializer for layer weights.
95
+ - `init_pe`: Initializer for positional encoding.
96
+ """
97
+ struct LSPE{H<: MessagePassing ,E,F,P} <: AbstractPE
98
+ f_h:: H
99
+ f_e:: E
100
+ f_p:: F
101
+ pe:: P
102
+ end
103
+
104
+ function LSPE (f_h:: MessagePassing , f_e, f_p, pe_dim:: Int ; init= glorot_uniform, init_pe= random_walk_pe)
105
+ pe = init_pe (A, pe_dim)
106
+ return LSPE (f_h, f_e, f_p, pe)
107
+ end
108
+
109
+ # For variable graph
110
+ function (l:: LSPE )(fg:: AbstractFeaturedGraph )
111
+ X = node_feature (fg)
112
+ E = edge_feature (fg)
113
+ GraphSignals. check_num_nodes (fg, X)
114
+ GraphSignals. check_num_edges (fg, E)
115
+ E, V = propagate (l, graph (fg), E, X)
116
+ return ConcreteFeaturedGraph (fg, nf= V, ef= E)
117
+ end
118
+
119
+ # For static graph
120
+ function (l:: LSPE )(el:: NamedTuple , X:: AbstractArray , E:: AbstractArray )
121
+ GraphSignals. check_num_nodes (el. N, X)
122
+ GraphSignals. check_num_edges (el. E, E)
123
+ E, V = propagate (l, graph (fg), E, X)
124
+ return V, E
125
+ end
126
+
127
+ update_vertex (l:: LSPE , el:: NamedTuple , X, E:: AbstractArray ) = l. f_h (el, X, E)
128
+ update_vertex (l:: LSPE , el:: NamedTuple , X, E:: Nothing ) = l. f_h (el, X)
129
+
130
+ update_edge (l:: LSPE , h_i, h_j, e_ij) = l. f_e (e_ij)
131
+
132
+ positional_encode (l:: LSPE , p_i, p_j, e_ij) = l. f_p (p_i)
133
+
134
+ propagate (l:: LSPE , sg:: SparseGraph , E, V) = propagate (l, to_namedtuple (sg), E, V)
135
+
136
+ function propagate (l:: LSPE , el:: NamedTuple , E, V)
137
+ e_ij = _gather (E, el. es)
138
+ h_i = _gather (V, el. xs)
139
+ h_j = _gather (V, el. nbrs)
140
+ p_i = _gather (l. pe, el. xs)
141
+ p_j = _gather (l. pe, el. nbrs)
142
+
143
+ V = update_vertex (l, el, vcat (V, l. pe), E)
144
+ E = update_edge (l, h_i, h_j, e_ij)
145
+ l. pe = positional_encode (l, p_i, p_j, e_ij)
146
+ return E, V
147
+ end
148
+
149
+ function Base. show (io:: IO , l:: LSPE )
150
+ print (io, " LSPE(node_layer=" , l. f_h)
151
+ print (io, " , edge_layer=" , l. f_e)
152
+ print (io, " , positional_encode=" , l. f_p, " )" )
153
+ end
154
+
155
+
156
+ """
157
+ random_walk_pe(A, k)
158
+
159
+ Returns positional encoding (PE) of size `(k, N)` where N is node number.
160
+ PE is generated by `k`-step random walk over given graph.
161
+
162
+ # Arguments
163
+
164
+ - `A`: Adjacency matrix of a graph.
165
+ - `k::Int`: First dimension of PE.
166
+ """
167
+ function random_walk_pe (A:: AbstractMatrix , k:: Int )
168
+ N = size (A, 1 )
169
+ @assert k ≤ N " k must less or equal to number of nodes"
170
+ inv_D = GraphSignals. degree_matrix (A, Float32, inverse= true )
171
+
172
+ RW = similar (A, size (A)... , k)
173
+ RW[:, :, 1 ] .= A * inv_D
174
+ for i in 2 : k
175
+ RW[:, :, i] .= RW[:, :, i- 1 ] * RW[:, :, 1 ]
176
+ end
177
+
178
+ pe = similar (RW, k, N)
179
+ for i in 1 : N
180
+ pe[:, i] .= RW[i, i, :]
181
+ end
182
+
183
+ return pe
184
+ end
185
+
186
+ """
187
+ laplacian_pe(A, k)
188
+
189
+ Returns positional encoding (PE) of size `(k, N)` where `N` is node number.
190
+ PE is generated from eigenvectors of a graph Laplacian truncated by `k`.
191
+
192
+ # Arguments
193
+
194
+ - `A`: Adjacency matrix of a graph.
195
+ - `k::Int`: First dimension of PE.
196
+ """
197
+ function laplacian_pe (A:: AbstractMatrix , k:: Int )
198
+ N = size (A, 1 )
199
+ @assert k ≤ N " k must less or equal to number of nodes"
200
+ L = GraphSignals. normalized_laplacian (A)
201
+ U = eigvecs (L)
202
+ return U[1 : k, :]
203
+ end
0 commit comments