@@ -23,7 +23,7 @@ EEquivGraphConv(ϕ_edge=Dense(10 => 5), ϕ_x=Dense(5 => 2), ϕ_h=Dense(8 => 5))
23
23
24
24
See also [`WithGraph`](@ref) for training layer with static graph and [`EEquivGraphPE`](@ref) for positional encoding.
25
25
"""
26
- struct EEquivGraphConv{X,E,H}
26
+ struct EEquivGraphConv{X,E,H} <: AbstractGraphLayer
27
27
pe:: X
28
28
nn_edge:: E
29
29
nn_h:: H
@@ -61,6 +61,15 @@ function(egnn::EEquivGraphConv)(fg::AbstractFeaturedGraph)
61
61
return ConcreteFeaturedGraph (fg, nf= V, pf= X)
62
62
end
63
63
64
+ # For static graph
65
+ function (l:: EEquivGraphConv )(el:: NamedTuple , H:: AbstractArray , E:: AbstractArray , X:: AbstractArray )
66
+ GraphSignals. check_num_nodes (el. N, H)
67
+ GraphSignals. check_num_nodes (el. N, X)
68
+ GraphSignals. check_num_edges (el. E, E)
69
+ _, V, X = propagate (l, el, E, H, X, + )
70
+ return V, X
71
+ end
72
+
64
73
function Base. show (io:: IO , l:: EEquivGraphConv )
65
74
print (io, " EEquivGraphConv(ϕ_edge=" , l. nn_edge)
66
75
print (io, " , ϕ_x=" , l. pe. nn)
@@ -96,3 +105,6 @@ function propagate(l::EEquivGraphConv, el::NamedTuple, E, V, X, aggr)
96
105
V = update (l, Ē, V)
97
106
return E, V, X
98
107
end
108
+
109
+ WithGraph (fg:: AbstractFeaturedGraph , l:: EEquivGraphConv ) = WithGraph (to_namedtuple (fg), l)
110
+ (wg:: WithGraph{<:EEquivGraphConv} )(args... ) = wg. layer (wg. graph, args... )
0 commit comments