Skip to content

Commit db60d93

Browse files
authored
Merge branch 'master' into new_features
2 parents 227e3da + 2a4430c commit db60d93

File tree

11 files changed

+92
-48
lines changed

11 files changed

+92
-48
lines changed

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,18 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## [0.13.4]
6+
7+
- support GraphSignals to 0.7
8+
9+
## [0.13.3]
10+
11+
- update doc for `FeaturedGraph`
12+
13+
## [0.13.2]
14+
15+
- fix doc
16+
517
## [0.13.1]
618

719
- `GraphParallel` support `positional_layer`

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GeometricFlux"
22
uuid = "7e08b658-56d3-11e9-2997-919d5b31e4ea"
33
authors = ["Yueh-Hua Tu <[email protected]>"]
4-
version = "0.13.1"
4+
version = "0.13.4"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -30,7 +30,7 @@ ChainRulesCore = "1"
3030
DataStructures = "0.18"
3131
FillArrays = "0.13"
3232
Flux = "0.12 - 0.13"
33-
GraphSignals = "0.6"
33+
GraphSignals = "0.7"
3434
Graphs = "1"
3535
MLDatasets = "0.7"
3636
NNlib = "0.8"

docs/src/assets/FeaturedGraph-support-DataLoader.svg

Lines changed: 4 additions & 0 deletions
Loading

docs/src/assets/cuda-minibatch.svg

Lines changed: 4 additions & 0 deletions
Loading

docs/src/assets/graphparallel.svg

Lines changed: 4 additions & 0 deletions
Loading

docs/src/basics/batch.md

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,37 @@
11
# Batch Learning
22

3-
## Batch Learning for Variable Graph Strategy
3+
## Mini-batch Learning for [`FeaturedGraph`](@ref)
44

5-
Batch learning for variable graph strategy can be prepared as follows:
5+
```@raw html
6+
<figure>
7+
<img src="../../assets/FeaturedGraph-support-DataLoader.svg" width="100%" alt="FeaturedGraph-support-DataLoader.svg" /><br>
8+
<figcaption><em>FeaturedGraph supports DataLoader.</em></figcaption>
9+
</figure>
10+
```
11+
12+
Batch learning for [`FeaturedGraph`](@ref) can be prepared as follows:
613

714
```julia
8-
train_data = [(FeaturedGraph(g, nf=train_X), train_y) for _ in 1:N]
9-
train_batch = Flux.batch(train_data)
15+
train_data = (FeaturedGraph(g, nf=train_X), train_y)
16+
train_batch = DataLoader(train_data, batchsize=batch_size, shuffle=true)
1017
```
1118

12-
It batches up [`FeaturedGraph`](@ref) objects into specified mini-batch. A batch is passed to a GNN model and trained/inferred one by one. It is hard for [`FeaturedGraph`](@ref) objects to train or infer in real batch for GPU.
19+
[`FeaturedGraph`](@ref) now supports `DataLoader` and one can specify mini-batch to it.
20+
A mini-batch is passed to a GNN model and trained/inferred in one [`FeaturedGraph`](@ref).
1321

14-
## Batch Learning for Static Graph Strategy
22+
## Mini-batch Learning for array
23+
24+
```@raw html
25+
<figure>
26+
<img src="../../assets/cuda-minibatch.svg" width="100%" alt="cuda-minibatch.svg" /><br>
27+
<figcaption><em>Mini-batch learning on CUDA.</em></figcaption>
28+
</figure>
29+
```
1530

16-
A efficient batch learning should use static graph strategy. Batch learning for static graph strategy can be prepared as follows:
31+
Mini-batch learning for array can be prepared as follows:
1732

1833
```julia
19-
train_data = (repeat(train_X, outer=(1,1,N)), repeat(train_y, outer=(1,1,N)))
20-
train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=true)
34+
train_loader = DataLoader((train_X, train_y), batchsize=batch_size, shuffle=true)
2135
```
2236

23-
An efficient batch learning should feed array to a GNN model. In the example, the mini-batch dimension is the third dimension for `train_X` array. The `train_X` array is split by `DataLoader` into mini-batches and feed a mini-batch to GNN model at a time. This strategy leverages the advantage of GPU training by accelerating training GNN model in a real batch learning.
37+
An array could be fed to a GNN model. In the example, the mini-batch dimension is the last dimension for `train_X` array. The `train_X` array is split by `DataLoader` into mini-batches and feed a mini-batch to GNN model at a time. This strategy leverages the advantage of GPU training by accelerating training GNN model in a real batch learning.

