@@ -294,43 +294,30 @@ end
294
294
Flux. trainable (l:: GATConv ) = (l. weight, l. bias, l. a)
295
295
296
296
# neighbor attention
297
- function message (gat:: GATConv , Xi:: AbstractMatrix , Xj:: AbstractMatrix , e_ij)
298
- Xi = reshape (Xi, size (Xi)... , 1 )
299
- Xj = reshape (Xj, size (Xj)... , 1 )
300
- m = message (gat, Xi, Xj, nothing )
301
- return reshape (m, :)
297
+ function update_batch_edge (gat:: GATConv , el:: NamedTuple , E, X:: AbstractMatrix , u)
298
+ X = reshape (X, size (X)... , 1 )
299
+ M = update_batch_edge (gat, el, E, X, u)
300
+ return reshape (M, size (M)[1 : 2 ]. .. )
302
301
end
303
302
304
- function message (gat:: GATConv , Xi:: AbstractArray , Xj:: AbstractArray , e_ij)
303
+ function update_batch_edge (gat:: GATConv , el:: NamedTuple , E, X:: AbstractArray , u)
304
+ Xi, Xj = _gather (X, el. xs), _gather (X, el. nbrs)
305
305
_, nb, bch_sz = size (Xj)
306
306
heads = gat. heads
307
307
Q = reshape (NNlib. batched_mul (gat. weight, Xi), :, heads, nb, bch_sz) # dims: (out, heads, nb, bch_sz)
308
308
K = reshape (NNlib. batched_mul (gat. weight, Xj), :, heads, nb, bch_sz)
309
309
V = reshape (NNlib. batched_mul (gat. weight, Xj), :, heads, nb, bch_sz)
310
310
QK = vcat (Q, K) # dims: (2out, heads, nb, bch_sz)
311
311
A = leakyrelu .(sum (QK .* gat. a, dims= 1 ), gat. negative_slope) # dims: (1, heads, nb, bch_sz)
312
- α = Flux. softmax (A, dims= 3 ) # dims: (1, heads, nb, bch_sz)
313
- return reshape (sum (V .* α, dims= 3 ), :, 1 , bch_sz) # dims: (out*heads, 1, bch_sz)
314
- end
315
-
316
- # graph attention
317
- function update_batch_edge (gat:: GATConv , el:: NamedTuple , E, X:: AbstractArray , u)
318
- function _message (gat, el, i, X)
319
- xs = el. xs[el. xs .== i]
320
- nbrs = el. nbrs[el. xs .== i]
321
- Xi = _gather (X, xs)
322
- Xj = _gather (X, nbrs)
323
- return message (gat, Xi, Xj, nothing )
324
- end
325
- hs = [_message (gat, el, i, X) for i in 1 : el. N]
326
- return hcat (hs... ) # dims: (out*heads, N, [bch_sz])
312
+ A = Flux. softmax (A, dims= 3 ) # dims: (1, heads, nb, bch_sz)
313
+ A = reshape (V .* A, :, nb, bch_sz)
314
+ N = incidence_matrix (el. xs, el. N)
315
+ return NNlib. batched_mul (A, N) # dims: (out*heads, N, bch_sz)
327
316
end
328
317
329
- update_batch_edge (gat:: GATConv , el:: NamedTuple , E, X:: AbstractArray , u) =
330
- [update_batch_edge (gat, el, X, i) for i in 1 : el. N]
331
-
332
318
# graph attention
333
- aggregate_neighbors (gat:: GATConv , el:: NamedTuple , aggr, E) = aggr (E... ) # dims: (out, N, heads, [bch_sz])
319
+ aggregate_neighbors (gat:: GATConv , el:: NamedTuple , aggr, E:: AbstractArray ) = E # dims: (out*heads, N, [bch_sz])
320
+ aggregate_neighbors (gat:: GATConv , el:: NamedTuple , aggr, E:: AbstractMatrix ) = E
334
321
335
322
function update (gat:: GATConv , M:: AbstractArray , X)
336
323
M = M .+ gat. bias
@@ -342,7 +329,7 @@ function update(gat::GATConv, M::AbstractArray, X)
342
329
M = gat. σ .(mean (M, dims= 2 ))
343
330
M = reshape (M, :, dims... ) # dims: (out, N, [bch_sz])
344
331
end
345
- return _reshape (M)
332
+ return M
346
333
end
347
334
348
335
# For variable graph
360
347
function (l:: GATConv )(el:: NamedTuple , X:: AbstractArray )
361
348
GraphSignals. check_num_nodes (el. N, X)
362
349
# TODO : should have self loops check for el
363
- Ē = update_batch_edge (l, el, nothing , X, nothing )
364
- V = update_batch_vertex (l, el, Ē, X, nothing )
350
+ _, V, _ = propagate (l, el, nothing , X, nothing , hcat, nothing , nothing )
365
351
return V
366
352
end
367
353
@@ -486,7 +472,7 @@ function (gat::GATv2Conv)(fg::AbstractFeaturedGraph)
486
472
X = node_feature (fg)
487
473
GraphSignals. check_num_nodes (fg, X)
488
474
sg = graph (fg)
489
- @assert Zygote. ignore (() -> check_self_loops (sg)) " a vertex must have self loop (receive a message from itself)."
475
+ @assert Zygote. ignore (() -> GraphSignals . has_all_self_loops (sg)) " a vertex must have self loop (receive a message from itself)."
490
476
es, nbrs, xs = Zygote. ignore (() -> collect (edges (sg)))
491
477
el = (N= nv (sg), E= ne (sg), es= es, nbrs= nbrs, xs= xs)
492
478
Ē = update_batch_edge (gat, el, nothing , X, nothing )
0 commit comments