|
| 1 | +import argparse |
1 | 2 | import os |
| 3 | +import tempfile |
2 | 4 | import time |
3 | 5 |
|
4 | 6 | import torch |
|
7 | 9 | import torch.nn.functional as F |
8 | 10 | from ogb.nodeproppred import PygNodePropPredDataset |
9 | 11 | from torch.nn.parallel import DistributedDataParallel |
| 12 | +from torchmetrics import Accuracy |
10 | 13 |
|
| 14 | +import torch_geometric |
11 | 15 | from torch_geometric.loader import NeighborLoader |
12 | | -from torch_geometric.nn import GCNConv |
13 | 16 |
|
14 | 17 |
|
15 | | -def get_num_workers(world_size: int) -> int: |
16 | | - num_workers = None |
| 18 | +def get_num_workers(world_size): |
| 19 | + num_work = None |
17 | 20 | if hasattr(os, "sched_getaffinity"): |
18 | 21 | try: |
19 | | - num_workers = len(os.sched_getaffinity(0)) // (2 * world_size) |
| 22 | + num_work = len(os.sched_getaffinity(0)) / (2 * world_size) |
20 | 23 | except Exception: |
21 | 24 | pass |
22 | | - if num_workers is None: |
23 | | - num_workers = os.cpu_count() // (2 * world_size) |
24 | | - return num_workers |
| 25 | + if num_work is None: |
| 26 | + num_work = os.cpu_count() / (2 * world_size) |
| 27 | + return int(num_work) |
25 | 28 |
|
26 | 29 |
|
27 | | -class GCN(torch.nn.Module): |
28 | | - def __init__(self, in_channels, hidden_channels, out_channels): |
29 | | - super().__init__() |
30 | | - self.conv1 = GCNConv(in_channels, hidden_channels) |
31 | | - self.conv2 = GCNConv(hidden_channels, out_channels) |
| 30 | +def run_train(rank, data, world_size, model, epochs, batch_size, fan_out, |
| 31 | + split_idx, num_classes, wall_clock_start, tempdir=None, |
| 32 | + num_layers=3): |
32 | 33 |
|
33 | | - def forward(self, x, edge_index=None): |
34 | | - x = F.dropout(x, p=0.5, training=self.training) |
35 | | - x = self.conv1(x, edge_index).relu() |
36 | | - x = F.dropout(x, p=0.5, training=self.training) |
37 | | - x = self.conv2(x, edge_index) |
38 | | - return x |
39 | | - |
40 | | - |
41 | | -def run(rank, world_size, data, split_idx, model): |
| 34 | + # init pytorch worker |
42 | 35 | os.environ['MASTER_ADDR'] = 'localhost' |
43 | 36 | os.environ['MASTER_PORT'] = '12355' |
44 | 37 | dist.init_process_group('nccl', rank=rank, world_size=world_size) |
45 | 38 |
|
46 | | - split_idx['train'] = split_idx['train'].split( |
47 | | - split_idx['train'].size(0) // world_size, |
48 | | - dim=0, |
49 | | - )[rank].clone() |
50 | | - |
51 | | - model = DistributedDataParallel(model.to(rank), device_ids=[rank]) |
52 | | - optimizer = torch.optim.Adam(model.parameters(), lr=0.01) |
| 39 | + if world_size > 1: |
| 40 | + split_idx['train'] = split_idx['train'].split( |
| 41 | + split_idx['train'].size(0) // world_size, dim=0)[rank].clone() |
| 42 | + split_idx['valid'] = split_idx['valid'].split( |
| 43 | + split_idx['valid'].size(0) // world_size, dim=0)[rank].clone() |
| 44 | + split_idx['test'] = split_idx['test'].split( |
| 45 | + split_idx['test'].size(0) // world_size, dim=0)[rank].clone() |
| 46 | + model = model.to(rank) |
| 47 | + model = DistributedDataParallel(model, device_ids=[rank]) |
| 48 | + optimizer = torch.optim.Adam(model.parameters(), lr=0.01, |
| 49 | + weight_decay=0.0005) |
53 | 50 |
|
54 | 51 | kwargs = dict( |
55 | | - data=data, |
56 | | - batch_size=128, |
57 | | - num_workers=get_num_workers(world_size), |
58 | | - num_neighbors=[50, 50], |
| 52 | + num_neighbors=[fan_out] * num_layers, |
| 53 | + batch_size=batch_size, |
59 | 54 | ) |
60 | | - train_loader = NeighborLoader( |
61 | | - input_nodes=split_idx['train'], |
62 | | - shuffle=True, |
63 | | - **kwargs, |
64 | | - ) |
65 | | - if rank == 0: |
66 | | - val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs) |
67 | | - test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs) |
68 | | - |
69 | | - val_steps = 1000 |
70 | | - warmup_steps = 100 |
| 55 | + num_work = get_num_workers(world_size) |
| 56 | + train_loader = NeighborLoader(data, input_nodes=split_idx['train'], |
| 57 | + num_workers=num_work, shuffle=True, |
| 58 | + drop_last=True, **kwargs) |
| 59 | + val_loader = NeighborLoader(data, input_nodes=split_idx['valid'], |
| 60 | + num_workers=num_work, **kwargs) |
| 61 | + test_loader = NeighborLoader(data, input_nodes=split_idx['test'], |
| 62 | + num_workers=num_work, **kwargs) |
| 63 | + |
| 64 | + eval_steps = 1000 |
| 65 | + warmup_steps = 20 |
| 66 | + acc = Accuracy(task="multiclass", num_classes=num_classes).to(rank) |
| 67 | + dist.barrier() |
| 68 | + torch.cuda.synchronize() |
71 | 69 | if rank == 0: |
| 70 | + prep_time = round(time.perf_counter() - wall_clock_start, 2) |
| 71 | + print("Total time before training begins (prep_time) =", prep_time, |
| 72 | + "seconds") |
72 | 73 | print("Beginning training...") |
73 | | - |
74 | | - for epoch in range(1, 4): |
75 | | - model.train() |
| 74 | + for epoch in range(epochs): |
76 | 75 | for i, batch in enumerate(train_loader): |
77 | 76 | if i == warmup_steps: |
| 77 | + torch.cuda.synchronize() |
78 | 78 | start = time.time() |
79 | 79 | batch = batch.to(rank) |
| 80 | + batch_size = batch.num_sampled_nodes[0] |
| 81 | + batch.y = batch.y.to(torch.long) |
80 | 82 | optimizer.zero_grad() |
81 | | - y = batch.y[:batch.batch_size].view(-1).to(torch.long) |
82 | | - out = model(batch.x, batch.edge_index)[:batch.batch_size] |
83 | | - loss = F.cross_entropy(out, y) |
| 83 | + out = model(batch.x, batch.edge_index) |
| 84 | + loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size]) |
84 | 85 | loss.backward() |
85 | 86 | optimizer.step() |
86 | | - |
87 | 87 | if rank == 0 and i % 10 == 0: |
88 | | - print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}') |
89 | | - |
| 88 | + print("Epoch: " + str(epoch) + ", Iteration: " + str(i) + |
| 89 | + ", Loss: " + str(loss)) |
| 90 | + nb = i + 1.0 |
| 91 | + dist.barrier() |
| 92 | + torch.cuda.synchronize() |
90 | 93 | if rank == 0: |
91 | | - sec_per_iter = (time.time() - start) / (i - warmup_steps) |
92 | | - print(f"Avg Training Iteration Time: {sec_per_iter:.6f} s/iter") |
93 | | - |
94 | | - model.eval() |
95 | | - total_correct = total_examples = 0 |
| 94 | + print("Average Training Iteration Time:", |
| 95 | + (time.time() - start) / (nb - warmup_steps), "s/iter") |
| 96 | + with torch.no_grad(): |
96 | 97 | for i, batch in enumerate(val_loader): |
97 | | - if i >= val_steps: |
| 98 | + if i >= eval_steps: |
98 | 99 | break |
99 | | - if i == warmup_steps: |
100 | | - start = time.time() |
101 | 100 |
|
102 | 101 | batch = batch.to(rank) |
103 | | - with torch.no_grad(): |
104 | | - out = model(batch.x, batch.edge_index)[:batch.batch_size] |
105 | | - pred = out.argmax(dim=-1) |
106 | | - y = batch.y[:batch.batch_size].view(-1).to(torch.long) |
107 | | - |
108 | | - total_correct += int((pred == y).sum()) |
109 | | - total_examples += y.size(0) |
110 | | - |
111 | | - print(f"Val Acc: {total_correct / total_examples:.4f}") |
112 | | - sec_per_iter = (time.time() - start) / (i - warmup_steps) |
113 | | - print(f"Avg Inference Iteration Time: {sec_per_iter:.6f} s/iter") |
114 | | - |
115 | | - if rank == 0: |
116 | | - model.eval() |
117 | | - total_correct = total_examples = 0 |
| 102 | + batch_size = batch.num_sampled_nodes[0] |
| 103 | + |
| 104 | + batch.y = batch.y.to(torch.long) |
| 105 | + out = model(batch.x, batch.edge_index) |
| 106 | + acc_i = acc( # noqa |
| 107 | + out[:batch_size].softmax(dim=-1), batch.y[:batch_size]) |
| 108 | + acc_sum = acc.compute() |
| 109 | + if rank == 0: |
| 110 | + print(f"Validation Accuracy: {acc_sum * 100.0:.4f}%", ) |
| 111 | + dist.barrier() |
| 112 | + acc.reset() |
| 113 | + |
| 114 | + with torch.no_grad(): |
118 | 115 | for i, batch in enumerate(test_loader): |
119 | 116 | batch = batch.to(rank) |
120 | | - with torch.no_grad(): |
121 | | - out = model(batch.x, batch.edge_index)[:batch.batch_size] |
122 | | - pred = out.argmax(dim=-1) |
123 | | - y = batch.y[:batch.batch_size].view(-1).to(torch.long) |
| 117 | + batch_size = batch.num_sampled_nodes[0] |
124 | 118 |
|
125 | | - total_correct += int((pred == y).sum()) |
126 | | - total_examples += y.size(0) |
127 | | - print(f"Test Acc: {total_correct / total_examples:.4f}") |
| 119 | + batch.y = batch.y.to(torch.long) |
| 120 | + out = model(batch.x, batch.edge_index) |
| 121 | + acc_i = acc( # noqa |
| 122 | + out[:batch_size].softmax(dim=-1), batch.y[:batch_size]) |
| 123 | + acc_sum = acc.compute() |
| 124 | + if rank == 0: |
| 125 | + print(f"Test Accuracy: {acc_sum * 100.0:.4f}%", ) |
| 126 | + dist.barrier() |
| 127 | + acc.reset() |
| 128 | + if rank == 0: |
| 129 | + total_time = round(time.perf_counter() - wall_clock_start, 2) |
| 130 | + print("Total Program Runtime (total_time) =", total_time, "seconds") |
| 131 | + print("total_time - prep_time =", total_time - prep_time, "seconds") |
128 | 132 |
|
129 | 133 |
|
130 | 134 | if __name__ == '__main__': |
131 | | - dataset = PygNodePropPredDataset(name='ogbn-papers100M') |
132 | | - split_idx = dataset.get_idx_split() |
133 | | - model = GCN(dataset.num_features, 64, dataset.num_classes) |
134 | 135 |
|
135 | | - world_size = torch.cuda.device_count() |
136 | | - print('Let\'s use', world_size, 'GPUs!') |
137 | | - mp.spawn( |
138 | | - run, |
139 | | - args=(world_size, dataset[0], split_idx, model), |
140 | | - nprocs=world_size, |
141 | | - join=True, |
| 136 | + parser = argparse.ArgumentParser() |
| 137 | + parser.add_argument('--hidden_channels', type=int, default=256) |
| 138 | + parser.add_argument('--num_layers', type=int, default=2) |
| 139 | + parser.add_argument('--lr', type=float, default=0.001) |
| 140 | + parser.add_argument('--epochs', type=int, default=20) |
| 141 | + parser.add_argument('--batch_size', type=int, default=1024) |
| 142 | + parser.add_argument('--fan_out', type=int, default=30) |
| 143 | + parser.add_argument( |
| 144 | + "--use_gat_conv", |
| 145 | + action='store_true', |
| 146 | + help="Whether or not to use GATConv. (Defaults to using GCNConv)", |
142 | 147 | ) |
| 148 | + parser.add_argument( |
| 149 | + "--n_gat_conv_heads", |
| 150 | + type=int, |
| 151 | + default=4, |
| 152 | + help="If using GATConv, number of attention heads to use", |
| 153 | + ) |
| 154 | + parser.add_argument( |
| 155 | + "--n_devices", type=int, default=-1, |
| 156 | + help="1-8 to use that many GPUs. Defaults to all available GPUs") |
| 157 | + |
| 158 | + args = parser.parse_args() |
| 159 | + wall_clock_start = time.perf_counter() |
| 160 | + |
| 161 | + dataset = PygNodePropPredDataset(name='ogbn-papers100M', |
| 162 | + root='/datasets/ogb_datasets') |
| 163 | + split_idx = dataset.get_idx_split() |
| 164 | + data = dataset[0] |
| 165 | + data.y = data.y.reshape(-1) |
| 166 | + if args.use_gat_conv: |
| 167 | + model = torch_geometric.nn.models.GAT(dataset.num_features, |
| 168 | + args.hidden_channels, |
| 169 | + args.num_layers, |
| 170 | + dataset.num_classes, |
| 171 | + heads=args.n_gat_conv_heads) |
| 172 | + else: |
| 173 | + model = torch_geometric.nn.models.GCN( |
| 174 | + dataset.num_features, |
| 175 | + args.hidden_channels, |
| 176 | + args.num_layers, |
| 177 | + dataset.num_classes, |
| 178 | + ) |
| 179 | + |
| 180 | + print("Data =", data) |
| 181 | + if args.n_devices == -1: |
| 182 | + world_size = torch.cuda.device_count() |
| 183 | + else: |
| 184 | + world_size = args.n_devices |
| 185 | + print('Let\'s use', world_size, 'GPUs!') |
| 186 | + with tempfile.TemporaryDirectory() as tempdir: |
| 187 | + if world_size > 1: |
| 188 | + mp.spawn( |
| 189 | + run_train, |
| 190 | + args=(data, world_size, model, args.epochs, args.batch_size, |
| 191 | + args.fan_out, split_idx, dataset.num_classes, |
| 192 | + wall_clock_start, tempdir, args.num_layers), |
| 193 | + nprocs=world_size, join=True) |
| 194 | + else: |
| 195 | + run_train(0, data, world_size, model, args.epochs, args.batch_size, |
| 196 | + args.fan_out, split_idx, dataset.num_classes, |
| 197 | + wall_clock_start, tempdir, args.num_layers) |
0 commit comments