Skip to content

Commit cc4f0d2

Browse files
authored
Merge pull request #324 from FluxML/doc
Update doc for FeaturedGraph
2 parents 985ad94 + 528adfe commit cc4f0d2

File tree

5 files changed

+43
-10
lines changed

5 files changed

+43
-10
lines changed

docs/src/assets/FeaturedGraph-support-DataLoader.svg

Lines changed: 4 additions & 0 deletions
Loading

docs/src/assets/cuda-minibatch.svg

Lines changed: 4 additions & 0 deletions
Loading

docs/src/assets/graphparallel.svg

Lines changed: 4 additions & 0 deletions
Loading

docs/src/basics/batch.md

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,37 @@
11
# Batch Learning
22

3-
## Batch Learning for Variable Graph Strategy
3+
## Mini-batch Learning for [`FeaturedGraph`](@ref)
44

5-
Batch learning for variable graph strategy can be prepared as follows:
5+
```@raw html
6+
<figure>
7+
<img src="../../assets/FeaturedGraph-support-DataLoader.svg" width="100%" alt="FeaturedGraph-support-DataLoader.svg" /><br>
8+
<figcaption><em>FeaturedGraph supports DataLoader.</em></figcaption>
9+
</figure>
10+
```
11+
12+
Batch learning for [`FeaturedGraph`](@ref) can be prepared as follows:
613

714
```julia
8-
train_data = [(FeaturedGraph(g, nf=train_X), train_y) for _ in 1:N]
9-
train_batch = Flux.batch(train_data)
15+
train_data = (FeaturedGraph(g, nf=train_X), train_y)
16+
train_batch = DataLoader(train_data, batchsize=batch_size, shuffle=true)
1017
```
1118

12-
It batches up [`FeaturedGraph`](@ref) objects into specified mini-batch. A batch is passed to a GNN model and trained/inferred one by one. It is hard for [`FeaturedGraph`](@ref) objects to train or infer in real batch for GPU.
19+
[`FeaturedGraph`](@ref) now supports `DataLoader` and one can specify mini-batch to it.
20+
A mini-batch is passed to a GNN model and trained/inferred in one [`FeaturedGraph`](@ref).
1321

14-
## Batch Learning for Static Graph Strategy
22+
## Mini-batch Learning for array
23+
24+
```@raw html
25+
<figure>
26+
<img src="../../assets/cuda-minibatch.svg" width="100%" alt="cuda-minibatch.svg" /><br>
27+
<figcaption><em>Mini-batch learning on CUDA.</em></figcaption>
28+
</figure>
29+
```
1530

16-
A efficient batch learning should use static graph strategy. Batch learning for static graph strategy can be prepared as follows:
31+
Mini-batch learning for array can be prepared as follows:
1732

1833
```julia
19-
train_data = (repeat(train_X, outer=(1,1,N)), repeat(train_y, outer=(1,1,N)))
20-
train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=true)
34+
train_loader = DataLoader((train_X, train_y), batchsize=batch_size, shuffle=true)
2135
```
2236

23-
An efficient batch learning should feed array to a GNN model. In the example, the mini-batch dimension is the third dimension for `train_X` array. The `train_X` array is split by `DataLoader` into mini-batches and feed a mini-batch to GNN model at a time. This strategy leverages the advantage of GPU training by accelerating training GNN model in a real batch learning.
37+
An array could be fed to a GNN model. In the example, the mini-batch dimension is the last dimension for `train_X` array. The `train_X` array is split by `DataLoader` into mini-batches and feed a mini-batch to GNN model at a time. This strategy leverages the advantage of GPU training by accelerating training GNN model in a real batch learning.

docs/src/cooperate.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ model = Chain(
2626

2727
## Branching Different Features Through Different Layers
2828

29+
```@raw html
30+
<figure>
31+
<img src="../assets/graphparallel.svg" width="70%" alt="graphparallel.svg" /><br>
32+
<figcaption><em>GraphParallel wraps regular Flux layers for different kinds of features for integration to GNN layers.</em></figcaption>
33+
</figure>
34+
```
35+
2936
A [`GraphParallel`](@ref) construct is designed for passing each feature through different layers from a [`FeaturedGraph`](@ref). An example is given as follow:
3037

3138
```julia

0 commit comments

Comments
 (0)