Skip to content

Commit 4dc9ed0

Browse files
committed
support batch
1 parent ac71234 commit 4dc9ed0

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

src/layers/group_conv.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ EEquivGraphConv(ϕ_edge=Dense(10 => 5), ϕ_x=Dense(5 => 2), ϕ_h=Dense(8 => 5))
2323
2424
See also [`WithGraph`](@ref) for training layer with static graph and [`EEquivGraphPE`](@ref) for positional encoding.
2525
"""
26-
struct EEquivGraphConv{X,E,H}
26+
struct EEquivGraphConv{X,E,H} <: AbstractGraphLayer
2727
pe::X
2828
nn_edge::E
2929
nn_h::H
@@ -61,6 +61,15 @@ function(egnn::EEquivGraphConv)(fg::AbstractFeaturedGraph)
6161
return ConcreteFeaturedGraph(fg, nf=V, pf=X)
6262
end
6363

64+
# For static graph
65+
function(l::EEquivGraphConv)(el::NamedTuple, H::AbstractArray, E::AbstractArray, X::AbstractArray)
66+
GraphSignals.check_num_nodes(el.N, H)
67+
GraphSignals.check_num_nodes(el.N, X)
68+
GraphSignals.check_num_edges(el.E, E)
69+
_, V, X = propagate(l, el, E, H, X, +)
70+
return V, X
71+
end
72+
6473
function Base.show(io::IO, l::EEquivGraphConv)
6574
print(io, "EEquivGraphConv(ϕ_edge=", l.nn_edge)
6675
print(io, ", ϕ_x=", l.pe.nn)
@@ -96,3 +105,6 @@ function propagate(l::EEquivGraphConv, el::NamedTuple, E, V, X, aggr)
96105
V = update(l, Ē, V)
97106
return E, V, X
98107
end
108+
109+
WithGraph(fg::AbstractFeaturedGraph, l::EEquivGraphConv) = WithGraph(to_namedtuple(fg), l)
110+
(wg::WithGraph{<:EEquivGraphConv})(args...) = wg.layer(wg.graph, args...)

test/layers/group_conv.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,18 @@
3232
g = Zygote.gradient(() -> sum(node_feature(egnn(fg))), Flux.params(egnn))
3333
@test length(g.grads) == 8
3434
end
35+
36+
@testset "layer with static graph" begin
37+
nf = rand(T, in_channel, N, batch_size)
38+
ef = rand(T, in_channel_edge, E, batch_size)
39+
pf = rand(T, pos_dim, N, batch_size)
40+
l = WithGraph(fg, EEquivGraphConv(in_channel=>out_channel, pos_dim, in_channel_edge))
41+
H, Y = l(nf, ef, pf)
42+
@test size(H) == (out_channel, N, batch_size)
43+
@test size(Y) == (pos_dim, N, batch_size)
44+
45+
g = Zygote.gradient(() -> sum(l(nf, ef, pf)[1]), Flux.params(l))
46+
@test length(g.grads) == 6
47+
end
3548
end
3649
end

0 commit comments

Comments
 (0)