Skip to content

Commit e30d800

Browse files
committed
add DeepSet model
1 parent 66c9b9a commit e30d800

File tree

6 files changed

+89
-7
lines changed

6 files changed

+89
-7
lines changed

docs/src/manual/models.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,22 @@ Reference: [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308)
3737

3838
---
3939

40+
## DeepSet
41+
42+
```math
43+
Z = \rho ( \sum_{x_i \in \mathcal{V}} \phi (x_i) )
44+
```
45+
46+
where ``\phi`` and ``\rho`` denote two neural networks and ``x_i`` is the node feature for node ``i``.
47+
48+
```@docs
49+
GeometricFlux.DeepSet
50+
```
51+
52+
Reference: [Deep Sets](https://papers.nips.cc/paper/2017/hash/f22e4747da1aa27e363d86d40ff442fe-Abstract.html)
53+
54+
---
55+
4056
## Special Layers
4157

4258
### Inner-product Decoder

examples/deepsets.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
mae loss
2+
3+
Adam(lr=1e-4)

src/GeometricFlux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export
5151
VGAE,
5252
InnerProductDecoder,
5353
VariationalGraphEncoder,
54+
DeepSet,
5455

5556
# layer/utils
5657
WithGraph,

src/models.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,38 @@ WithGraph(fg::AbstractFeaturedGraph, l::VariationalGraphEncoder) =
146146
WithGraph(fg, l.logσ),
147147
l.z_dim
148148
)
149+
150+
151+
struct DeepSet{T,S} <: GraphNet
152+
ϕ::T
153+
ρ::S
154+
end
155+
156+
update_vertex(l::DeepSet, Ē, V, u) = l.ϕ(V)
157+
158+
update_global(l::DeepSet, ē, v̄, u) = l.ρ(v̄)
159+
160+
# For variable graph
161+
function (l::DeepSet)(fg::AbstractFeaturedGraph)
162+
X = node_feature(fg)
163+
u = global_feature(fg)
164+
GraphSignals.check_num_nodes(fg, X)
165+
_, _, u = propagate(l, graph(fg), nothing, X, u, nothing, nothing, +)
166+
return ConcreteFeaturedGraph(fg, gf=u)
167+
end
168+
169+
# For static graph
170+
function (l::DeepSet)(el::NamedTuple, X::AbstractArray, u=nothing)
171+
GraphSignals.check_num_nodes(el.N, X)
172+
_, _, u = propagate(l, el, nothing, X, u, nothing, nothing, +)
173+
return u
174+
end
175+
176+
WithGraph(fg::AbstractFeaturedGraph, l::DeepSet) = WithGraph(to_namedtuple(fg), l)
177+
(wg::WithGraph{<:DeepSet})(args...) = wg.layer(wg.graph, args...)
178+
179+
function Base.show(io::IO, l::WithGraph{<:DeepSet})
180+
print(io, "WithGraph(Graph(#V=", l.graph.N)
181+
print(io, ", #E=", l.graph.E, "), ")
182+
print(io, l.layer, ")")
183+
end

src/operation.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ function batched_index(idx::AbstractVector, batch_size::Integer)
1414
return tuple.(idx, b)
1515
end
1616

17-
aggregate(aggr::typeof(+), X) = vec(sum(X, dims=2))
18-
aggregate(aggr::typeof(-), X) = -vec(sum(X, dims=2))
19-
aggregate(aggr::typeof(*), X) = vec(prod(X, dims=2))
20-
aggregate(aggr::typeof(/), X) = 1 ./ vec(prod(X, dims=2))
21-
aggregate(aggr::typeof(max), X) = vec(maximum(X, dims=2))
22-
aggregate(aggr::typeof(min), X) = vec(minimum(X, dims=2))
23-
aggregate(aggr::typeof(mean), X) = vec(aggr(X, dims=2))
17+
aggregate(::typeof(+), X) = sum(X, dims=2)
18+
aggregate(::typeof(-), X) = -sum(X, dims=2)
19+
aggregate(::typeof(*), X) = prod(X, dims=2)
20+
aggregate(::typeof(/), X) = 1 ./ prod(X, dims=2)
21+
aggregate(::typeof(max), X) = maximum(X, dims=2)
22+
aggregate(::typeof(min), X) = minimum(X, dims=2)
23+
aggregate(::typeof(mean), X) = mean(X, dims=2)
2424

2525
@non_differentiable batched_index(x...)

test/models.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@testset "models" begin
2+
batch_size = 10
23
in_channel = 3
34
out_channel = 5
45
N = 4
@@ -60,5 +61,31 @@
6061
Y = node_feature(fg_)
6162
@test size(Y) == (N, N)
6263
end
64+
65+
@testset "DeepSet" begin
66+
ϕ = Dense(64, 16)
67+
ρ = Dense(16, 4)
68+
@testset "layer without graph" begin
69+
deepset = DeepSet(ϕ, ρ)
70+
71+
X = rand(T, 64, N)
72+
fg = FeaturedGraph(adj, nf=X)
73+
fg_ = deepset(fg)
74+
@test size(global_feature(fg_)) == (4, 1)
75+
76+
g = Zygote.gradient(() -> sum(global_feature(deepset(fg))), Flux.params(deepset))
77+
@test length(g.grads) == 2
78+
end
79+
80+
@testset "layer with static graph" begin
81+
X = rand(T, 64, N, batch_size)
82+
deepset = WithGraph(fg, DeepSet(ϕ, ρ))
83+
Y = deepset(X)
84+
@test size(Y) == (4, 1, batch_size)
85+
86+
g = Zygote.gradient(() -> sum(deepset(X)), Flux.params(deepset))
87+
@test length(g.grads) == 0
88+
end
89+
end
6390
end
6491
end

0 commit comments

Comments
 (0)