Skip to content

Commit 491ad51

Browse files
authored
Merge pull request #341 from FluxML/refactor
Refactor and makes internal functions as internal APIs
2 parents 6d2d9b1 + e96217e commit 491ad51

File tree

6 files changed

+63
-88
lines changed

6 files changed

+63
-88
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GeometricFlux"
22
uuid = "7e08b658-56d3-11e9-2997-919d5b31e4ea"
33
authors = ["Yueh-Hua Tu <[email protected]>"]
4-
version = "0.13.6"
4+
version = "0.14.0"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/layers/gn.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ See also [`update_edge`](@ref), [`update_vertex`](@ref), [`update_global`](@ref)
7575
update_batch_edge(gn::GraphNet, el::NamedTuple, E, V, u) =
7676
update_edge(
7777
gn,
78-
_gather(E, el.es),
79-
_gather(V, el.xs),
80-
_gather(V, el.nbrs),
78+
gather(E, el.es),
79+
gather(V, el.xs),
80+
gather(V, el.nbrs),
8181
u
8282
)
8383

@@ -116,17 +116,11 @@ See also [`update_edge`](@ref), [`update_vertex`](@ref), [`update_global`](@ref)
116116
[`update_batch_edge`](@ref), [`update_batch_vertex`](@ref), [`aggregate_edges`](@ref),
117117
[`aggregate_vertices`](@ref).
118118
"""
119-
function aggregate_neighbors(::GraphNet, el::NamedTuple, aggr, E)
120-
batch_size = size(E)[end]
121-
dstsize = (size(E, 1), el.N, batch_size)
122-
xs = batched_index(el.xs, batch_size)
123-
return _scatter(aggr, E, xs, dstsize)
124-
end
125-
126-
aggregate_neighbors(::GraphNet, el::NamedTuple, aggr, E::AbstractMatrix) = _scatter(aggr, E, el.xs)
119+
aggregate_neighbors(::GraphNet, el::NamedTuple, aggr, E) = scatter(aggr, E, el.xs, el.N)
120+
aggregate_neighbors(::GraphNet, el::NamedTuple, aggr, E::AbstractMatrix) = scatter(aggr, E, el.xs)
127121

128-
@inline aggregate_neighbors(::GraphNet, ::NamedTuple, ::Nothing, E) = nothing
129-
@inline aggregate_neighbors(::GraphNet, ::NamedTuple, ::Nothing, ::AbstractMatrix) = nothing
122+
aggregate_neighbors(::GraphNet, ::NamedTuple, ::Nothing, E) = nothing
123+
aggregate_neighbors(::GraphNet, ::NamedTuple, ::Nothing, ::AbstractMatrix) = nothing
130124

131125
"""
132126
aggregate_edges(gn, aggr, E)

src/layers/graph_conv.jl

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ end
3838
@functor GCNConv
3939

4040
function (l::GCNConv)(Ã::AbstractMatrix, X::AbstractArray)
41-
z = _matmul(l.weight, _matmul(X, Ã))
41+
z = matmul(l.weight, matmul(X, Ã))
4242
return l.σ.(z .+ l.bias)
4343
end
4444

@@ -113,12 +113,12 @@ Flux.trainable(l::ChebConv) = (l.weight, l.bias)
113113

114114
function (l::ChebConv)(L̃::AbstractMatrix, X::AbstractArray)
115115
Z_prev = X
116-
Z = _matmul(X, L̃)
117-
Y = _matmul(view(l.weight,:,:,1), Z_prev)
118-
Y += _matmul(view(l.weight,:,:,2), Z)
116+
Z = matmul(X, L̃)
117+
Y = matmul(view(l.weight,:,:,1), Z_prev)
118+
Y += matmul(view(l.weight,:,:,2), Z)
119119
for k = 3:l.k
120-
Z, Z_prev = 2 .* _matmul(Z, L̃) .- Z_prev, Z
121-
Y += _matmul(view(l.weight,:,:,k), Z)
120+
Z, Z_prev = 2 .* matmul(Z, L̃) .- Z_prev, Z
121+
Y += matmul(view(l.weight,:,:,k), Z)
122122
end
123123
return l.σ.(Y .+ l.bias)
124124
end
@@ -203,9 +203,9 @@ end
203203

204204
Flux.trainable(l::GraphConv) = (l.weight1, l.weight2, l.bias)
205205

206-
message(gc::GraphConv, x_i, x_j::AbstractArray, e_ij) = _matmul(gc.weight2, x_j)
206+
message(gc::GraphConv, x_i, x_j::AbstractArray, e_ij) = matmul(gc.weight2, x_j)
207207

