Skip to content

Commit e2ab9ee

Browse files
committed
add DeepSet model
fix
1 parent 66c9b9a commit e2ab9ee

File tree

6 files changed

+128
-7
lines changed

6 files changed

+128
-7
lines changed

docs/src/manual/models.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,22 @@ Reference: [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308)
3737

3838
---
3939

40+
## DeepSet
41+
42+
```math
43+
Z = \rho ( \sum_{x_i \in \mathcal{V}} \phi (x_i) )
44+
```
45+
46+
where ``\phi`` and ``\rho`` denote two neural networks and ``x_i`` is the node feature for node ``i``.
47+
48+
```@docs
49+
GeometricFlux.DeepSet
50+
```
51+
52+
Reference: [Deep Sets](https://papers.nips.cc/paper/2017/hash/f22e4747da1aa27e363d86d40ff442fe-Abstract.html)
53+
54+
---
55+
4056
## Special Layers
4157

4258
### Inner-product Decoder

examples/deepsets.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
mae loss
2+
3+
Adam(lr=1e-4)

src/GeometricFlux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export
5151
VGAE,
5252
InnerProductDecoder,
5353
VariationalGraphEncoder,
54+
DeepSet,
5455

5556
# layer/utils
5657
WithGraph,

src/models.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,77 @@ WithGraph(fg::AbstractFeaturedGraph, l::VariationalGraphEncoder) =
146146
WithGraph(fg, l.logσ),
147147
l.z_dim
148148
)
149+
150+
151+
"""
152+
DeepSet(ϕ, ρ, aggr=+)
153+
154+
Deep set model.
155+
156+
# Arguments
157+
158+
- `ϕ`: The dimension of input features.
159+
- `ρ`: The dimension of output features.
160+
- `aggr`: An aggregate function applied to the result of message function. `+`, `-`,
161+
`*`, `/`, `max`, `min` and `mean` are available.
162+
163+
# Examples
164+
165+
```jldoctest
166+
julia> ϕ = Dense(64, 16)
167+
Dense(64, 16) # 1_040 parameters
168+
169+
julia> ρ = Dense(16, 4)
170+
Dense(16, 4) # 68 parameters
171+
172+
julia> DeepSet(ϕ, ρ)
173+
DeepSet(Dense(64, 16), Dense(16, 4), aggr=+)
174+
175+
julia> DeepSet(ϕ, ρ, aggr=max)
176+
DeepSet(Dense(64, 16), Dense(16, 4), aggr=max)
177+
```
178+
179+
See also [`WithGraph`](@ref) for training layer with static graph.
180+
"""
181+
struct DeepSet{T,S,O} <: GraphNet
182+
ϕ::T
183+
ρ::S
184+
aggr::O
185+
end
186+
187+
DeepSet(ϕ, ρ; aggr=+) = DeepSet(ϕ, ρ, aggr)
188+
189+
update_vertex(l::DeepSet, Ē, V, u) = l.ϕ(V)
190+
191+
update_global(l::DeepSet, ē, v̄, u) = l.ρ(v̄)
192+
193+
# For variable graph
194+
function (l::DeepSet)(fg::AbstractFeaturedGraph)
195+
X = node_feature(fg)
196+
u = global_feature(fg)
197+
GraphSignals.check_num_nodes(fg, X)
198+
_, _, u = propagate(l, graph(fg), nothing, X, u, nothing, nothing, l.aggr)
199+
return ConcreteFeaturedGraph(fg, gf=u)
200+
end
201+
202+
# For static graph
203+
function (l::DeepSet)(el::NamedTuple, X::AbstractArray, u=nothing)
204+
GraphSignals.check_num_nodes(el.N, X)
205+
_, _, u = propagate(l, el, nothing, X, u, nothing, nothing, l.aggr)
206+
return u
207+
end
208+
209+
WithGraph(fg::AbstractFeaturedGraph, l::DeepSet) = WithGraph(to_namedtuple(fg), l)
210+
(wg::WithGraph{<:DeepSet})(args...) = wg.layer(wg.graph, args...)
211+
212+
function Base.show(io::IO, l::DeepSet)
213+
print(io, "DeepSet(", l.ϕ, ", ", l.ρ)
214+
print(io, ", aggr=", l.aggr)
215+
print(io, ")")
216+
end
217+
218+
function Base.show(io::IO, l::WithGraph{<:DeepSet})
219+
print(io, "WithGraph(Graph(#V=", l.graph.N)
220+
print(io, ", #E=", l.graph.E, "), ")
221+
print(io, l.layer, ")")
222+
end

src/operation.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ function batched_index(idx::AbstractVector, batch_size::Integer)
1414
return tuple.(idx, b)
1515
end
1616

17-
aggregate(aggr::typeof(+), X) = vec(sum(X, dims=2))
18-
aggregate(aggr::typeof(-), X) = -vec(sum(X, dims=2))
19-
aggregate(aggr::typeof(*), X) = vec(prod(X, dims=2))
20-
aggregate(aggr::typeof(/), X) = 1 ./ vec(prod(X, dims=2))
21-
aggregate(aggr::typeof(max), X) = vec(maximum(X, dims=2))
22-
aggregate(aggr::typeof(min), X) = vec(minimum(X, dims=2))
23-
aggregate(aggr::typeof(mean), X) = vec(aggr(X, dims=2))
17+
aggregate(::typeof(+), X) = sum(X, dims=2)
18+
aggregate(::typeof(-), X) = -sum(X, dims=2)
19+
aggregate(::typeof(*), X) = prod(X, dims=2)
20+
aggregate(::typeof(/), X) = 1 ./ prod(X, dims=2)
21+
aggregate(::typeof(max), X) = maximum(X, dims=2)
22+
aggregate(::typeof(min), X) = minimum(X, dims=2)
23+
aggregate(::typeof(mean), X) = mean(X, dims=2)
2424

2525
@non_differentiable batched_index(x...)

test/models.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@testset "models" begin
2+
batch_size = 10
23
in_channel = 3
34
out_channel = 5
45
N = 4
@@ -60,5 +61,31 @@
6061
Y = node_feature(fg_)
6162
@test size(Y) == (N, N)
6263
end
64+
65+
@testset "DeepSet" begin
66+
ϕ = Dense(64, 16)
67+
ρ = Dense(16, 4)
68+
@testset "layer without graph" begin
69+
deepset = DeepSet(ϕ, ρ)
70+
71+
X = rand(T, 64, N)
72+
fg = FeaturedGraph(adj, nf=X)
73+
fg_ = deepset(fg)
74+
@test size(global_feature(fg_)) == (4, 1)
75+
76+
g = Zygote.gradient(() -> sum(global_feature(deepset(fg))), Flux.params(deepset))
77+
@test length(g.grads) == 2
78+
end
79+
80+
@testset "layer with static graph" begin
81+
X = rand(T, 64, N, batch_size)
82+
deepset = WithGraph(fg, DeepSet(ϕ, ρ))
83+
Y = deepset(X)
84+
@test size(Y) == (4, 1, batch_size)
85+
86+
g = Zygote.gradient(() -> sum(deepset(X)), Flux.params(deepset))
87+
@test length(g.grads) == 0
88+
end
89+
end
6390
end
6491
end

0 commit comments

Comments
 (0)