272
272
273
273
function update_batch_edge (gat:: GATConv , fg:: AbstractFeaturedGraph , E:: AbstractMatrix , X:: AbstractMatrix , u)
274
274
@assert Zygote. ignore (() -> check_self_loops (graph (fg))) " a vertex must have self loop (receive a message from itself)."
275
- nodes = Zygote. ignore (()-> vertices (fg ))
275
+ nodes = Zygote. ignore (()-> vertices (graph (fg) ))
276
276
nbr = i-> cpu (GraphSignals. neighbors (graph (fg), i))
277
277
ms = map (i -> graph_attention (gat, i, Zygote. ignore (()-> nbr (i)), X), nodes)
278
278
M = hcat_by_sum (ms)
@@ -292,7 +292,7 @@ function update_batch_vertex(gat::GATConv, ::AbstractFeaturedGraph, M::AbstractM
292
292
M = M .+ gat. bias
293
293
if ! gat. concat
294
294
N = size (M, 2 )
295
- M = reshape (mean (reshape (M, :, gat. heads, N), dims= 2 ), :, N)
295
+ M = reshape (mean (reshape (M, gat. heads, :, N), dims= 1 ), :, N)
296
296
end
297
297
return M
298
298
end
@@ -315,6 +315,102 @@ function Base.show(io::IO, l::GATConv)
315
315
end
316
316
317
317
318
+ """
319
+ GATv2Conv([fg,] in => out;
320
+ heads=1,
321
+ concat=true,
322
+ init=glorot_uniform
323
+ negative_slope=0.2)
324
+
325
+ GATv2 Layer as introduced in https://arxiv.org/abs/2105.14491
326
+
327
+ # Arguments
328
+
329
+ - `fg`: Optionally pass a [`FeaturedGraph`](@ref).
330
+ - `in`: The dimension of input features.
331
+ - `out`: The dimension of output features.
332
+ - `heads`: Number attention heads
333
+ - `concat`: Concatenate layer output or not. If not, layer output is averaged.
334
+ - `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
335
+ """
336
+ struct GATv2Conv{V<: AbstractFeaturedGraph , T, A<: AbstractMatrix{T} , B} <: MessagePassing
337
+ fg:: V
338
+ wi:: A
339
+ wj:: A
340
+ biasi:: B
341
+ biasj:: B
342
+ a:: A
343
+ negative_slope:: T
344
+ channel:: Pair{Int, Int}
345
+ heads:: Int
346
+ concat:: Bool
347
+ end
348
+
349
+ function GATv2Conv (
350
+ fg:: AbstractFeaturedGraph ,
351
+ ch:: Pair{Int,Int} ;
352
+ heads:: Int = 1 ,
353
+ concat:: Bool = true ,
354
+ negative_slope= 0.2f0 ,
355
+ bias:: Bool = true ,
356
+ init= glorot_uniform,
357
+ )
358
+ in, out = ch
359
+ wi = init (out* heads, in)
360
+ wj = init (out* heads, in)
361
+ bi = Flux. create_bias (wi, bias, out* heads)
362
+ bj = Flux. create_bias (wj, bias, out* heads)
363
+ a = init (out, heads)
364
+ GATv2Conv (fg, wi, wj, bi, bj, a, negative_slope, ch, heads, concat)
365
+ end
366
+
367
+ GATv2Conv (ch:: Pair{Int,Int} ; kwargs... ) = GATv2Conv (NullGraph (), ch; kwargs... )
368
+
369
+ @functor GATv2Conv
370
+
371
+ Flux. trainable (l:: GATv2Conv ) = (l. wi, l. wj, l. biasi, l. biasj, l. a)
372
+
373
+ function message (gat:: GATv2Conv , x_i:: AbstractVector , x_j:: AbstractVector )
374
+ xi = reshape (gat. wi * x_i + gat. biasi, :, gat. heads)
375
+ xj = reshape (gat. wj * x_j + gat. biasj, :, gat. heads)
376
+ eij = gat. a' * leakyrelu .(xi + xj, gat. negative_slope)
377
+ vcat (eij, xj)
378
+ end
379
+
380
+ function graph_attention (gat:: GATv2Conv , i, js, X:: AbstractMatrix )
381
+ e_ij = mapreduce (j -> GeometricFlux. message (gat, _view (X, i), _view (X, j)), hcat, js)
382
+ n = size (e_ij, 1 )
383
+ αs = Flux. softmax (reshape (view (e_ij, 1 , :), gat. heads, :), dims= 2 )
384
+ msgs = view (e_ij, 2 : n, :) .* reshape (αs, 1 , :)
385
+ reshape (msgs, (n- 1 )* gat. heads, :)
386
+ end
387
+
388
+ function update_batch_edge (gat:: GATv2Conv , fg:: AbstractFeaturedGraph , E:: AbstractMatrix , X:: AbstractMatrix , u)
389
+ @assert Zygote. ignore (() -> check_self_loops (graph (fg))) " a vertex must have self loop (receive a message from itself)."
390
+ nodes = Zygote. ignore (()-> vertices (graph (fg)))
391
+ nbr = i-> cpu (GraphSignals. neighbors (graph (fg), i))
392
+ ms = map (i -> graph_attention (gat, i, Zygote. ignore (()-> nbr (i)), X), nodes)
393
+ M = hcat_by_sum (ms)
394
+ return M
395
+ end
396
+
397
+ function update_batch_vertex (gat:: GATv2Conv , :: AbstractFeaturedGraph , M:: AbstractMatrix , X:: AbstractMatrix , u)
398
+ if ! gat. concat
399
+ N = size (M, 2 )
400
+ M = reshape (mean (reshape (M, gat. heads, :, N), dims= 1 ), :, N)
401
+ end
402
+ return M
403
+ end
404
+
405
+ function (gat:: GATv2Conv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix )
406
+ GraphSignals. check_num_nodes (fg, X)
407
+ _, X, _ = propagate (gat, fg, edge_feature (fg), X, global_feature (fg), + )
408
+ return X
409
+ end
410
+
411
+ (l:: GATv2Conv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
412
+
413
+
318
414
"""
319
415
GatedGraphConv([fg,] out, num_layers; aggr=+, init=glorot_uniform)
320
416
0 commit comments