Skip to content

Commit 9e2c938

Browse files
committed
fix
1 parent 76545f4 commit 9e2c938

File tree

4 files changed

+10
-9
lines changed

4 files changed

+10
-9
lines changed

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
4+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
45

56
[compat]
67
Documenter = "0.27"

src/layers/graph_conv.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ function (l::GATConv)(fg::AbstractFeaturedGraph)
336336
GraphSignals.check_num_nodes(fg, X)
337337
sg = graph(fg)
338338
@assert ChainRulesCore.ignore_derivatives(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
339-
el = to_namedtuple(sg)
339+
el = GraphSignals.to_namedtuple(sg)
340340
_, V, _ = propagate(l, el, nothing, X, nothing, hcat, nothing, nothing)
341341
return ConcreteFeaturedGraph(fg, nf=V)
342342
end
@@ -464,7 +464,7 @@ function (l::GATv2Conv)(fg::AbstractFeaturedGraph)
464464
GraphSignals.check_num_nodes(fg, X)
465465
sg = graph(fg)
466466
@assert ChainRulesCore.ignore_derivatives(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
467-
el = to_namedtuple(sg)
467+
el = GraphSignals.to_namedtuple(sg)
468468
_, V, _ = propagate(l, el, nothing, X, nothing, hcat, nothing, nothing)
469469
return ConcreteFeaturedGraph(fg, nf=V)
470470
end
@@ -540,7 +540,7 @@ update(ggc::GatedGraphConv, m::AbstractArray, x) = m
540540
function (l::GatedGraphConv)(fg::AbstractFeaturedGraph)
541541
nf = node_feature(fg)
542542
GraphSignals.check_num_nodes(fg, nf)
543-
V = l(GraphSignals.to_namedtuple(fg), nf)
543+
V = l(GraphSignals.GraphSignals.to_namedtuple(fg), nf)
544544
return ConcreteFeaturedGraph(fg, nf=V)
545545
end
546546

test/layers/group_conv.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
@test size(nf_) == (out_channel, N)
2929
@test size(pf_) == (pos_dim, N)
3030

31-
g = Zygote.gradient(() -> sum(node_feature(egnn(fg))), Flux.params(egnn))
31+
g = gradient(() -> sum(node_feature(egnn(fg))), Flux.params(egnn))
3232
@test length(g.grads) == 13
3333
end
3434

@@ -45,7 +45,7 @@
4545
@test size(nf_) == (out_channel, N)
4646
@test size(pf_) == (pos_dim, N)
4747

48-
g = Zygote.gradient(() -> sum(node_feature(egnn(fg))), Flux.params(egnn))
48+
g = gradient(() -> sum(node_feature(egnn(fg))), Flux.params(egnn))
4949
@test length(g.grads) == 13
5050
end
5151

@@ -58,7 +58,7 @@
5858
@test size(H) == (out_channel, N, batch_size)
5959
@test size(Y) == (pos_dim, N, batch_size)
6060

61-
g = Zygote.gradient(() -> sum(l(nf, ef)[1]), Flux.params(l))
61+
g = gradient(() -> sum(l(nf, ef)[1]), Flux.params(l))
6262
@test length(g.grads) == 11
6363
end
6464

@@ -70,7 +70,7 @@
7070
@test size(H) == (out_channel, N, batch_size)
7171
@test size(Y) == (pos_dim, N, batch_size)
7272

73-
g = Zygote.gradient(() -> sum(l(nf)[1]), Flux.params(l))
73+
g = gradient(() -> sum(l(nf)[1]), Flux.params(l))
7474
@test length(g.grads) == 11
7575
end
7676
end

test/layers/positional.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
fg_ = l(fg)
2424
@test size(node_feature(fg_)) == (out_channel, N)
2525

26-
g = Zygote.gradient(() -> sum(node_feature(l(fg))), Flux.params(l))
26+
g = gradient(() -> sum(node_feature(l(fg))), Flux.params(l))
2727
@test length(g.grads) == 4
2828
end
2929

@@ -34,7 +34,7 @@
3434
Y = l(nf, ef)
3535
@test size(Y) == (out_channel, N, batch_size)
3636

37-
g = Zygote.gradient(() -> sum(l(nf, ef)), Flux.params(l))
37+
g = gradient(() -> sum(l(nf, ef)), Flux.params(l))
3838
@test length(g.grads) == 2
3939
end
4040
end

0 commit comments

Comments
 (0)