|
1 | 1 | # Batch Learning
|
2 | 2 |
|
3 |
| -## Batch Learning for Variable Graph Strategy |
| 3 | +## Mini-batch Learning for [`FeaturedGraph`](@ref) |
4 | 4 |
|
5 |
| -Batch learning for variable graph strategy can be prepared as follows: |
| 5 | +```@raw html |
| 6 | +<figure> |
| 7 | + <img src="../../assets/FeaturedGraph-support-DataLoader.svg" width="100%" alt="FeaturedGraph-support-DataLoader.svg" /><br> |
| 8 | + <figcaption><em>FeaturedGraph supports DataLoader.</em></figcaption> |
| 9 | +</figure> |
| 10 | +``` |
| 11 | + |
| 12 | +Batch learning for [`FeaturedGraph`](@ref) can be prepared as follows: |
6 | 13 |
|
7 | 14 | ```julia
|
8 |
| -train_data = [(FeaturedGraph(g, nf=train_X), train_y) for _ in 1:N] |
9 |
| -train_batch = Flux.batch(train_data) |
| 15 | +train_data = (FeaturedGraph(g, nf=train_X), train_y) |
| 16 | +train_batch = DataLoader(train_data, batchsize=batch_size, shuffle=true) |
10 | 17 | ```
|
11 | 18 |
|
12 |
| -It batches up [`FeaturedGraph`](@ref) objects into specified mini-batch. A batch is passed to a GNN model and trained/inferred one by one. It is hard for [`FeaturedGraph`](@ref) objects to train or infer in real batch for GPU. |
| 19 | +[`FeaturedGraph`](@ref) now supports `DataLoader` and one can specify mini-batch to it. |
| 20 | +A mini-batch is passed to a GNN model and trained/inferred in one [`FeaturedGraph`](@ref). |
13 | 21 |
|
14 |
| -## Batch Learning for Static Graph Strategy |
| 22 | +## Mini-batch Learning for array |
| 23 | + |
| 24 | +```@raw html |
| 25 | +<figure> |
| 26 | + <img src="../../assets/cuda-minibatch.svg" width="100%" alt="cuda-minibatch.svg" /><br> |
| 27 | + <figcaption><em>Mini-batch learning on CUDA.</em></figcaption> |
| 28 | +</figure> |
| 29 | +``` |
15 | 30 |
|
16 |
| -A efficient batch learning should use static graph strategy. Batch learning for static graph strategy can be prepared as follows: |
| 31 | +Mini-batch learning for array can be prepared as follows: |
17 | 32 |
|
18 | 33 | ```julia
|
19 |
| -train_data = (repeat(train_X, outer=(1,1,N)), repeat(train_y, outer=(1,1,N))) |
20 |
| -train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=true) |
| 34 | +train_loader = DataLoader((train_X, train_y), batchsize=batch_size, shuffle=true) |
21 | 35 | ```
|
22 | 36 |
|
23 |
| -An efficient batch learning should feed array to a GNN model. In the example, the mini-batch dimension is the third dimension for `train_X` array. The `train_X` array is split by `DataLoader` into mini-batches and feed a mini-batch to GNN model at a time. This strategy leverages the advantage of GPU training by accelerating training GNN model in a real batch learning. |
| 37 | +An array could be fed to a GNN model. In the example, the mini-batch dimension is the last dimension for `train_X` array. The `train_X` array is split by `DataLoader` into mini-batches and feed a mini-batch to GNN model at a time. This strategy leverages the advantage of GPU training by accelerating training GNN model in a real batch learning. |
0 commit comments