208-
update(gc::GraphConv, m::AbstractArray, x::AbstractArray) = gc.σ.(_matmul(gc.weight1, x) .+ m .+ gc.bias)
208+
update(gc::GraphConv, m::AbstractArray, x::AbstractArray) = gc.σ.(matmul(gc.weight1, x) .+ m .+ gc.bias)
209209

210210
# For variable graph
211211
function (l::GraphConv)(fg::AbstractFeaturedGraph)
@@ -299,7 +299,7 @@ function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractMatrix, u
299299
end
300300

301301
function update_batch_edge(gat::GATConv, el::NamedTuple, E, X::AbstractArray, u)
302-
Xi, Xj = _gather(X, el.xs), _gather(X, el.nbrs)
302+
Xi, Xj = gather(X, el.xs), gather(X, el.nbrs)
303303
_, nb, bch_sz = size(Xj)
304304
heads = gat.heads
305305
Q = reshape(NNlib.batched_mul(gat.weight, Xi), :, heads, nb, bch_sz) # dims: (out, heads, nb, bch_sz)
@@ -429,7 +429,7 @@ function update_batch_edge(gat::GATv2Conv, el::NamedTuple, E, X::AbstractMatrix,
429429
end
430430

431431
function update_batch_edge(gat::GATv2Conv, el::NamedTuple, E, X::AbstractArray, u)
432-
Xi, Xj = _gather(X, el.xs), _gather(X, el.nbrs)
432+
Xi, Xj = gather(X, el.xs), gather(X, el.nbrs)
433433
_, nb, bch_sz = size(Xj)
434434
heads = gat.heads
435435
Q = reshape(NNlib.batched_mul(gat.wi, Xi) .+ gat.biasi, :, heads, nb, bch_sz) # dims: (out, heads, nb, bch_sz)
@@ -556,7 +556,7 @@ function (l::GatedGraphConv)(el::NamedTuple, H::AbstractArray{T}) where {T<:Real
556556
H = vcat(H, Hpad)
557557
end
558558
for i = 1:l.num_layers
559-
M = _matmul(view(l.weight, :, :, i), H)
559+
M = matmul(view(l.weight, :, :, i), H)
560560
_, M = propagate(l, el, nothing, M, nothing, l.aggr, nothing, nothing)
561561
H, _ = l.gru(H, M)
562562
end
@@ -822,20 +822,17 @@ end
822822
message(l::SAGEConv, x_i, x_j::AbstractArray, e) = l.proj(x_j)
823823

824824
function aggregate_neighbors(l::SAGEConv, el::NamedTuple, aggr, E)
825-
batch_size = size(E)[end]
826825
sample_idx = sample_node_index(E, l.num_sample; dims=2)
827-
idx = ntuple(i -> (i == 2) ? sample_idx : Colon(), ndims(E))
828-
dstsize = (size(E, 1), el.N, batch_size) # ensure outcome has the same dimension as x in update
829-
xs = batched_index(el.xs[sample_idx], batch_size)
830-
= _scatter(aggr, E[idx...], xs, dstsize)
826+
indexed_E = selectdim(E, 2, sample_idx)
827+
= scatter(aggr, indexed_E, el.xs[sample_idx], el.N)
831828
return
832829
end
833830

834831
function aggregate_neighbors(l::SAGEConv, el::NamedTuple, aggr, E::AbstractMatrix)
835832
sample_idx = sample_node_index(E, l.num_sample; dims=2)
836-
idx = ntuple(i -> (i == 2) ? sample_idx : Colon(), ndims(E))
833+
indexed_E = selectdim(E, 2, sample_idx)
837834
dstsize = (size(E, 1), el.N) # ensure outcome has the same dimension as x in update
838-
= _scatter(aggr, E[idx...], el.xs[sample_idx], dstsize)
835+
= NNlib.scatter(aggr, indexed_E, el.xs[sample_idx]; dstsize=dstsize)
839836
return
840837
end
841838

@@ -844,13 +841,13 @@ aggregate_neighbors(::SAGEConv, el::NamedTuple, lstm::Flux.LSTMCell, E::Abstract
844841

845842
function aggregate_neighbors(::SAGEConv, el::NamedTuple, lstm::Flux.LSTMCell, E::AbstractMatrix)
846843
sample_idx = sample_node_index(E, el.N; dims=2)
847-
idx = ntuple(i -> (i == 2) ? sample_idx : Colon(), ndims(E))
848-
state, Ē = lstm(lstm.state0, E[idx...])
844+
indexed_E = selectdim(E, 2, sample_idx)
845+
state, Ē = lstm(lstm.state0, indexed_E)
849846
return
850847
end
851848

852849
function update(l::SAGEConv, m::AbstractArray, x::AbstractArray)
853-
y = l.σ.(_matmul(l.weight1, x) + _matmul(l.weight2, m) .+ l.bias)
850+
y = l.σ.(matmul(l.weight1, x) + matmul(l.weight2, m) .+ l.bias)
854851
l.normalize && (y = l2normalize(y; dims=2)) # across all nodes
855852
return y
856853
end

src/layers/group_conv.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,26 +104,20 @@ function Base.show(io::IO, l::EEquivGraphConv)
104104
print(io, ")")
105105
end
106106

107-
function aggregate_neighbors(::EEquivGraphConv, el::NamedTuple, aggr, E)
108-
batch_size = size(E)[end]
109-
dstsize = (size(E, 1), el.N, batch_size)
110-
xs = batched_index(el.xs, batch_size)
111-
return _scatter(aggr, E, xs, dstsize)
112-
end
113-
114-
aggregate_neighbors(::EEquivGraphConv, el::NamedTuple, aggr, E::AbstractMatrix) = _scatter(aggr, E, el.xs)
107+
aggregate_neighbors(::EEquivGraphConv, el::NamedTuple, aggr, E) = scatter(aggr, E, el.xs, el.N)
108+
aggregate_neighbors(::EEquivGraphConv, el::NamedTuple, aggr, E::AbstractMatrix) = scatter(aggr, E, el.xs)
115109

116-
@inline aggregate_neighbors(::EEquivGraphConv, ::NamedTuple, ::Nothing, E) = nothing
117-
@inline aggregate_neighbors(::EEquivGraphConv, ::NamedTuple, ::Nothing, ::AbstractMatrix) = nothing
110+
aggregate_neighbors(::EEquivGraphConv, ::NamedTuple, ::Nothing, E) = nothing
111+
aggregate_neighbors(::EEquivGraphConv, ::NamedTuple, ::Nothing, ::AbstractMatrix) = nothing
118112

119113
propagate(l::EEquivGraphConv, sg::SparseGraph, E, V, X, aggr) =
120114
propagate(l, GraphSignals.to_namedtuple(sg), E, V, X, aggr)
121115

122116
function propagate(l::EEquivGraphConv, el::NamedTuple, E, V, X, aggr)
123117
E = message(
124-
l, _gather(V, el.xs), _gather(V, el.nbrs),
125-
_gather(X, el.xs), _gather(X, el.nbrs),
126-
_gather(E, el.es)
118+
l, gather(V, el.xs), gather(V, el.nbrs),
119+
gather(X, el.xs), gather(X, el.nbrs),
120+
gather(E, el.es)
127121
)
128122
X = positional_encode(l.pe, el, X, E)
129123
= aggregate_neighbors(l, el, aggr, E)

src/layers/positional.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -275,20 +275,13 @@ update_edge(l::GatedGCNLSPEConv, h_i, h_j, e_ij) = σ.(l.B1(h_i) + l.B2(h_j) + l
275275

276276
function normalize_η(l::GatedGCNLSPEConv, el::NamedTuple, η̂)
277277
summed_η = aggregate_neighbors(l, el, +, η̂)
278-
return η̂ ./ _gather(summed_η .+ 1f-6, el.xs)
278+
return η̂ ./ gather(summed_η .+ 1f-6, el.xs)
279279
end
280280

281281
message_vertex(l::GatedGCNLSPEConv, h_j, p_j, η_ij) = η_ij .* l.A2(vcat(h_j, p_j))
282282

283-
function aggregate_neighbors(l::GatedGCNLSPEConv, el::NamedTuple, aggr, E)
284-
batch_size = size(E)[end]
285-
dstsize = (size(E, 1), el.N, batch_size)
286-
xs = batched_index(el.xs, batch_size)
287-
return _scatter(aggr, E, xs, dstsize)
288-
end
289-
290-
aggregate_neighbors(l::GatedGCNLSPEConv, el::NamedTuple, aggr, E::AbstractMatrix) =
291-
_scatter(aggr, E, el.xs)
283+
aggregate_neighbors(l::GatedGCNLSPEConv, el::NamedTuple, aggr, E) = scatter(aggr, E, el.xs, el.N)
284+
aggregate_neighbors(l::GatedGCNLSPEConv, el::NamedTuple, aggr, E::AbstractMatrix) = scatter(aggr, E, el.xs)
292285

293286
update_vertex(l::GatedGCNLSPEConv, m, h, p) = l.σ.(l.A1(vcat(h, p)) + m)
294287

@@ -300,11 +293,11 @@ propagate(l::GatedGCNLSPEConv, sg::SparseGraph, E, H, X) =
300293
propagate(l, GraphSignals.to_namedtuple(sg), E, H, X)
301294

302295
function propagate(l::GatedGCNLSPEConv, el::NamedTuple, E, H, X)
303-
e_ij = _gather(E, el.es)
304-
h_i = _gather(H, el.xs)
305-
h_j = _gather(H, el.nbrs)
306-
p_i = _gather(X, el.xs)
307-
p_j = _gather(X, el.nbrs)
296+
e_ij = gather(E, el.es)
297+
h_i = gather(H, el.xs)
298+
h_j = gather(H, el.nbrs)
299+
p_i = gather(X, el.xs)
300+
p_j = gather(X, el.nbrs)
308301

309302
η̂ = update_edge(l, h_i, h_j, e_ij)
310303
= l.σ.(η̂)

src/operation.jl

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1-
_gather(::Nothing, idx) = nothing
2-
_gather(A::Fill{T,2,Axes}, idx) where {T,Axes} = fill(A.value, A.axes[1], length(idx))
3-
_gather(A::AbstractMatrix, idx) = NNlib.gather(A, idx)
4-
_gather(A::AbstractArray, idx) = NNlib.gather(A, batched_index(idx, size(A)[end]))
1+
gather(::Nothing, idx) = nothing
2+
gather(A::Fill{T,2,Axes}, idx) where {T,Axes} = fill(A.value, A.axes[1], length(idx))
3+
gather(A::AbstractMatrix, idx) = NNlib.gather(A, idx)
4+
gather(A::AbstractArray, idx) = NNlib.gather(A, batched_index(idx, size(A)[end]))
55

6-
_scatter(aggr, E, xs::AbstractArray) = NNlib.scatter(aggr, E, xs)
7-
_scatter(aggr, E, xs::AbstractArray, dstsize) = NNlib.scatter(aggr, E, xs; dstsize=dstsize)
6+
scatter(aggr, E, xs::AbstractArray) = NNlib.scatter(aggr, E, xs)
87

9-
_matmul(A::AbstractMatrix, B::AbstractMatrix) = A * B
10-
_matmul(A::AbstractArray, B::AbstractArray) = NNlib.batched_mul(A, B)
8+
function scatter(aggr, E, xs::AbstractArray, N::Int)
9+
dim, batch_size = size(E)[1], size(E)[end]
10+
dstsize = (dim, N, batch_size)
11+
batched_xs = batched_index(xs, batch_size)
12+
return NNlib.scatter(aggr, E, batched_xs; dstsize=dstsize)
13+
end
14+
15+
matmul(A::AbstractMatrix, B::AbstractMatrix) = A * B
16+
matmul(A::AbstractArray, B::AbstractArray) = NNlib.batched_mul(A, B)
1117

1218
function batched_index(idx::AbstractVector, batch_size::Integer)
1319
b = copyto!(similar(idx, 1, batch_size), collect(1:batch_size))
@@ -36,29 +42,20 @@ end
3642
function indexed_softmax(x::AbstractArray, xs, N; dims=1)
3743
y = copy(x)
3844
for i in 1:N
39-
idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(y))
40-
NNlib.softmax!(view(y, idx...); dims)
45+
indexed_y = selectdim(y, dims, (xs .== i))
46+
NNlib.softmax!(indexed_y; dims)
4147
end
4248
return y
4349
end
4450

4551
function ∇indexed_softmax(dy::AbstractArray{T}, y::AbstractArray{S}, xs, N; dims=1) where {T,S}
46-
dx = if NNlib.within_grad()
47-
tmp = dy .* y
48-
for i in 1:N
49-
idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(y))
50-
tmp[idx...] .= tmp[idx...] .- y[idx...] .* sum(tmp[idx...]; dims)
51-
end
52-
tmp
53-
else
54-
out = similar(y, promote_type(T,S))
55-
out .= dy .* y
56-
for i in 1:N
57-
idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(y))
58-
out[idx...] .= out[idx...] .- y[idx...] .* sum(out[idx...]; dims)
59-
end
60-
out
52+
out = NNlib.∇softmax_data(dy, y; dims)
53+
for i in 1:N
54+
indexed_y = selectdim(y, dims, (xs .== i))
55+
indexed_out = selectdim(out, dims, (xs .== i))
56+
indexed_out .= indexed_out .- indexed_y .* sum(indexed_out; dims)
6157
end
58+
return out
6259
end
6360

6461
function ChainRulesCore.rrule(::typeof(indexed_softmax), x, xs, N; dims=1)

0 commit comments

Comments
 (0)