Skip to content

Commit c4ea40a

Browse files
committed
API design
fix
1 parent 63ce513 commit c4ea40a

File tree

6 files changed

+62
-46
lines changed

6 files changed

+62
-46
lines changed

src/GeometricFlux.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export
3131
# layers/msgpass
3232
MessagePassing,
3333

34-
# layers/conv
34+
# layers/graph_conv
3535
GCNConv,
3636
ChebConv,
3737
GraphConv,
@@ -44,6 +44,8 @@ export
4444
SAGEConv,
4545
MeanAggregator, MeanPoolAggregator, MaxPoolAggregator,
4646
LSTMAggregator,
47+
48+
# layers/group_conv
4749
EEquivGraphConv,
4850

4951
# layer/pool
@@ -73,8 +75,8 @@ include("layers/graphlayers.jl")
7375
include("layers/gn.jl")
7476
include("layers/msgpass.jl")
7577

76-
include("layers/conv.jl")
77-
include("layers/groups.jl")
78+
include("layers/graph_conv.jl")
79+
include("layers/group_conv.jl")
7880
include("layers/pool.jl")
7981
include("models.jl")
8082

File renamed without changes.

src/layers/groups.jl renamed to src/layers/group_conv.jl

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,39 @@
11
"""
22
EEquivGraphConv((in_dim, int_dim, out_dim); init)
33
"""
4-
5-
struct EEquivGraphConv <: MessagePassing
6-
nn_edge
7-
nn_x
8-
nn_h
9-
10-
in_dim
11-
int_dim
12-
out_dim
4+
struct EEquivGraphConv{E,X,H} <: MessagePassing
5+
nn_edge::E
6+
nn_x::X
7+
nn_h::H
8+
in_dim::Int
9+
int_dim::Int
10+
out_dim::Int
1311
end
1412

1513
@functor EEquivGraphConv
1614

17-
function EEquivGraphConv(dims::NTuple{3,Int}; init=glorot_uniform)
18-
in_dim, int_dim, out_dim = dims
19-
20-
m_len = in_dim * 2 + 2
21-
15+
function EEquivGraphConv(in_dim::Int, int_dim::Int, out_dim::Int; init=glorot_uniform)
16+
m_len = 2in_dim + 2
2217
nn_edge = Flux.Dense(m_len, int_dim; init=init)
23-
2418
nn_x = Flux.Dense(int_dim, 1; init=init)
2519
nn_h = Flux.Dense(in_dim + int_dim, out_dim; init=init)
26-
27-
return EEquivGraphConv(nn_edge, nn_x, nn_h, dims...)
20+
return EEquivGraphConv(nn_edge, nn_x, nn_h, in_dim, int_dim, out_dim)
2821
end
2922

30-
function EEquivGraphConv(nn_edge, nn_x, nn_h; init=glorot_uniform)
31-
32-
# Assume that these are strictly MLPs (no conv)
33-
nn_edge.init(init)
34-
nn_x.init(init)
35-
nn_h.init(init)
36-
37-
in_dim = nn_edge.layers[1].W |> x->size(x)[2]
38-
int_dim = nn_edge.layers[end].W |> x->size(x)[1]
39-
out_dim = nn_h.layers[end].W |> x->size(x)[1]
23+
function EEquivGraphConv(in_dim::Int, nn_edge, nn_x, nn_h)
24+
m_len = 2in_dim + 2
25+
int_dim = Flux.outputsize(nn_edge, (m_len, 2))[1]
26+
out_dim = Flux.outputsize(nn_h, (in_dim + int_dim, 2))[1]
4027
return EEquivGraphConv(nn_edge, nn_x, nn_h, in_dim, int_dim, out_dim)
4128
end
4229

30+
function ϕ_edge(egnn::EEquivGraphConv, h_i, h_j, dist, a)
31+
N = size(h_i, 2)
32+
return egnn.nn_edge(vcat(h_i, h_j, dist, ones(N)' * a))
33+
end
34+
35+
ϕ_x(egnn::EEquivGraphConv, m_ij) = egnn.nn_x(m_ij)
36+
4337
function message(egnn::EEquivGraphConv, v_i, v_j, e)
4438
in_dim = egnn.in_dim
4539
h_i = v_i[1:in_dim,:]
@@ -56,9 +50,9 @@ function message(egnn::EEquivGraphConv, v_i, v_j, e)
5650
a = e[1]
5751
end
5852

59-
input = vcat(h_i, h_j, sum(abs2.(x_i - x_j); dims=1), ones(N)' * a)
60-
edge_msg = egnn.nn_edge(input)
61-
output_vec = vcat(edge_msg, (x_i - x_j) .* egnn.nn_x(edge_msg)[1], ones(N)')
53+
dist = sum(abs2.(x_i - x_j); dims=1)
54+
edge_msg = ϕ_edge(egnn, h_i, h_j, dist, a)
55+
output_vec = vcat(edge_msg, (x_i - x_j) .* ϕ_x(egnn, edge_msg)[1], ones(N)')
6256
return reshape(output_vec, :, N)
6357
end
6458

test/layers/conv.jl renamed to test/layers/graph_conv.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "layer" begin
1+
@testset "graph conv" begin
22
T = Float32
33
batch_size = 10
44
in_channel = 3
@@ -424,15 +424,4 @@
424424
end
425425
end
426426
end
427-
428-
@testset "EEquivGraphConv" begin
429-
@testset "layer without static graph" begin
430-
int_dim = 5
431-
egnn = EEquivGraphConv((in_channel, int_dim, out_channel))
432-
433-
nf = rand(T, in_channel + 3, N)
434-
fg = FeaturedGraph(adj, nf=nf)
435-
fg_ = egnn(fg)
436-
end
437-
end
438427
end

test/layers/group_conv.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
@testset "group conv" begin
2+
T = Float32
3+
batch_size = 10
4+
in_channel = 3
5+
out_channel = 5
6+
7+
N = 4
8+
E = 4
9+
adj = T[0. 1. 0. 1.;
10+
1. 0. 1. 0.;
11+
0. 1. 0. 1.;
12+
1. 0. 1. 0.]
13+
fg = FeaturedGraph(adj)
14+
15+
@testset "EEquivGraphConv" begin
16+
@testset "layer without static graph" begin
17+
int_dim = 5
18+
m_len = in_channel * 2 + 2
19+
20+
nn_edge = Flux.Dense(m_len, int_dim)
21+
nn_x = Flux.Dense(int_dim, 1)
22+
nn_h = Flux.Dense(in_channel + int_dim, out_channel)
23+
egnn = EEquivGraphConv(in_channel, nn_edge, nn_x, nn_h)
24+
25+
nf = rand(T, in_channel + 3, N)
26+
fg = FeaturedGraph(adj, nf=nf)
27+
fg_ = egnn(fg)
28+
end
29+
end
30+
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ cuda_tests = [
2222
tests = [
2323
"layers/gn",
2424
"layers/msgpass",
25-
"layers/conv",
25+
"layers/graph_conv",
26+
"layers/group_conv",
2627
"layers/pool",
2728
"layers/graphlayers",
2829
"sampling",

0 commit comments

Comments
 (0)