Skip to content

Commit 3292a49

Browse files
committed
add LSPE
1 parent 770fd14 commit 3292a49

File tree

3 files changed

+128
-0
lines changed

3 files changed

+128
-0
lines changed

src/GeometricFlux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ include("models.jl")
8888

8989
include("sampling.jl")
9090
include("embedding/node2vec.jl")
91+
include("layers/positional.jl")
9192

9293
using .Datasets
9394

src/layers/positional.jl

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,126 @@ output_dim(l::EEquivGraphPE) = size(l.nn.weight, 1)
7878

7979
positional_encode(wg::WithGraph{<:EEquivGraphPE}, args...) = wg(args...)
8080
positional_encode(l::EEquivGraphPE, args...) = l(args...)
81+
82+
"""
83+
LSPE(f_h, f_e, f_p, pe_dim; init=glorot_uniform,
84+
init_pe=random_walk_pe)
85+
86+
Learnable structural positional encoding layer.
87+
88+
# Arguments
89+
90+
- `f_h::MessagePassing`: Neural network layer for node update.
91+
- `f_e`: Neural network layer for edge update.
92+
- `f_p`: Neural network layer for positional encoding.
93+
- `pe_dim::Int`: Dimension of positional encoding.
94+
- `init`: Initializer for layer weights.
95+
- `init_pe`: Initializer for positional encoding.
96+
"""
97+
struct LSPE{H<:MessagePassing,E,F,P} <: AbstractPE
98+
f_h::H
99+
f_e::E
100+
f_p::F
101+
pe::P
102+
end
103+
104+
function LSPE(f_h::MessagePassing, f_e, f_p, pe_dim::Int; init=glorot_uniform, init_pe=random_walk_pe)
105+
pe = init_pe(A, pe_dim)
106+
return LSPE(f_h, f_e, f_p, pe)
107+
end
108+
109+
# For variable graph
110+
function (l::LSPE)(fg::AbstractFeaturedGraph)
111+
X = node_feature(fg)
112+
E = edge_feature(fg)
113+
GraphSignals.check_num_nodes(fg, X)
114+
GraphSignals.check_num_edges(fg, E)
115+
E, V = propagate(l, graph(fg), E, X)
116+
return ConcreteFeaturedGraph(fg, nf=V, ef=E)
117+
end
118+
119+
# For static graph
120+
function (l::LSPE)(el::NamedTuple, X::AbstractArray, E::AbstractArray)
121+
GraphSignals.check_num_nodes(el.N, X)
122+
GraphSignals.check_num_edges(el.E, E)
123+
E, V = propagate(l, graph(fg), E, X)
124+
return V, E
125+
end
126+
127+
update_vertex(l::LSPE, el::NamedTuple, X, E::AbstractArray) = l.f_h(el, X, E)
128+
update_vertex(l::LSPE, el::NamedTuple, X, E::Nothing) = l.f_h(el, X)
129+
130+
update_edge(l::LSPE, h_i, h_j, e_ij) = l.f_e(e_ij)
131+
132+
positional_encode(l::LSPE, p_i, p_j, e_ij) = l.f_p(p_i)
133+
134+
propagate(l::LSPE, sg::SparseGraph, E, V) = propagate(l, to_namedtuple(sg), E, V)
135+
136+
function propagate(l::LSPE, el::NamedTuple, E, V)
137+
e_ij = _gather(E, el.es)
138+
h_i = _gather(V, el.xs)
139+
h_j = _gather(V, el.nbrs)
140+
p_i = _gather(l.pe, el.xs)
141+
p_j = _gather(l.pe, el.nbrs)
142+
143+
V = update_vertex(l, el, vcat(V, l.pe), E)
144+
E = update_edge(l, h_i, h_j, e_ij)
145+
l.pe = positional_encode(l, p_i, p_j, e_ij)
146+
return E, V
147+
end
148+
149+
function Base.show(io::IO, l::LSPE)
150+
print(io, "LSPE(node_layer=", l.f_h)
151+
print(io, ", edge_layer=", l.f_e)
152+
print(io, ", positional_encode=", l.f_p, ")")
153+
end
154+
155+
156+
"""
157+
random_walk_pe(A, k)
158+
159+
Returns positional encoding (PE) of size `(k, N)` where N is node number.
160+
PE is generated by `k`-step random walk over given graph.
161+
162+
# Arguments
163+
164+
- `A`: Adjacency matrix of a graph.
165+
- `k::Int`: First dimension of PE.
166+
"""
167+
function random_walk_pe(A::AbstractMatrix, k::Int)
168+
N = size(A, 1)
169+
@assert k N "k must less or equal to number of nodes"
170+
inv_D = GraphSignals.degree_matrix(A, Float32, inverse=true)
171+
172+
RW = similar(A, size(A)..., k)
173+
RW[:, :, 1] .= A * inv_D
174+
for i in 2:k
175+
RW[:, :, i] .= RW[:, :, i-1] * RW[:, :, 1]
176+
end
177+
178+
pe = similar(RW, k, N)
179+
for i in 1:N
180+
pe[:, i] .= RW[i, i, :]
181+
end
182+
183+
return pe
184+
end
185+
186+
"""
187+
laplacian_pe(A, k)
188+
189+
Returns positional encoding (PE) of size `(k, N)` where `N` is node number.
190+
PE is generated from eigenvectors of a graph Laplacian truncated by `k`.
191+
192+
# Arguments
193+
194+
- `A`: Adjacency matrix of a graph.
195+
- `k::Int`: First dimension of PE.
196+
"""
197+
function laplacian_pe(A::AbstractMatrix, k::Int)
198+
N = size(A, 1)
199+
@assert k N "k must less or equal to number of nodes"
200+
L = GraphSignals.normalized_laplacian(A)
201+
U = eigvecs(L)
202+
return U[1:k, :]
203+
end

test/layers/positional.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,8 @@
3838
@test length(g.grads) == 2
3939
end
4040
end
41+
42+
@testset "LSPE" begin
43+
44+
end
4145
end

0 commit comments

Comments
 (0)