@@ -253,16 +253,16 @@ Graph attentional layer.
253
253
254
254
```jldoctest
255
255
julia> GATConv(1024=>256, relu)
256
- GATConv(1024=>256, heads=1, concat=true, LeakyReLU(λ=0.2))
256
+ GATConv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.2))
257
257
258
258
julia> GATConv(1024=>256, relu, heads=4)
259
- GATConv(1024=>1024, heads=4, concat=true, LeakyReLU(λ=0.2))
259
+ GATConv(1024=>1024, relu, heads=4, concat=true, LeakyReLU(λ=0.2))
260
260
261
261
julia> GATConv(1024=>256, relu, heads=4, concat=false)
262
- GATConv(1024=>1024, heads=4, concat=false, LeakyReLU(λ=0.2))
262
+ GATConv(1024=>1024, relu, heads=4, concat=false, LeakyReLU(λ=0.2))
263
263
264
264
julia> GATConv(1024=>256, relu, negative_slope=0.1f0)
265
- GATConv(1024=>256, heads=1, concat=true, LeakyReLU(λ=0.1))
265
+ GATConv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.1))
266
266
```
267
267
268
268
See also [`WithGraph`](@ref) for training layer with static graph.
@@ -282,7 +282,7 @@ function GATConv(ch::Pair{Int,Int}, σ=identity; heads::Int=1, concat::Bool=true
282
282
negative_slope= 0.2f0 , init= glorot_uniform, bias:: Bool = true )
283
283
in, out = ch
284
284
W = init (out* heads, in)
285
- b = Flux. create_bias (W, bias, out, 1 , heads)
285
+ b = Flux. create_bias (W, bias, out* heads)
286
286
a = init (2 * out, heads)
287
287
GATConv (W, b, a, σ, negative_slope, ch, heads, concat)
288
288
end
@@ -297,22 +297,20 @@ Flux.trainable(l::GATConv) = (l.weight, l.bias, l.a)
297
297
function message (gat:: GATConv , Xi:: AbstractMatrix , Xj:: AbstractMatrix , e_ij)
298
298
Xi = reshape (Xi, size (Xi)... , 1 )
299
299
Xj = reshape (Xj, size (Xj)... , 1 )
300
- A = message (gat, Xi, Xj, nothing )
301
- return reshape (A, size (A)[ 1 : 3 ] . .. )
300
+ m = message (gat, Xi, Xj, nothing )
301
+ return reshape (m, : )
302
302
end
303
303
304
304
function message (gat:: GATConv , Xi:: AbstractArray , Xj:: AbstractArray , e_ij)
305
305
_, nb, bch_sz = size (Xj)
306
306
heads = gat. heads
307
- Q = reshape (NNlib. batched_mul (gat. weight, Xi), :, nb, heads* bch_sz) # dims: (out, nb, heads*bch_sz)
308
- K = reshape (NNlib. batched_mul (gat. weight, Xj), :, nb, heads* bch_sz)
309
- V = reshape (NNlib. batched_mul (gat. weight, Xj), :, nb, heads* bch_sz)
310
- QK = reshape (vcat (Q, K), :, nb, heads, bch_sz) # dims: (2out, nb, heads, bch_sz)
311
- QK = permutedims (QK, (1 , 3 , 2 , 4 )) # dims: (2out, heads, nb, bch_sz)
307
+ Q = reshape (NNlib. batched_mul (gat. weight, Xi), :, heads, nb, bch_sz) # dims: (out, heads, nb, bch_sz)
308
+ K = reshape (NNlib. batched_mul (gat. weight, Xj), :, heads, nb, bch_sz)
309
+ V = reshape (NNlib. batched_mul (gat. weight, Xj), :, heads, nb, bch_sz)
310
+ QK = vcat (Q, K) # dims: (2out, heads, nb, bch_sz)
312
311
A = leakyrelu .(sum (QK .* gat. a, dims= 1 ), gat. negative_slope) # dims: (1, heads, nb, bch_sz)
313
- QK = permutedims (QK, (1 , 3 , 2 , 4 )) # dims: (1, nb, heads, bch_sz)
314
- α = Flux. softmax (reshape (A, nb, 1 , :), dims= 1 ) # dims: (nb, 1, heads*bch_sz)
315
- return reshape (NNlib. batched_mul (V, α), :, 1 , heads, bch_sz) # dims: (out, 1, heads, bch_sz)
312
+ α = Flux. softmax (A, dims= 3 ) # dims: (1, heads, nb, bch_sz)
313
+ return reshape (sum (V .* α, dims= 3 ), :, 1 , bch_sz) # dims: (out*heads, 1, bch_sz)
316
314
end
317
315
318
316
# graph attention
@@ -325,7 +323,7 @@ function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractArray, u)
325
323
return message (gat, Xi, Xj, nothing )
326
324
end
327
325
hs = [_message (gat, el, i, X) for i in 1 : el. N]
328
- return hcat (hs... ) # dims: (out, N, heads , [bch_sz])
326
+ return hcat (hs... ) # dims: (out*heads , N, [bch_sz])
329
327
end
330
328
331
329
function check_self_loops (sg:: SparseGraph )
@@ -337,20 +335,18 @@ function check_self_loops(sg::SparseGraph)
337
335
return true
338
336
end
339
337
340
- function update (gat:: GATConv , M:: AbstractArray , X:: AbstractArray )
338
+ function update (gat:: GATConv , M:: AbstractArray , X)
341
339
M = M .+ gat. bias
342
- if gat. concat
343
- M = gat. σ .(M) # dims: (out, N, heads , [bch_sz])
340
+ if gat. concat || gat . heads == 1
341
+ M = gat. σ .(M) # dims: (out*heads , N, [bch_sz])
344
342
else
345
- M = gat. σ .(mean (M, dims= 3 ))
346
- M = _reshape (M) # dims: (out, N, [bch_sz])
343
+ M = reshape (M, :, gat. heads, size (M)[2 : end ]. .. )
344
+ M = gat. σ .(mean (M, dims= 2 ))
345
+ M = reshape (M, :, size (M)[2 : end ]. .. ) # dims: (out, N, [bch_sz])
347
346
end
348
347
return M
349
348
end
350
349
351
- _reshape (M:: AbstractArray{<:Real,3} ) = reshape (M, size (M)[[1 ,2 ]]. .. )
352
- _reshape (M:: AbstractArray{<:Real,4} ) = reshape (M, size (M)[[1 ,2 ,4 ]]. .. )
353
-
354
350
# For variable graph
355
351
function (l:: GATConv )(fg:: AbstractFeaturedGraph )
356
352
X = node_feature (fg)
365
361
366
362
# For static graph
367
363
function (l:: GATConv )(el:: NamedTuple , X:: AbstractArray )
368
- GraphSignals. check_num_nodes (el. N, size (X, 2 ) )
364
+ GraphSignals. check_num_nodes (el. N, X )
369
365
# TODO : should have self loops check for el
370
366
Ē = update_batch_edge (l, el, nothing , X, nothing )
371
367
V = update_batch_vertex (l, el, Ē, X, nothing )
@@ -376,6 +372,7 @@ function Base.show(io::IO, l::GATConv)
376
372
in_channel = size (l. weight, ndims (l. weight))
377
373
out_channel = size (l. weight, ndims (l. weight)- 1 )
378
374
print (io, " GATConv(" , in_channel, " =>" , out_channel)
375
+ l. σ == identity || print (io, " , " , l. σ)
379
376
print (io, " , heads=" , l. heads)
380
377
print (io, " , concat=" , l. concat)
381
378
print (io, " , LeakyReLU(λ=" , l. negative_slope)
@@ -384,99 +381,141 @@ end
384
381
385
382
386
383
"""
387
- GATv2Conv([fg,] in => out;
388
- heads=1,
389
- concat=true,
390
- init=glorot_uniform
391
- negative_slope=0.2)
384
+ GATv2Conv(in => out, σ=identity; heads=1, concat=true,
385
+ init=glorot_uniform, negative_slope=0.2)
392
386
393
- GATv2 Layer as introduced in https://arxiv.org/abs/2105.14491
387
+ Graph attentional layer v2.
394
388
395
389
# Arguments
396
390
397
- - `fg`: Optionally pass a [`FeaturedGraph`](@ref).
398
391
- `in`: The dimension of input features.
399
392
- `out`: The dimension of output features.
393
+ - `σ`: Activation function.
400
394
- `heads`: Number attention heads
401
395
- `concat`: Concatenate layer output or not. If not, layer output is averaged.
402
396
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
397
+
398
+ # Examples
399
+
400
+ ```jldoctest
401
+ julia> GATv2Conv(1024=>256, relu)
402
+ GATv2Conv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.2))
403
+
404
+ julia> GATv2Conv(1024=>256, relu, heads=4)
405
+ GATv2Conv(1024=>1024, relu, heads=4, concat=true, LeakyReLU(λ=0.2))
406
+
407
+ julia> GATv2Conv(1024=>256, relu, heads=4, concat=false)
408
+ GATv2Conv(1024=>1024, relu, heads=4, concat=false, LeakyReLU(λ=0.2))
409
+
410
+ julia> GATv2Conv(1024=>256, relu, negative_slope=0.1f0)
411
+ GATv2Conv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.1))
412
+ ```
413
+
414
+ See also [`WithGraph`](@ref) for training layer with static graph.
403
415
"""
404
- struct GATv2Conv{V<: AbstractFeaturedGraph , T, A<: AbstractMatrix{T} , B} <: MessagePassing
405
- fg:: V
416
+ struct GATv2Conv{T, A<: AbstractMatrix{T} , B, F} <: MessagePassing
406
417
wi:: A
407
418
wj:: A
408
419
biasi:: B
409
420
biasj:: B
410
421
a:: A
422
+ σ:: F
411
423
negative_slope:: T
412
424
channel:: Pair{Int, Int}
413
425
heads:: Int
414
426
concat:: Bool
415
427
end
416
428
417
- function GATv2Conv (
418
- fg:: AbstractFeaturedGraph ,
419
- ch:: Pair{Int,Int} ;
420
- heads:: Int = 1 ,
421
- concat:: Bool = true ,
422
- negative_slope= 0.2f0 ,
423
- bias:: Bool = true ,
424
- init= glorot_uniform,
425
- )
429
+ function GATv2Conv (ch:: Pair{Int,Int} , σ= identity; heads:: Int = 1 , concat:: Bool = true ,
430
+ negative_slope= 0.2f0 , bias:: Bool = true , init= glorot_uniform)
426
431
in, out = ch
427
432
wi = init (out* heads, in)
428
433
wj = init (out* heads, in)
429
434
bi = Flux. create_bias (wi, bias, out* heads)
430
435
bj = Flux. create_bias (wj, bias, out* heads)
431
436
a = init (out, heads)
432
- GATv2Conv (fg, wi, wj, bi, bj, a, negative_slope, ch, heads, concat)
437
+ GATv2Conv (wi, wj, bi, bj, a, σ , negative_slope, ch, heads, concat)
433
438
end
434
439
435
- GATv2Conv (ch:: Pair{Int,Int} ; kwargs... ) = GATv2Conv (NullGraph (), ch; kwargs... )
436
-
437
440
@functor GATv2Conv
438
441
439
442
Flux. trainable (l:: GATv2Conv ) = (l. wi, l. wj, l. biasi, l. biasj, l. a)
440
443
441
- function message (gat:: GATv2Conv , x_i :: AbstractVector , x_j :: AbstractVector )
442
- xi = reshape (gat . wi * x_i + gat . biasi, :, gat . heads )
443
- xj = reshape (gat . wj * x_j + gat . biasj, :, gat . heads )
444
- eij = gat. a ' * leakyrelu .(xi + xj, gat . negative_slope )
445
- vcat (eij, xj )
444
+ function message (gat:: GATv2Conv , Xi :: AbstractMatrix , Xj :: AbstractMatrix , e_ij )
445
+ Xi = reshape (Xi, size (Xi) ... , 1 )
446
+ Xj = reshape (Xj, size (Xj) ... , 1 )
447
+ m = message ( gat, Xi, Xj, nothing )
448
+ return reshape (m, : )
446
449
end
447
450
448
- function graph_attention (gat:: GATv2Conv , i, js, X:: AbstractMatrix )
449
- e_ij = mapreduce (j -> GeometricFlux. message (gat, _view (X, i), _view (X, j)), hcat, js)
450
- n = size (e_ij, 1 )
451
- αs = Flux. softmax (reshape (view (e_ij, 1 , :), gat. heads, :), dims= 2 )
452
- msgs = view (e_ij, 2 : n, :) .* reshape (αs, 1 , :)
453
- reshape (msgs, (n- 1 )* gat. heads, :)
451
+ function message (gat:: GATv2Conv , Xi:: AbstractArray , Xj:: AbstractArray , e_ij)
452
+ _, nb, bch_sz = size (Xj)
453
+ heads = gat. heads
454
+ Q = reshape (NNlib. batched_mul (gat. wi, Xi) .+ gat. biasi, :, heads, nb, bch_sz) # dims: (out, heads, nb, bch_sz)
455
+ K = reshape (NNlib. batched_mul (gat. wj, Xj) .+ gat. biasj, :, heads, nb, bch_sz)
456
+ V = reshape (NNlib. batched_mul (gat. wj, Xj) .+ gat. biasj, :, heads, nb, bch_sz)
457
+ QK = Q + K # dims: (out, heads, nb, bch_sz)
458
+ A = leakyrelu .(sum (QK .* gat. a, dims= 1 ), gat. negative_slope) # dims: (1, heads, nb, bch_sz)
459
+ α = Flux. softmax (A, dims= 3 ) # dims: (1, heads, nb, bch_sz)
460
+ return reshape (sum (V .* α, dims= 3 ), :, 1 , bch_sz) # dims: (out*heads, 1, bch_sz)
454
461
end
455
462
456
- function update_batch_edge (gat:: GATv2Conv , fg:: AbstractFeaturedGraph , E:: AbstractMatrix , X:: AbstractMatrix , u)
457
- @assert Zygote. ignore (() -> check_self_loops (graph (fg))) " a vertex must have self loop (receive a message from itself)."
458
- nodes = Zygote. ignore (()-> vertices (graph (fg)))
459
- nbr = i-> cpu (GraphSignals. neighbors (graph (fg), i))
460
- ms = map (i -> graph_attention (gat, i, Zygote. ignore (()-> nbr (i)), X), nodes)
461
- M = hcat_by_sum (ms)
462
- return M
463
+ function update_batch_edge (gat:: GATv2Conv , el:: NamedTuple , E, X:: AbstractArray , u)
464
+ function _message (gat, el, i, X)
465
+ xs = el. xs[el. xs .== i]
466
+ nbrs = el. nbrs[el. xs .== i]
467
+ Xi = _gather (X, xs)
468
+ Xj = _gather (X, nbrs)
469
+ return message (gat, Xi, Xj, nothing )
470
+ end
471
+ hs = [_message (gat, el, i, X) for i in 1 : el. N]
472
+ return hcat (hs... ) # dims: (out*heads, N, [bch_sz])
463
473
end
464
474
465
- function update_batch_vertex (gat:: GATv2Conv , :: AbstractFeaturedGraph , M:: AbstractMatrix , X:: AbstractMatrix , u)
466
- if ! gat. concat
467
- N = size (M, 2 )
468
- M = reshape (mean (reshape (M, :, gat. heads, N), dims= 2 ), :, N)
475
+ function update (gat:: GATv2Conv , M:: AbstractArray , X)
476
+ if gat. concat || gat. heads == 1
477
+ M = gat. σ .(M) # dims: (out*heads, N, [bch_sz])
478
+ else
479
+ M = reshape (M, :, gat. heads, size (M)[2 : end ]. .. )
480
+ M = gat. σ .(mean (M, dims= 2 ))
481
+ M = reshape (M, :, size (M)[2 : end ]. .. ) # dims: (out, N, [bch_sz])
469
482
end
470
483
return M
471
484
end
472
485
473
- function (gat:: GATv2Conv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix )
486
+ # For variable graph
487
+ function (gat:: GATv2Conv )(fg:: AbstractFeaturedGraph )
488
+ X = node_feature (fg)
474
489
GraphSignals. check_num_nodes (fg, X)
475
- _, X, _ = propagate (gat, fg, edge_feature (fg), X, global_feature (fg), + )
476
- return X
490
+ sg = graph (fg)
491
+ @assert Zygote. ignore (() -> check_self_loops (sg)) " a vertex must have self loop (receive a message from itself)."
492
+ es, nbrs, xs = Zygote. ignore (() -> collect (edges (sg)))
493
+ el = (N= nv (sg), E= ne (sg), es= es, nbrs= nbrs, xs= xs)
494
+ Ē = update_batch_edge (gat, el, nothing , X, nothing )
495
+ V = update_batch_vertex (gat, el, Ē, X, nothing )
496
+ return ConcreteFeaturedGraph (fg, nf= V)
497
+ end
498
+
499
+ # For static graph
500
+ function (l:: GATv2Conv )(el:: NamedTuple , X:: AbstractArray )
501
+ GraphSignals. check_num_nodes (el. N, X)
502
+ # TODO : should have self loops check for el
503
+ Ē = update_batch_edge (l, el, nothing , X, nothing )
504
+ V = update_batch_vertex (l, el, Ē, X, nothing )
505
+ return V
506
+ end
507
+
508
+ function Base. show (io:: IO , l:: GATv2Conv )
509
+ in_channel = size (l. wi, ndims (l. wi))
510
+ out_channel = size (l. wi, ndims (l. wi)- 1 )
511
+ print (io, " GATv2Conv(" , in_channel, " =>" , out_channel)
512
+ l. σ == identity || print (io, " , " , l. σ)
513
+ print (io, " , heads=" , l. heads)
514
+ print (io, " , concat=" , l. concat)
515
+ print (io, " , LeakyReLU(λ=" , l. negative_slope)
516
+ print (io, " ))" )
477
517
end
478
518
479
- (l:: GATv2Conv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
480
519
481
520
482
521
"""
@@ -569,7 +608,7 @@ Edge convolutional layer.
569
608
570
609
# Arguments
571
610
572
- - `nn`: A neural network (e.g. a Dense layer or a MLP).
611
+ - `nn`: A neural network (e.g. a Dense layer or a MLP).
573
612
- `aggr`: An aggregate function applied to the result of message function.
574
613
`+`, `max` and `mean` are available.
575
614
0 commit comments