diff --git a/docs/bibliography.bib b/docs/bibliography.bib
index abde6e773..dc848ab1c 100644
--- a/docs/bibliography.bib
+++ b/docs/bibliography.bib
@@ -154,3 +154,20 @@ @inproceedings{Gao2019
title = {Graph U-Nets},
year = {2019},
}
+
+@article{Wang2019,
+ abstract = {
Point clouds provide a flexible geometric representation suitable for countless applications in computer graphics; they also comprise the raw output of most 3D data acquisition devices. While hand-designed features on point clouds have long been proposed in graphics and vision, however, the recent overwhelming success of convolutional neural networks (CNNs) for image analysis suggests the value of adapting insight from CNN to the point cloud world. Point clouds inherently lack topological information, so designing a model to recover topology can enrich the representation power of point clouds. To this end, we propose a new neural network module dubbed EdgeConv suitable for CNN-based high-level tasks on point clouds, including classification and segmentation. EdgeConv acts on graphs dynamically computed in each layer of the network. It is differentiable and can be plugged into existing architectures. Compared to existing modules operating in extrinsic space or treating each point independently, EdgeConv has several appealing properties: It incorporates local neighborhood information; it can be stacked applied to learn global shape properties; and in multi-layer systems affinity in feature space captures semantic characteristics over potentially long distances in the original embedding. We show the performance of our model on standard benchmarks, including ModelNet40, ShapeNetPart, and S3DIS.
},
+ author = {Yue Wang and Yongbin Sun and Ziwei Liu and Sanjay E. Sarma and Michael M. Bronstein and Justin M. Solomon},
+ doi = {10.1145/3326362},
+ issn = {0730-0301},
+ issue = {5},
+ journal = {ACM Transactions on Graphics},
+ keywords = {Classification,Point cloud,Segmentation},
+ month = {11},
+ pages = {1-12},
+ publisher = {Association for Computing Machinery},
+ title = {Dynamic Graph CNN for Learning on Point Clouds},
+ volume = {38},
+ url = {https://dl.acm.org/doi/10.1145/3326362},
+ year = {2019},
+}
diff --git a/docs/make.jl b/docs/make.jl
index 45ee77d03..81ad353d0 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -17,6 +17,14 @@ makedocs(
clean = false,
modules = [GeometricFlux,GraphSignals],
pages = ["Home" => "index.md",
+ "Tutorials" => [
+ "Semi-Supervised Learning with GCN" => "tutorials/semisupervised_gcn.md",
+ "GCN with Fixed Graph" => "tutorials/gcn_fixed_graph.md",
+ "Graph Attention Network" => "tutorials/gat.md",
+ "DeepSet for Digit Sum" => "tutorials/deepset.md",
+ "Variational Graph Autoencoder" => "tutorials/vgae.md",
+ "Graph Embedding" => "tutorials/graph_embedding.md",
+ ],
"Introduction" => "introduction.md",
"Basics" => [
"Graph Convolutions" => "basics/conv.md",
@@ -28,17 +36,10 @@ makedocs(
"Batch Learning" => "basics/batch.md",
],
"Cooperate with Flux Layers" => "cooperate.md",
- "Tutorials" => [
- "Semi-Supervised Learning with GCN" => "tutorials/semisupervised_gcn.md",
- "GCN with Fixed Graph" => "tutorials/gcn_fixed_graph.md",
- "Graph Attention Network" => "tutorials/gat.md",
- "DeepSet for Digit Sum" => "tutorials/deepset.md",
- "Variational Graph Autoencoder" => "tutorials/vgae.md",
- "Graph Embedding" => "tutorials/graph_embedding.md",
- ],
"Abstractions" => [
"Message passing scheme" => "abstractions/msgpass.md",
"Graph network block" => "abstractions/gn.md"],
+ "Dynamic Graph Update" => "dynamicgraph.md",
"Manual" => [
"FeaturedGraph" => "manual/featuredgraph.md",
"Convolutional Layers" => "manual/conv.md",
diff --git a/docs/src/dynamicgraph.md b/docs/src/dynamicgraph.md
new file mode 100644
index 000000000..3a3f1fc48
--- /dev/null
+++ b/docs/src/dynamicgraph.md
@@ -0,0 +1,22 @@
+# Dynamic Graph Update
+
+Dynamic graph update is a technique to generate a new graph within a graph convolutional layer proposed by [Wang2019](@cite).
+
+Most of manifold learning approaches aims to learn capture manifold structures in high dimensional space. They construct a graph to approximate the manifold and learn to reduce dimensions of space. The separation of capturing manifold and learning dimensional reduction limits the power of manifold learning. Thus, [latent graph learning](https://towardsdatascience.com/manifold-learning-2-99a25eeb677d) is proposed to learn the manifold and dimensional reduction simultaneously. The latent graph learning is also named as manifold learning 2.0 which leverages the power of graph neural network and learns latent graph structure within layers of a graph neural network.
+
+Latent graph learning learns the latent graph through training over point cloud, or a set of features. A fixed graph structure is not provided to a GNN model. Latent graph is dynamically constructed by constructing a neighborhood graph using features in graph convolutional layers. After construction of neighborhood graph, the neighborhood graph is fed as input with features into a graph convolutional layer.
+
+Currently, we support k-nearest neighbor method to construct a neighborhood graph. To use dynamic graph update, just replace the static graph strategy
+
+```julia
+WithGraph(fg, EdgeConv(Dense(2*in_channel, out_channel)))
+```
+
+as graph construction method.
+
+```julia
+WithGraph(
+ EdgeConv(Dense(2*in_channel, out_channel)),
+ dynamic=X -> GraphSignals.kneighbors_graph(X, 3)
+)
+```
diff --git a/src/layers/gn.jl b/src/layers/gn.jl
index 7ef243769..31199aa8a 100644
--- a/src/layers/gn.jl
+++ b/src/layers/gn.jl
@@ -56,9 +56,14 @@ function propagate(gn::GraphNet, el::NamedTuple, E, V, u, naggr, eaggr, vaggr)
return E, V, u
end
+WithGraph(fg::AbstractFeaturedGraph, gn::GraphNet) = WithGraph(to_namedtuple(fg), gn)
+WithGraph(gn::GraphNet; dynamic=nothing) = WithGraph(DynamicGraph(dynamic), gn)
+
to_namedtuple(fg::AbstractFeaturedGraph) = to_namedtuple(graph(fg))
function to_namedtuple(sg::SparseGraph)
- es, nbrs, xs = Zygote.ignore(() -> collect(edges(sg)))
+ es, nbrs, xs = collect(edges(sg))
return (N=nv(sg), E=ne(sg), es=es, nbrs=nbrs, xs=xs)
end
+
+@non_differentiable to_namedtuple(x...)
diff --git a/src/layers/graphlayers.jl b/src/layers/graphlayers.jl
index cb68514dc..d0b14585b 100644
--- a/src/layers/graphlayers.jl
+++ b/src/layers/graphlayers.jl
@@ -6,14 +6,16 @@ An abstract type of graph neural network layer for GeometricFlux.
abstract type AbstractGraphLayer end
"""
- WithGraph(fg, layer)
+ WithGraph([g], layer; dynamic=nothing)
Train GNN layers with static graph.
# Arguments
-- `fg`: A fixed `FeaturedGraph` to train with.
+- `g`: If a `FeaturedGraph` is given, a fixed graph is used to train with.
- `layer`: A GNN layer.
+- `dynamic`: If a function is given, it enables dynamic graph update by constructing
+dynamic graph through given function within layers.
# Example
@@ -74,9 +76,10 @@ function Base.show(io::IO, l::WithGraph)
print(io, l.layer, ")")
end
-WithGraph(fg::AbstractFeaturedGraph, model::Chain) = Chain(map(l -> WithGraph(fg, l), model.layers)...)
-WithGraph(::AbstractFeaturedGraph, layer::WithGraph) = layer
-WithGraph(::AbstractFeaturedGraph, layer) = layer
+WithGraph(fg::AbstractFeaturedGraph, model::Chain; kwargs...) =
+ Chain(map(l -> WithGraph(fg, l; kwargs...), model.layers)...)
+WithGraph(::AbstractFeaturedGraph, layer::WithGraph; kwargs...) = layer
+WithGraph(::AbstractFeaturedGraph, layer; kwargs...) = layer
update_batch_edge(l::WithGraph, args...) = update_batch_edge(l.layer, l.graph, args...)
aggregate_neighbors(l::WithGraph, args...) = aggregate_neighbors(l.layer, l.graph, args...)
@@ -131,3 +134,7 @@ function Base.show(io::IO, l::GraphParallel)
print(io, ", global_layer=", l.global_layer)
print(io, ")")
end
+
+struct DynamicGraph{F}
+ method::F
+end
diff --git a/src/layers/msgpass.jl b/src/layers/msgpass.jl
index 4c5923467..5a0ad92c4 100644
--- a/src/layers/msgpass.jl
+++ b/src/layers/msgpass.jl
@@ -63,13 +63,21 @@ update_edge(mp::MessagePassing, e, vi, vj, u) = GeometricFlux.message(mp, vi, vj
update_vertex(mp::MessagePassing, ē, vi, u) = GeometricFlux.update(mp, ē, vi)
# For static graph
-WithGraph(fg::AbstractFeaturedGraph, mp::MessagePassing) =
- WithGraph(to_namedtuple(fg), mp)
-
(wg::WithGraph{<:MessagePassing})(args...) = wg.layer(wg.graph, args...)
+# For dynamic graph
+function (wg::WithGraph{<:MessagePassing,<:DynamicGraph})(args...)
+ fg = wg.graph.method(args[1])
+ return wg.layer(to_namedtuple(fg), args...)
+end
+
function Base.show(io::IO, l::WithGraph{<:MessagePassing})
print(io, "WithGraph(Graph(#V=", l.graph.N)
print(io, ", #E=", l.graph.E, "), ")
print(io, l.layer, ")")
end
+
+function Base.show(io::IO, l::WithGraph{<:MessagePassing,<:DynamicGraph})
+ print(io, "WithGraph(DynamicGraph(", l.graph.method, "), ")
+ print(io, l.layer, ")")
+end
diff --git a/test/layers/conv.jl b/test/layers/conv.jl
index e7018c0b0..6ce6c22ea 100644
--- a/test/layers/conv.jl
+++ b/test/layers/conv.jl
@@ -296,6 +296,26 @@
g = Zygote.gradient(() -> sum(ec(X)), Flux.params(ec))
@test length(g.grads) == 2
end
+
+ @testset "layer with dynamic graph" begin
+ X = rand(T, in_channel, N)
+ ec = WithGraph(EdgeConv(Dense(2*in_channel, out_channel)), dynamic=X -> GraphSignals.kneighbors_graph(X, 3))
+ Y = ec(X)
+ @test size(Y) == (out_channel, N)
+
+ g = Zygote.gradient(() -> sum(ec(X)), Flux.params(ec))
+ @test length(g.grads) == 2
+ end
+
+ @testset "layer with dynamic graph in batch" begin
+ X = rand(T, in_channel, N, batch_size)
+ ec = WithGraph(EdgeConv(Dense(2*in_channel, out_channel)), dynamic=X -> GraphSignals.kneighbors_graph(X, 3))
+ Y = ec(X)
+ @test size(Y) == (out_channel, N, batch_size)
+
+ g = Zygote.gradient(() -> sum(ec(X)), Flux.params(ec))
+ @test length(g.grads) == 2
+ end
end
@testset "GINConv" begin