Skip to content

Correct VGAE example #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions examples/vgae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ using LightGraphs: adjacency_matrix

num_nodes = 2708
num_features = 1433
hidden1 = 128
hidden2 = 32
h_dim = 32
z_dim = 16
target_catg = 7
epochs = 200
Expand All @@ -27,9 +26,7 @@ adj_mat = Matrix{Float32}(adjacency_matrix(g))
train_data = [(FeaturedGraph(adj_mat.*M, Matrix{Float32}(features)), adj_mat) for M in masks]

## Model
encoder = Chain(GCNConv(num_features=>hidden1, relu; cache=false),
GCNConv(hidden1=>hidden2; cache=false))
model = VGAE(encoder, hidden2, z_dim, σ)
model = VGAE(GCNConv(num_features=>h_dim, relu; cache=false), h_dim, z_dim, σ)
encoder = model.encoder
decoder = model.decoder
ps = Flux.params(model)
Expand Down
32 changes: 9 additions & 23 deletions src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,10 @@ end

@functor VGAE

function (g::VGAE)(X::AbstractMatrix)
Z = g.encoder(X)
A = g.decoder(Z)
A
end

function (g::VGAE)(fg::FeaturedGraph)
Z = g.encoder(X)
A = g.decoder(Z)
A
fg_ = g.encoder(fg)
fg_ = g.decoder(fg_)
fg_
end


Expand Down Expand Up @@ -101,32 +95,24 @@ struct VariationalEncoder
end

function VariationalEncoder(nn, h_dim::Integer, z_dim::Integer)
VariationalEncoder(nn, Dense(h_dim, z_dim), Dense(h_dim, z_dim), z_dim)
VariationalEncoder(nn,
GCNConv(h_dim=>z_dim; cache=false),
GCNConv(h_dim=>z_dim; cache=false),
z_dim)
end

@functor VariationalEncoder

function (ve::VariationalEncoder)(X::AbstractMatrix)::AbstractMatrix
μ, logσ = summarize(ve, X)
Z = sample(μ, logσ)
Z
end

function (ve::VariationalEncoder)(fg::FeaturedGraph)::FeaturedGraph
μ, logσ = summarize(ve, fg)
Z = sample(μ, logσ)
FeaturedGraph(graph(fg), Z)
end

function summarize(ve::VariationalEncoder, X::AbstractMatrix)
h = ve.nn(X)
ve.μ(h), ve.logσ(h)
end

function summarize(ve::VariationalEncoder, fg::FeaturedGraph)
fg_ = ve.nn(fg)
h = node_feature(fg_)
ve.μ(h), ve.logσ(h)
fg_μ, fg_logσ = ve.μ(fg_), ve.logσ(fg_)
node_feature(fg_μ), node_feature(fg_logσ)
end

sample(μ::AbstractArray{T}, logσ::AbstractArray{T}) where {T<:Real} =
Expand Down
26 changes: 20 additions & 6 deletions test/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,42 @@ adj = [0. 1. 0. 1.;
@testset "VGAE" begin
@testset "InnerProductDecoder" begin
ipd = InnerProductDecoder(identity)
Y = ipd(rand(1, N))
X = rand(1, N)
Y = ipd(X)
@test size(Y) == (N, N)

Y = ipd(rand(in_channel, N))
X = rand(1, N)
fg = FeaturedGraph(adj, X)
fg_ = ipd(fg)
Y = node_feature(fg_)
@test size(Y) == (N, N)

X = rand(in_channel, N)
fg = FeaturedGraph(adj, X)
fg_ = ipd(fg)
Y = node_feature(fg_)
@test size(Y) == (N, N)
end

@testset "VariationalEncoder" begin
z_dim = 2
gc = GCNConv(adj, in_channel=>out_channel)
gc = GCNConv(in_channel=>out_channel)
ve = VariationalEncoder(gc, out_channel, z_dim)
X = rand(in_channel, N)
Z = ve(X)
fg = FeaturedGraph(adj, X)
fg_ = ve(fg)
Z = node_feature(fg_)
@test size(Z) == (z_dim, N)
end

@testset "VGAE" begin
z_dim = 2
gc = GCNConv(adj, in_channel=>out_channel)
gc = GCNConv(in_channel=>out_channel)
vgae = VGAE(gc, out_channel, z_dim)
X = rand(in_channel, N)
Y = vgae(X)
fg = FeaturedGraph(adj, X)
fg_ = vgae(fg)
Y = node_feature(fg_)
@test size(Y) == (N, N)
end
end
Expand Down