docs/src/cooperate.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ model = Chain(
2626

2727
## Branching Different Features Through Different Layers
2828

29+
```@raw html
30+
<figure>
31+
<img src="../assets/graphparallel.svg" width="70%" alt="graphparallel.svg" /><br>
32+
<figcaption><em>GraphParallel wraps regular Flux layers for different kinds of features for integration to GNN layers.</em></figcaption>
33+
</figure>
34+
```
35+
2936
A [`GraphParallel`](@ref) construct is designed for passing each feature through different layers from a [`FeaturedGraph`](@ref). An example is given as follow:
3037

3138
```julia

docs/src/manual/graph_conv.md

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,26 @@ Reference: [Morris2019](@cite)
5555

5656
---
5757

58+
## SAmple and aggreGatE (GraphSAGE) Network
59+
60+
```math
61+
\hat{\textbf{x}}_j = sample(\textbf{x}_j), \forall j \in \mathcal{N}(i) \\
62+
\textbf{m}_i = aggregate(\hat{\textbf{x}}_j) \\
63+
\textbf{x}_i' = \sigma (\Theta_1 \textbf{x}_i + \Theta_2 \textbf{m}_i)
64+
```
65+
66+
```@docs
67+
SAGEConv
68+
MeanAggregator
69+
MeanPoolAggregator
70+
MaxPoolAggregator
71+
LSTMAggregator
72+
```
73+
74+
Reference: [Hamilton2017](@cite) and [GraphSAGE website](http://snap.stanford.edu/graphsage/)
75+
76+
---
77+
5878
## Graph Attentional Layer
5979

6080
```math
@@ -122,7 +142,7 @@ Reference: [Wang2019](@cite)
122142
## Graph Isomorphism Network
123143

124144
```math
125-
\textbf{x}_i' = f_{\Theta}\left((1 + \varepsilon) \dot \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \textbf{x}_j \right)
145+
\textbf{x}_i' = f_{\Theta}\left((1 + \varepsilon) \cdot \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \textbf{x}_j \right)
126146
```
127147

128148
where ``f_{\Theta}`` denotes a neural network parametrized by ``\Theta``, *i.e.*, a MLP.
@@ -148,23 +168,3 @@ CGConv
148168
```
149169

150170
Reference: [Xie2018](@cite)
151-
152-
---
153-
154-
## SAmple and aggreGatE (GraphSAGE) Network
155-
156-
```math
157-
\hat{\textbf{x}}_j = sample(\textbf{x}_j), \forall j \in \mathcal{N}(i) \\
158-
\textbf{m}_i = aggregate(\hat{\textbf{x}}_j) \\
159-
\textbf{x}_i' = \sigma (\Theta_1 \textbf{x}_i + \Theta_2 \textbf{m}_i)
160-
```
161-
162-
```@docs
163-
SAGEConv
164-
MeanAggregator
165-
MeanPoolAggregator
166-
MaxPoolAggregator
167-
LSTMAggregator
168-
```
169-
170-
Reference: [Hamilton2017](@cite) and [GraphSAGE website](http://snap.stanford.edu/graphsage/)

src/layers/group_conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ julia> in_dim, out_dim, pos_dim = 3, 5, 2
1818
(3, 5, 2)
1919
2020
julia> egnn = EEquivGraphConv(in_dim=>out_dim, pos_dim, in_dim)
21-
EEquivGraphConv(ϕ_edge=Dense(10 => 5), ϕ_x=Dense(5 => 2), ϕ_h=Dense(8 => 5))
21+
EEquivGraphConv(ϕ_edge=Chain(Dense(10 => 2), Dense(2 => 2)), ϕ_x=Chain(Dense(2 => 2), Dense(2 => 1; bias=false)), ϕ_h=Chain(Dense(5 => 2), Dense(2 => 5)))
2222
```
2323
2424
See also [`WithGraph`](@ref) for training layer with static graph and [`EEquivGraphPE`](@ref) for positional encoding.

test/layers/gn.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,17 @@
2121
@testset "without aggregation" begin
2222
function (l::NewGNLayer)(fg::AbstractFeaturedGraph)
2323
nf = node_feature(fg)
24-
ef = edge_feature(fg)
2524
GraphSignals.check_num_nodes(fg, nf)
26-
GraphSignals.check_num_edges(fg, ef)
27-
return GeometricFlux.propagate(l, graph(fg), ef, nf, global_feature(fg), nothing, nothing, nothing)
25+
return GeometricFlux.propagate(l, graph(fg), nothing, nf, nothing, nothing, nothing, nothing)
2826
end
2927

3028
fg = FeaturedGraph(adj, nf=nf)
3129
l = NewGNLayer()
3230
ef_, nf_, gf_ = l(fg)
3331

3432
@test nf_ == nf
35-
@test size(ef_) == (0, 2E)
36-
@test size(gf_) == (0,)
33+
@test isnothing(ef_)
34+
@test isnothing(gf_)
3735
end
3836

3937
@testset "with neighbor aggregation" begin
@@ -42,16 +40,16 @@
4240
ef = edge_feature(fg)
4341
GraphSignals.check_num_nodes(fg, nf)
4442
GraphSignals.check_num_edges(fg, ef)
45-
return GeometricFlux.propagate(l, graph(fg), ef, nf, global_feature(fg), +, nothing, nothing)
43+
return GeometricFlux.propagate(l, graph(fg), ef, nf, nothing, +, nothing, nothing)
4644
end
4745

4846
fg = FeaturedGraph(adj, nf=nf, ef=ef, gf=zeros(0))
4947
l = NewGNLayer()
5048
ef_, nf_, gf_ = l(fg)
5149

5250
@test size(nf_) == (in_channel, V)
53-
@test size(ef_) == (0, 2E)
54-
@test size(gf_) == (0,)
51+
@test size(ef_) == (in_channel, 2E)
52+
@test isnothing(gf_)
5553
end
5654

5755
GeometricFlux.update_edge(l::NewGNLayer, e, vi, vj, u) = similar(e, out_channel, size(e)[2:end]...)
@@ -61,7 +59,7 @@
6159
ef = edge_feature(fg)
6260
GraphSignals.check_num_nodes(fg, nf)
6361
GraphSignals.check_num_edges(fg, ef)
64-
return GeometricFlux.propagate(l, graph(fg), ef, nf, global_feature(fg), +, nothing, nothing)
62+
return GeometricFlux.propagate(l, graph(fg), ef, nf, nothing, +, nothing, nothing)
6563
end
6664

6765
fg = FeaturedGraph(adj, nf=nf, ef=ef, gf=zeros(0))
@@ -70,17 +68,18 @@
7068

7169
@test size(nf_) == (in_channel, V)
7270
@test size(ef_) == (out_channel, 2E)
73-
@test size(gf_) == (0,)
71+
@test isnothing(gf_)
7472
end
7573

7674
GeometricFlux.update_vertex(l::NewGNLayer, ē, vi, u) = similar(vi, out_channel, size(vi)[2:end]...)
7775
@testset "update edge/vertex with all aggregation" begin
7876
function (l::NewGNLayer)(fg::AbstractFeaturedGraph)
7977
nf = node_feature(fg)
8078
ef = edge_feature(fg)
79+
gf = global_feature(fg)
8180
GraphSignals.check_num_nodes(fg, nf)
8281
GraphSignals.check_num_edges(fg, ef)
83-
return GeometricFlux.propagate(l, graph(fg), ef, nf, global_feature(fg), +, +, +)
82+
return GeometricFlux.propagate(l, graph(fg), ef, nf, gf, +, +, +)
8483
end
8584

8685
fg = FeaturedGraph(adj, nf=nf, ef=ef, gf=gf)

test/layers/msgpass.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
@test GraphSignals.adjacency_matrix(fg_) == adj
3838
@test size(node_feature(fg_)) == (in_channel, num_V)
3939
@test size(edge_feature(fg_)) == (0, num_E)
40-
@test size(global_feature(fg_)) == (0,)
40+
@test !has_global_feature(fg_)
4141
end
4242

4343
GeometricFlux.message(l::NewLayer, x_i, x_j::AbstractMatrix, e_ij) = l.weight * x_j
@@ -47,7 +47,7 @@
4747
@test GraphSignals.adjacency_matrix(fg_) == adj
4848
@test size(node_feature(fg_)) == (out_channel, num_V)
4949
@test size(edge_feature(fg_)) == (0, num_E)
50-
@test size(global_feature(fg_)) == (0,)
50+
@test !has_global_feature(fg_)
5151
end
5252

5353
GeometricFlux.update(l::NewLayer, m::AbstractMatrix, x) = l.weight * x + m
@@ -57,6 +57,6 @@
5757
@test GraphSignals.adjacency_matrix(fg_) == adj
5858
@test size(node_feature(fg_)) == (out_channel, num_V)
5959
@test size(edge_feature(fg_)) == (0, num_E)
60-
@test size(global_feature(fg_)) == (0,)
60+
@test !has_global_feature(fg_)
6161
end
6262
end

0 commit comments

Comments
 (0)