@@ -23,7 +23,7 @@ GCNConv(1024 => 256, relu)
23
23
24
24
See also [`WithGraph`](@ref) for training layer with static graph.
25
25
"""
26
- struct GCNConv{A<: AbstractMatrix ,B,F}
26
+ struct GCNConv{A<: AbstractMatrix ,B,F} <: AbstractGraphLayer
27
27
weight:: A
28
28
bias:: B
29
29
σ:: F
57
57
58
58
# For static graph
59
59
WithGraph (fg:: AbstractFeaturedGraph , l:: GCNConv ) =
60
- WithGraph (l, GraphSignals. normalized_adjacency_matrix! (fg, eltype (l. weight); selfloop= true ))
60
+ WithGraph (GraphSignals. normalized_adjacency_matrix (fg, eltype (l. weight); selfloop= true ), l )
61
61
62
62
function (wg:: WithGraph{<:GCNConv} )(X:: AbstractArray )
63
- Ã = Zygote. ignore () do
64
- GraphSignals. normalized_adjacency_matrix (wg. fg)
65
- end
63
+ Ã = wg. graph
66
64
return wg. layer (Ã, X)
67
65
end
68
66
75
73
76
74
77
75
"""
78
- ChebConv([fg,] in=>out, k; bias=true, init=glorot_uniform)
76
+ ChebConv(in=>out, k; bias=true, init=glorot_uniform)
79
77
80
78
Chebyshev spectral graph convolutional layer.
81
79
82
80
# Arguments
83
81
84
- - `fg`: Optionally pass a [`FeaturedGraph`](@ref).
85
82
- `in`: The dimension of input features.
86
83
- `out`: The dimension of output features.
87
84
- `k`: The order of Chebyshev polynomial.
88
85
- `bias`: Add learnable bias.
89
86
- `init`: Weights' initializer.
87
+
88
+ # Example
89
+
90
+ ```jldoctest
91
+ julia> cc = ChebConv(1024=>256, 5, relu)
92
+ ChebConv(1024 => 256, k=5, relu)
93
+ ```
94
+
95
+ See also [`WithGraph`](@ref) for training layer with static graph.
90
96
"""
91
- struct ChebConv{A<: AbstractArray{<:Number,3} , B, S <: AbstractFeaturedGraph } <: AbstractGraphLayer
97
+ struct ChebConv{A<: AbstractArray{<:Number,3} ,B,F } <: AbstractGraphLayer
92
98
weight:: A
93
99
bias:: B
94
- fg:: S
95
100
k:: Int
101
+ σ:: F
96
102
end
97
103
98
- function ChebConv (fg :: AbstractFeaturedGraph , ch:: Pair{Int,Int} , k:: Int ;
104
+ function ChebConv (ch:: Pair{Int,Int} , k:: Int , σ = identity ;
99
105
init= glorot_uniform, bias:: Bool = true )
100
106
in, out = ch
101
107
W = init (out, in, k)
102
108
b = Flux. create_bias (W, bias, out)
103
- ChebConv (W, b, fg, k )
109
+ ChebConv (W, b, k, σ )
104
110
end
105
111
106
- ChebConv (ch:: Pair{Int,Int} , k:: Int ; kwargs... ) =
107
- ChebConv (NullGraph (), ch, k; kwargs... )
108
-
109
112
@functor ChebConv
110
113
111
114
Flux. trainable (l:: ChebConv ) = (l. weight, l. bias)
112
115
113
- function (c:: ChebConv )(fg:: AbstractFeaturedGraph , X:: AbstractMatrix{T} ) where T
114
- GraphSignals. check_num_nodes (fg, X)
115
- @assert size (X, 1 ) == size (c. weight, 2 ) " Input feature size must match input channel size."
116
-
117
- L̃ = Zygote. ignore () do
118
- GraphSignals. scaled_laplacian (fg, eltype (X))
119
- end
120
-
116
+ function (l:: ChebConv )(L̃:: AbstractMatrix , X:: AbstractMatrix )
121
117
Z_prev = X
122
118
Z = X * L̃
123
- Y = view (c . weight,:,:,1 ) * Z_prev
124
- Y += view (c . weight,:,:,2 ) * Z
125
- for k = 3 : c . k
119
+ Y = view (l . weight,:,:,1 ) * Z_prev
120
+ Y += view (l . weight,:,:,2 ) * Z
121
+ for k = 3 : l . k
126
122
Z, Z_prev = 2 .* Z * L̃ - Z_prev, Z
127
- Y += view (c. weight,:,:,k) * Z
123
+ Y += view (l. weight,:,:,k) * Z
124
+ end
125
+ return l. σ .(Y .+ l. bias)
126
+ end
127
+
128
+ function (l:: ChebConv )(L̃:: AbstractMatrix , X:: AbstractArray )
129
+ Z_prev = X
130
+ Z = NNlib. batched_mul (X, L̃)
131
+ Y = NNlib. batched_mul (view (l. weight,:,:,1 ), Z_prev)
132
+ Y += NNlib. batched_mul (view (l. weight,:,:,2 ), Z)
133
+ for k = 3 : l. k
134
+ Z, Z_prev = 2 .* NNlib. batched_mul (Z, L̃) .- Z_prev, Z
135
+ Y += NNlib. batched_mul (view (l. weight,:,:,k), Z)
136
+ end
137
+ return l. σ .(Y .+ l. bias)
138
+ end
139
+
140
+ # For variable graph
141
+ function (l:: ChebConv )(fg:: AbstractFeaturedGraph )
142
+ nf = node_feature (fg)
143
+ GraphSignals. check_num_nodes (fg, nf)
144
+ @assert size (nf, 1 ) == size (l. weight, 2 ) " Input feature size must match input channel size."
145
+
146
+ L̃ = Zygote. ignore () do
147
+ GraphSignals. scaled_laplacian (fg, eltype (nf))
128
148
end
129
- return Y .+ c . bias
149
+ return ConcreteFeaturedGraph (fg, nf = l (L̃, nf))
130
150
end
131
151
132
- (l:: ChebConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
152
+ # For static graph
153
+ WithGraph (fg:: AbstractFeaturedGraph , l:: ChebConv ) =
154
+ WithGraph (GraphSignals. scaled_laplacian (fg, eltype (l. weight)), l)
155
+
156
+ function (wg:: WithGraph{<:ChebConv} )(X:: AbstractArray )
157
+ L̃ = wg. graph
158
+ return wg. layer (L̃, X)
159
+ end
133
160
134
161
function Base. show (io:: IO , l:: ChebConv )
135
162
out, in, k = size (l. weight)
136
163
print (io, " ChebConv(" , in, " => " , out)
137
164
print (io, " , k=" , k)
165
+ l. σ == identity || print (io, " , " , l. σ)
138
166
print (io, " )" )
139
167
end
140
168
192
220
193
221
(l:: GraphConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
194
222
# (l::GraphConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
223
+ (l:: GraphConv )(x:: AbstractMatrix ) = l (l. fg, x)
224
+ (l:: GraphConv )(:: NullGraph , x:: AbstractMatrix ) = throw (ArgumentError (" concrete FeaturedGraph is not provided." ))
195
225
196
226
function Base. show (io:: IO , l:: GraphConv )
197
227
in_channel = size (l. weight1, ndims (l. weight1))
307
337
308
338
(l:: GATConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
309
339
# (l::GATConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
340
+ (l:: GATConv )(x:: AbstractMatrix ) = l (l. fg, x)
341
+ (l:: GATConv )(:: NullGraph , x:: AbstractMatrix ) = throw (ArgumentError (" concrete FeaturedGraph is not provided." ))
310
342
311
343
function Base. show (io:: IO , l:: GATConv )
312
344
in_channel = size (l. weight, ndims (l. weight))
@@ -358,13 +390,13 @@ message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j
358
390
update (ggc:: GatedGraphConv , m:: AbstractVector , x) = m
359
391
360
392
361
- function (ggc:: GatedGraphConv )(fg:: AbstractFeaturedGraph , H:: AbstractMatrix{S } ) where {T<: AbstractVector ,S <: Real }
393
+ function (ggc:: GatedGraphConv )(fg:: AbstractFeaturedGraph , H:: AbstractMatrix{T } ) where {T<: Real }
362
394
GraphSignals. check_num_nodes (fg, H)
363
395
m, n = size (H)
364
396
@assert (m <= ggc. out_ch) " number of input features must less or equals to output features."
365
397
if m < ggc. out_ch
366
398
Hpad = Zygote. ignore () do
367
- fill! (similar (H, S , ggc. out_ch - m, n), 0 )
399
+ fill! (similar (H, T , ggc. out_ch - m, n), 0 )
368
400
end
369
401
H = vcat (H, Hpad)
370
402
end
378
410
379
411
(l:: GatedGraphConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
380
412
# (l::GatedGraphConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
413
+ (l:: GatedGraphConv )(x:: AbstractMatrix ) = l (l. fg, x)
414
+ (l:: GatedGraphConv )(:: NullGraph , x:: AbstractMatrix ) = throw (ArgumentError (" concrete FeaturedGraph is not provided." ))
381
415
382
416
383
417
function Base. show (io:: IO , l:: GatedGraphConv )
423
457
424
458
(l:: EdgeConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
425
459
# (l::EdgeConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, l.aggr) # edge number check break this
460
+ (l:: EdgeConv )(x:: AbstractMatrix ) = l (l. fg, x)
461
+ (l:: EdgeConv )(:: NullGraph , x:: AbstractMatrix ) = throw (ArgumentError (" concrete FeaturedGraph is not provided." ))
426
462
427
463
function Base. show (io:: IO , l:: EdgeConv )
428
464
print (io, " EdgeConv(" , l. nn)
475
511
476
512
(l:: GINConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
477
513
# (l::GINConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
514
+ (l:: GINConv )(x:: AbstractMatrix ) = l (l. fg, x)
515
+ (l:: GINConv )(:: NullGraph , x:: AbstractMatrix ) = throw (ArgumentError (" concrete FeaturedGraph is not provided." ))
478
516
479
517
480
518
"""
0 commit comments