Skip to content

Commit a4cd230

Browse files
authored
Merge pull request #278 from FluxML/deepset
Add DeepSet model and example
2 parents 66c9b9a + 1ca11db commit a4cd230

File tree

7 files changed

+286
-10
lines changed

7 files changed

+286
-10
lines changed

docs/src/manual/models.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,22 @@ Reference: [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308)
3737

3838
---
3939

40+
## DeepSet
41+
42+
```math
43+
Z = \rho ( \sum_{x_i \in \mathcal{V}} \phi (x_i) )
44+
```
45+
46+
where ``\phi`` and ``\rho`` denote two neural networks and ``x_i`` is the node feature for node ``i``.
47+
48+
```@docs
49+
GeometricFlux.DeepSet
50+
```
51+
52+
Reference: [Deep Sets](https://papers.nips.cc/paper/2017/hash/f22e4747da1aa27e363d86d40ff442fe-Abstract.html)
53+
54+
---
55+
4056
## Special Layers
4157

4258
### Inner-product Decoder

examples/digitsum_deepsets.jl

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
using CUDA
2+
using Flux
3+
using Flux: onecold
4+
using Flux.Losses: mae
5+
using Flux.Data: DataLoader
6+
using GeometricFlux
7+
using Graphs
8+
using GraphSignals
9+
using MLDatasets
10+
using Parameters: @with_kw
11+
using ProgressMeter: Progress, next!
12+
using Statistics
13+
using Random
14+
15+
function load_data(
16+
batch_size,
17+
num_train_examples,
18+
num_test_examples,
19+
train_max_length,
20+
test_min_length,
21+
test_max_length
22+
)
23+
train_X, train_y = MLDatasets.MNIST.traindata(Float32)
24+
test_X, test_y = MLDatasets.MNIST.testdata(Float32)
25+
26+
train_X, train_y = shuffle_data(train_X, train_y)
27+
test_X, test_y = shuffle_data(test_X, test_y)
28+
29+
train_data = generate_featuredgraphs(train_X, train_y, num_train_examples, 1:train_max_length)
30+
test_data = generate_featuredgraphs(test_X, test_y, num_test_examples, test_min_length:test_max_length)
31+
train_batch = Flux.batch(train_data)
32+
test_batch = Flux.batch(test_data)
33+
34+
train_loader = DataLoader(train_batch, batchsize=batch_size)
35+
test_loader = DataLoader(test_batch, batchsize=batch_size)
36+
return train_loader, test_loader
37+
end
38+
39+
function shuffle_data(X, y)
40+
X = reshape(X, :, size(y)...)
41+
p = randperm(size(y)...)
42+
return X[:,p], y[p]
43+
end
44+
45+
function generate_featuredgraphs(X, y, num_examples, len_range)
46+
len = size(y, 1)
47+
data = []
48+
start = 1
49+
for _ in 1:num_examples
50+
n = rand(len_range)
51+
if start+n-1 > len
52+
start = 1
53+
end
54+
last = start + n - 1
55+
g = SimpleGraph(n)
56+
d = (FeaturedGraph(g, nf=X[:,start:last]), sum(y[start:last], dims=1))
57+
push!(data, d)
58+
start = last + 1
59+
end
60+
return data
61+
end
62+
63+
@with_kw mutable struct Args
64+
η = 1e-4 # learning rate
65+
num_train_examples = 1.5e5 # number of training examples
66+
num_test_examples = 1e4 # number of testing examples
67+
train_max_length = 10 # max number of digits in a training example
68+
test_min_length = 5 # min number of digits in a testing example
69+
test_max_length = 55 # max number of digits in a testing example
70+
batch_size = 128 # batch size
71+
epochs = 10 # number of epochs
72+
seed = 0 # random seed
73+
cuda = true # use GPU
74+
input_dim = 28*28 # input dimension
75+
hidden_dims = [300, 100, 30] # hidden dimension
76+
target_dim = 1 # target dimension
77+
end
78+
79+
function model_loss(model, batch)
80+
= vcat(map(x -> global_feature(model(x[1])), batch)...)
81+
y = vcat(map(x -> x[2], batch)...)
82+
return mae(ŷ, y)
83+
end
84+
85+
model_loss(model, loader::DataLoader, device) = mean(model_loss(model, batch |> device) for batch in loader)
86+
87+
function train(; kws...)
88+
# load hyperparamters
89+
args = Args(; kws...)
90+
args.seed > 0 && Random.seed!(args.seed)
91+
92+
# GPU config
93+
if args.cuda && CUDA.has_cuda()
94+
device = gpu
95+
@info "Training on GPU"
96+
else
97+
device = cpu
98+
@info "Training on CPU"
99+
end
100+
101+
# load MNIST dataset
102+
train_loader, test_loader = load_data(
103+
args.batch_size,
104+
args.num_train_examples,
105+
args.num_test_examples,
106+
args.train_max_length,
107+
args.test_min_length,
108+
args.test_max_length
109+
)
110+
111+
# build model
112+
ϕ = Chain(
113+
Dense(args.input_dim, args.hidden_dims[1], tanh),
114+
Dense(args.hidden_dims[1], args.hidden_dims[2], tanh),
115+
Dense(args.hidden_dims[2], args.hidden_dims[3], tanh),
116+
)
117+
ρ = Dense(args.hidden_dims[3], args.target_dim)
118+
model = DeepSet(ϕ, ρ) |> device
119+
120+
# ADAM optimizer
121+
opt = ADAM(args.η)
122+
123+
# parameters
124+
ps = Flux.params(model)
125+
126+
# training
127+
train_steps = 0
128+
@info "Start Training, total $(args.epochs) epochs"
129+
for epoch = 1:args.epochs
130+
@info "Epoch $(epoch)"
131+
progress = Progress(length(train_loader))
132+
133+
for batch in train_loader
134+
train_loss, back = Flux.pullback(ps) do
135+
model_loss(model, batch |> device)
136+
end
137+
test_loss = model_loss(model, test_loader, device)
138+
grad = back(1f0)
139+
Flux.Optimise.update!(opt, ps, grad)
140+
141+
# progress meter
142+
next!(progress; showvalues=[
143+
(:train_loss, train_loss),
144+
(:test_loss, test_loss)
145+
])
146+
147+
train_steps += 1
148+
end
149+
end
150+
151+
return model, args
152+
end
153+
154+
model, args = train()

src/GeometricFlux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export
5151
VGAE,
5252
InnerProductDecoder,
5353
VariationalGraphEncoder,
54+
DeepSet,
5455

5556
# layer/utils
5657
WithGraph,

src/layers/gn.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ end
4242
- `update_batch_edge`: (E_in_dim, E) -> (E_out_dim, E)
4343
- `aggregate_neighbors`: (E_out_dim, E) -> (E_out_dim, V)
4444
- `update_batch_vertex`: (V_in_dim, V) -> (V_out_dim, V)
45-
- `aggregate_edges`: (E_out_dim, E) -> (E_out_dim,)
46-
- `aggregate_vertices`: (V_out_dim, V) -> (V_out_dim,)
47-
- `update_global`: (dim,) -> (dim,)
45+
- `aggregate_edges`: (E_out_dim, E) -> (E_out_dim, 1)
46+
- `aggregate_vertices`: (V_out_dim, V) -> (V_out_dim, 1)
47+
- `update_global`: (dim, 1) -> (dim, 1)
4848
"""
4949
function propagate(gn::GraphNet, el::NamedTuple, E, V, u, naggr, eaggr, vaggr)
5050
E = update_batch_edge(gn, el, E, V, u)

src/models.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,81 @@ WithGraph(fg::AbstractFeaturedGraph, l::VariationalGraphEncoder) =
146146
WithGraph(fg, l.logσ),
147147
l.z_dim
148148
)
149+
150+
151+
"""
152+
DeepSet(ϕ, ρ, aggr=+)
153+
154+
Deep set model.
155+
156+
# Arguments
157+
158+
- `ϕ`: Neural network layer for each input before aggregation.
159+
- `ρ`: Neural network layer after aggregation.
160+
- `aggr`: An aggregate function applied to the result of message function. `+`, `-`,
161+
`*`, `/`, `max`, `min` and `mean` are available.
162+
163+
# Examples
164+
165+
```jldoctest
166+
julia> ϕ = Dense(64, 16)
167+
Dense(64, 16) # 1_040 parameters
168+
169+
julia> ρ = Dense(16, 4)
170+
Dense(16, 4) # 68 parameters
171+
172+
julia> DeepSet(ϕ, ρ)
173+
DeepSet(Dense(64, 16), Dense(16, 4), aggr=+)
174+
175+
julia> DeepSet(ϕ, ρ, aggr=max)
176+
DeepSet(Dense(64, 16), Dense(16, 4), aggr=max)
177+
```
178+
179+
See also [`WithGraph`](@ref) for training layer with static graph.
180+
"""
181+
struct DeepSet{T,S,O} <: GraphNet
182+
ϕ::T
183+
ρ::S
184+
aggr::O
185+
end
186+
187+
DeepSet(ϕ, ρ; aggr=+) = DeepSet(ϕ, ρ, aggr)
188+
189+
@functor DeepSet
190+
191+
update_batch_edge(l::DeepSet, el::NamedTuple, E, V, u) = nothing
192+
193+
update_vertex(l::DeepSet, Ē, V, u) = l.ϕ(V)
194+
195+
update_global(l::DeepSet, ē, v̄, u) = l.ρ(v̄)
196+
197+
# For variable graph
198+
function (l::DeepSet)(fg::AbstractFeaturedGraph)
199+
X = node_feature(fg)
200+
u = global_feature(fg)
201+
GraphSignals.check_num_nodes(fg, X)
202+
_, _, u = propagate(l, graph(fg), nothing, X, u, nothing, nothing, l.aggr)
203+
return ConcreteFeaturedGraph(fg, gf=u)
204+
end
205+
206+
# For static graph
207+
function (l::DeepSet)(el::NamedTuple, X::AbstractArray, u=nothing)
208+
GraphSignals.check_num_nodes(el.N, X)
209+
_, _, u = propagate(l, el, nothing, X, u, nothing, nothing, l.aggr)
210+
return u
211+
end
212+
213+
WithGraph(fg::AbstractFeaturedGraph, l::DeepSet) = WithGraph(to_namedtuple(fg), l)
214+
(wg::WithGraph{<:DeepSet})(args...) = wg.layer(wg.graph, args...)
215+
216+
function Base.show(io::IO, l::DeepSet)
217+
print(io, "DeepSet(", l.ϕ, ", ", l.ρ)
218+
print(io, ", aggr=", l.aggr)
219+
print(io, ")")
220+
end
221+
222+
function Base.show(io::IO, l::WithGraph{<:DeepSet})
223+
print(io, "WithGraph(Graph(#V=", l.graph.N)
224+
print(io, ", #E=", l.graph.E, "), ")
225+
print(io, l.layer, ")")
226+
end

src/operation.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ function batched_index(idx::AbstractVector, batch_size::Integer)
1414
return tuple.(idx, b)
1515
end
1616

17-
aggregate(aggr::typeof(+), X) = vec(sum(X, dims=2))
18-
aggregate(aggr::typeof(-), X) = -vec(sum(X, dims=2))
19-
aggregate(aggr::typeof(*), X) = vec(prod(X, dims=2))
20-
aggregate(aggr::typeof(/), X) = 1 ./ vec(prod(X, dims=2))
21-
aggregate(aggr::typeof(max), X) = vec(maximum(X, dims=2))
22-
aggregate(aggr::typeof(min), X) = vec(minimum(X, dims=2))
23-
aggregate(aggr::typeof(mean), X) = vec(aggr(X, dims=2))
17+
aggregate(::typeof(+), X) = sum(X, dims=2)
18+
aggregate(::typeof(-), X) = -sum(X, dims=2)
19+
aggregate(::typeof(*), X) = prod(X, dims=2)
20+
aggregate(::typeof(/), X) = 1 ./ prod(X, dims=2)
21+
aggregate(::typeof(max), X) = maximum(X, dims=2)
22+
aggregate(::typeof(min), X) = minimum(X, dims=2)
23+
aggregate(::typeof(mean), X) = mean(X, dims=2)
2424

2525
@non_differentiable batched_index(x...)

test/models.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@testset "models" begin
2+
batch_size = 10
23
in_channel = 3
34
out_channel = 5
45
N = 4
@@ -60,5 +61,31 @@
6061
Y = node_feature(fg_)
6162
@test size(Y) == (N, N)
6263
end
64+
65+
@testset "DeepSet" begin
66+
ϕ = Dense(64, 16)
67+
ρ = Dense(16, 4)
68+
@testset "layer without graph" begin
69+
deepset = DeepSet(ϕ, ρ)
70+
71+
X = rand(T, 64, N)
72+
fg = FeaturedGraph(adj, nf=X)
73+
fg_ = deepset(fg)
74+
@test size(global_feature(fg_)) == (4, 1)
75+
76+
g = Zygote.gradient(() -> sum(global_feature(deepset(fg))), Flux.params(deepset))
77+
@test length(g.grads) == 6
78+
end
79+
80+
@testset "layer with static graph" begin
81+
X = rand(T, 64, N, batch_size)
82+
deepset = WithGraph(fg, DeepSet(ϕ, ρ))
83+
Y = deepset(X)
84+
@test size(Y) == (4, 1, batch_size)
85+
86+
g = Zygote.gradient(() -> sum(deepset(X)), Flux.params(deepset))
87+
@test length(g.grads) == 4
88+
end
89+
end
6390
end
6491
end

0 commit comments

Comments
 (0)