diff --git a/examples/vgae.jl b/examples/vgae.jl index d8ffb3c46..e0dcc5b46 100644 --- a/examples/vgae.jl +++ b/examples/vgae.jl @@ -15,8 +15,7 @@ using LightGraphs: adjacency_matrix num_nodes = 2708 num_features = 1433 -hidden1 = 128 -hidden2 = 32 +h_dim = 32 z_dim = 16 target_catg = 7 epochs = 200 @@ -27,9 +26,7 @@ adj_mat = Matrix{Float32}(adjacency_matrix(g)) train_data = [(FeaturedGraph(adj_mat.*M, Matrix{Float32}(features)), adj_mat) for M in masks] ## Model -encoder = Chain(GCNConv(num_features=>hidden1, relu; cache=false), - GCNConv(hidden1=>hidden2; cache=false)) -model = VGAE(encoder, hidden2, z_dim, σ) +model = VGAE(GCNConv(num_features=>h_dim, relu; cache=false), h_dim, z_dim, σ) encoder = model.encoder decoder = model.decoder ps = Flux.params(model) diff --git a/src/models.jl b/src/models.jl index 5948806c3..ac75acf3a 100644 --- a/src/models.jl +++ b/src/models.jl @@ -45,16 +45,10 @@ end @functor VGAE -function (g::VGAE)(X::AbstractMatrix) - Z = g.encoder(X) - A = g.decoder(Z) - A -end - function (g::VGAE)(fg::FeaturedGraph) - Z = g.encoder(X) - A = g.decoder(Z) - A + fg_ = g.encoder(fg) + fg_ = g.decoder(fg_) + fg_ end @@ -101,32 +95,24 @@ struct VariationalEncoder end function VariationalEncoder(nn, h_dim::Integer, z_dim::Integer) - VariationalEncoder(nn, Dense(h_dim, z_dim), Dense(h_dim, z_dim), z_dim) + VariationalEncoder(nn, + GCNConv(h_dim=>z_dim; cache=false), + GCNConv(h_dim=>z_dim; cache=false), + z_dim) end @functor VariationalEncoder -function (ve::VariationalEncoder)(X::AbstractMatrix)::AbstractMatrix - μ, logσ = summarize(ve, X) - Z = sample(μ, logσ) - Z -end - function (ve::VariationalEncoder)(fg::FeaturedGraph)::FeaturedGraph μ, logσ = summarize(ve, fg) Z = sample(μ, logσ) FeaturedGraph(graph(fg), Z) end -function summarize(ve::VariationalEncoder, X::AbstractMatrix) - h = ve.nn(X) - ve.μ(h), ve.logσ(h) -end - function summarize(ve::VariationalEncoder, fg::FeaturedGraph) fg_ = ve.nn(fg) - h = node_feature(fg_) - ve.μ(h), ve.logσ(h) + fg_μ, fg_logσ = ve.μ(fg_), ve.logσ(fg_) + node_feature(fg_μ), node_feature(fg_logσ) end sample(μ::AbstractArray{T}, logσ::AbstractArray{T}) where {T<:Real} = diff --git a/test/models.jl b/test/models.jl index 40609fb5a..244086c91 100644 --- a/test/models.jl +++ b/test/models.jl @@ -20,28 +20,42 @@ adj = [0. 1. 0. 1.; @testset "VGAE" begin @testset "InnerProductDecoder" begin ipd = InnerProductDecoder(identity) - Y = ipd(rand(1, N)) + X = rand(1, N) + Y = ipd(X) @test size(Y) == (N, N) - Y = ipd(rand(in_channel, N)) + X = rand(1, N) + fg = FeaturedGraph(adj, X) + fg_ = ipd(fg) + Y = node_feature(fg_) + @test size(Y) == (N, N) + + X = rand(in_channel, N) + fg = FeaturedGraph(adj, X) + fg_ = ipd(fg) + Y = node_feature(fg_) @test size(Y) == (N, N) end @testset "VariationalEncoder" begin z_dim = 2 - gc = GCNConv(adj, in_channel=>out_channel) + gc = GCNConv(in_channel=>out_channel) ve = VariationalEncoder(gc, out_channel, z_dim) X = rand(in_channel, N) - Z = ve(X) + fg = FeaturedGraph(adj, X) + fg_ = ve(fg) + Z = node_feature(fg_) @test size(Z) == (z_dim, N) end @testset "VGAE" begin z_dim = 2 - gc = GCNConv(adj, in_channel=>out_channel) + gc = GCNConv(in_channel=>out_channel) vgae = VGAE(gc, out_channel, z_dim) X = rand(in_channel, N) - Y = vgae(X) + fg = FeaturedGraph(adj, X) + fg_ = vgae(fg) + Y = node_feature(fg_) @test size(Y) == (N, N) end end