Skip to content

Add DeepSet model and example #278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/src/manual/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ Reference: [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308)

---

## DeepSet

```math
Z = \rho ( \sum_{x_i \in \mathcal{V}} \phi (x_i) )
```

where ``\phi`` and ``\rho`` denote two neural networks and ``x_i`` is the node feature for node ``i``.

```@docs
GeometricFlux.DeepSet
```

Reference: [Deep Sets](https://papers.nips.cc/paper/2017/hash/f22e4747da1aa27e363d86d40ff442fe-Abstract.html)

---

## Special Layers

### Inner-product Decoder
Expand Down
154 changes: 154 additions & 0 deletions examples/digitsum_deepsets.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
using CUDA
using Flux
using Flux: onecold
using Flux.Losses: mae
using Flux.Data: DataLoader
using GeometricFlux
using Graphs
using GraphSignals
using MLDatasets
using Parameters: @with_kw
using ProgressMeter: Progress, next!
using Statistics
using Random

function load_data(
batch_size,
num_train_examples,
num_test_examples,
train_max_length,
test_min_length,
test_max_length
)
train_X, train_y = MLDatasets.MNIST.traindata(Float32)
test_X, test_y = MLDatasets.MNIST.testdata(Float32)

train_X, train_y = shuffle_data(train_X, train_y)
test_X, test_y = shuffle_data(test_X, test_y)

train_data = generate_featuredgraphs(train_X, train_y, num_train_examples, 1:train_max_length)
test_data = generate_featuredgraphs(test_X, test_y, num_test_examples, test_min_length:test_max_length)
train_batch = Flux.batch(train_data)
test_batch = Flux.batch(test_data)

train_loader = DataLoader(train_batch, batchsize=batch_size)
test_loader = DataLoader(test_batch, batchsize=batch_size)
return train_loader, test_loader
end

function shuffle_data(X, y)
X = reshape(X, :, size(y)...)
p = randperm(size(y)...)
return X[:,p], y[p]
end

function generate_featuredgraphs(X, y, num_examples, len_range)
len = size(y, 1)
data = []
start = 1
for _ in 1:num_examples
n = rand(len_range)
if start+n-1 > len
start = 1
end
last = start + n - 1
g = SimpleGraph(n)
d = (FeaturedGraph(g, nf=X[:,start:last]), sum(y[start:last], dims=1))
push!(data, d)
start = last + 1
end
return data
end

@with_kw mutable struct Args
η = 1e-4 # learning rate
num_train_examples = 1.5e5 # number of training examples
num_test_examples = 1e4 # number of testing examples
train_max_length = 10 # max number of digits in a training example
test_min_length = 5 # min number of digits in a testing example
test_max_length = 55 # max number of digits in a testing example
batch_size = 128 # batch size
epochs = 10 # number of epochs
seed = 0 # random seed
cuda = true # use GPU
input_dim = 28*28 # input dimension
hidden_dims = [300, 100, 30] # hidden dimension
target_dim = 1 # target dimension
end

function model_loss(model, batch)
ŷ = vcat(map(x -> global_feature(model(x[1])), batch)...)
y = vcat(map(x -> x[2], batch)...)
return mae(ŷ, y)
end

model_loss(model, loader::DataLoader, device) = mean(model_loss(model, batch |> device) for batch in loader)

function train(; kws...)
# load hyperparamters
args = Args(; kws...)
args.seed > 0 && Random.seed!(args.seed)

# GPU config
if args.cuda && CUDA.has_cuda()
device = gpu
@info "Training on GPU"
else
device = cpu
@info "Training on CPU"
end

# load MNIST dataset
train_loader, test_loader = load_data(
args.batch_size,
args.num_train_examples,
args.num_test_examples,
args.train_max_length,
args.test_min_length,
args.test_max_length
)

# build model
ϕ = Chain(
Dense(args.input_dim, args.hidden_dims[1], tanh),
Dense(args.hidden_dims[1], args.hidden_dims[2], tanh),
Dense(args.hidden_dims[2], args.hidden_dims[3], tanh),
)
ρ = Dense(args.hidden_dims[3], args.target_dim)
model = DeepSet(ϕ, ρ) |> device

# ADAM optimizer
opt = ADAM(args.η)

# parameters
ps = Flux.params(model)

# training
train_steps = 0
@info "Start Training, total $(args.epochs) epochs"
for epoch = 1:args.epochs
@info "Epoch $(epoch)"
progress = Progress(length(train_loader))

for batch in train_loader
train_loss, back = Flux.pullback(ps) do
model_loss(model, batch |> device)
end
test_loss = model_loss(model, test_loader, device)
grad = back(1f0)
Flux.Optimise.update!(opt, ps, grad)

# progress meter
next!(progress; showvalues=[
(:train_loss, train_loss),
(:test_loss, test_loss)
])

train_steps += 1
end
end

return model, args
end

model, args = train()
1 change: 1 addition & 0 deletions src/GeometricFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export
VGAE,
InnerProductDecoder,
VariationalGraphEncoder,
DeepSet,

# layer/utils
WithGraph,
Expand Down
6 changes: 3 additions & 3 deletions src/layers/gn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ end
- `update_batch_edge`: (E_in_dim, E) -> (E_out_dim, E)
- `aggregate_neighbors`: (E_out_dim, E) -> (E_out_dim, V)
- `update_batch_vertex`: (V_in_dim, V) -> (V_out_dim, V)
- `aggregate_edges`: (E_out_dim, E) -> (E_out_dim,)
- `aggregate_vertices`: (V_out_dim, V) -> (V_out_dim,)
- `update_global`: (dim,) -> (dim,)
- `aggregate_edges`: (E_out_dim, E) -> (E_out_dim, 1)
- `aggregate_vertices`: (V_out_dim, V) -> (V_out_dim, 1)
- `update_global`: (dim, 1) -> (dim, 1)
"""
function propagate(gn::GraphNet, el::NamedTuple, E, V, u, naggr, eaggr, vaggr)
E = update_batch_edge(gn, el, E, V, u)
Expand Down
78 changes: 78 additions & 0 deletions src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,81 @@ WithGraph(fg::AbstractFeaturedGraph, l::VariationalGraphEncoder) =
WithGraph(fg, l.logσ),
l.z_dim
)


"""
DeepSet(ϕ, ρ, aggr=+)

Deep set model.

# Arguments

- `ϕ`: Neural network layer for each input before aggregation.
- `ρ`: Neural network layer after aggregation.
- `aggr`: An aggregate function applied to the result of message function. `+`, `-`,
`*`, `/`, `max`, `min` and `mean` are available.

# Examples

```jldoctest
julia> ϕ = Dense(64, 16)
Dense(64, 16) # 1_040 parameters

julia> ρ = Dense(16, 4)
Dense(16, 4) # 68 parameters

julia> DeepSet(ϕ, ρ)
DeepSet(Dense(64, 16), Dense(16, 4), aggr=+)

julia> DeepSet(ϕ, ρ, aggr=max)
DeepSet(Dense(64, 16), Dense(16, 4), aggr=max)
```

See also [`WithGraph`](@ref) for training layer with static graph.
"""
struct DeepSet{T,S,O} <: GraphNet
ϕ::T
ρ::S
aggr::O
end

DeepSet(ϕ, ρ; aggr=+) = DeepSet(ϕ, ρ, aggr)

@functor DeepSet

update_batch_edge(l::DeepSet, el::NamedTuple, E, V, u) = nothing

update_vertex(l::DeepSet, Ē, V, u) = l.ϕ(V)

update_global(l::DeepSet, ē, v̄, u) = l.ρ(v̄)

# For variable graph
function (l::DeepSet)(fg::AbstractFeaturedGraph)
X = node_feature(fg)
u = global_feature(fg)
GraphSignals.check_num_nodes(fg, X)
_, _, u = propagate(l, graph(fg), nothing, X, u, nothing, nothing, l.aggr)
return ConcreteFeaturedGraph(fg, gf=u)
end

# For static graph
function (l::DeepSet)(el::NamedTuple, X::AbstractArray, u=nothing)
GraphSignals.check_num_nodes(el.N, X)
_, _, u = propagate(l, el, nothing, X, u, nothing, nothing, l.aggr)
return u
end

WithGraph(fg::AbstractFeaturedGraph, l::DeepSet) = WithGraph(to_namedtuple(fg), l)
(wg::WithGraph{<:DeepSet})(args...) = wg.layer(wg.graph, args...)

function Base.show(io::IO, l::DeepSet)
print(io, "DeepSet(", l.ϕ, ", ", l.ρ)
print(io, ", aggr=", l.aggr)
print(io, ")")
end

function Base.show(io::IO, l::WithGraph{<:DeepSet})
print(io, "WithGraph(Graph(#V=", l.graph.N)
print(io, ", #E=", l.graph.E, "), ")
print(io, l.layer, ")")
end
14 changes: 7 additions & 7 deletions src/operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ function batched_index(idx::AbstractVector, batch_size::Integer)
return tuple.(idx, b)
end

aggregate(aggr::typeof(+), X) = vec(sum(X, dims=2))
aggregate(aggr::typeof(-), X) = -vec(sum(X, dims=2))
aggregate(aggr::typeof(*), X) = vec(prod(X, dims=2))
aggregate(aggr::typeof(/), X) = 1 ./ vec(prod(X, dims=2))
aggregate(aggr::typeof(max), X) = vec(maximum(X, dims=2))
aggregate(aggr::typeof(min), X) = vec(minimum(X, dims=2))
aggregate(aggr::typeof(mean), X) = vec(aggr(X, dims=2))
aggregate(::typeof(+), X) = sum(X, dims=2)
aggregate(::typeof(-), X) = -sum(X, dims=2)
aggregate(::typeof(*), X) = prod(X, dims=2)
aggregate(::typeof(/), X) = 1 ./ prod(X, dims=2)
aggregate(::typeof(max), X) = maximum(X, dims=2)
aggregate(::typeof(min), X) = minimum(X, dims=2)
aggregate(::typeof(mean), X) = mean(X, dims=2)

@non_differentiable batched_index(x...)
27 changes: 27 additions & 0 deletions test/models.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
@testset "models" begin
batch_size = 10
in_channel = 3
out_channel = 5
N = 4
Expand Down Expand Up @@ -60,5 +61,31 @@
Y = node_feature(fg_)
@test size(Y) == (N, N)
end

@testset "DeepSet" begin
ϕ = Dense(64, 16)
ρ = Dense(16, 4)
@testset "layer without graph" begin
deepset = DeepSet(ϕ, ρ)

X = rand(T, 64, N)
fg = FeaturedGraph(adj, nf=X)
fg_ = deepset(fg)
@test size(global_feature(fg_)) == (4, 1)

g = Zygote.gradient(() -> sum(global_feature(deepset(fg))), Flux.params(deepset))
@test length(g.grads) == 6
end

@testset "layer with static graph" begin
X = rand(T, 64, N, batch_size)
deepset = WithGraph(fg, DeepSet(ϕ, ρ))
Y = deepset(X)
@test size(Y) == (4, 1, batch_size)

g = Zygote.gradient(() -> sum(deepset(X)), Flux.params(deepset))
@test length(g.grads) == 4
end
end
end
end