|
1 | 1 | from collections import defaultdict, deque
|
2 | 2 | import datetime
|
3 |
| -import pickle |
| 3 | +import errno |
| 4 | +import os |
4 | 5 | import time
|
5 | 6 |
|
6 | 7 | import torch
|
7 | 8 | import torch.distributed as dist
|
8 | 9 |
|
9 |
| -import errno |
10 |
| -import os |
11 |
| - |
12 | 10 |
|
13 | 11 | class SmoothedValue(object):
|
14 | 12 | """Track a series of values and provide access to smoothed values over a
|
@@ -83,35 +81,8 @@ def all_gather(data):
|
83 | 81 | world_size = get_world_size()
|
84 | 82 | if world_size == 1:
|
85 | 83 | return [data]
|
86 |
| - |
87 |
| - # serialized to a Tensor |
88 |
| - buffer = pickle.dumps(data) |
89 |
| - storage = torch.ByteStorage.from_buffer(buffer) |
90 |
| - tensor = torch.ByteTensor(storage).to("cuda") |
91 |
| - |
92 |
| - # obtain Tensor size of each rank |
93 |
| - local_size = torch.tensor([tensor.numel()], device="cuda") |
94 |
| - size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] |
95 |
| - dist.all_gather(size_list, local_size) |
96 |
| - size_list = [int(size.item()) for size in size_list] |
97 |
| - max_size = max(size_list) |
98 |
| - |
99 |
| - # receiving Tensor from all ranks |
100 |
| - # we pad the tensor because torch all_gather does not support |
101 |
| - # gathering tensors of different shapes |
102 |
| - tensor_list = [] |
103 |
| - for _ in size_list: |
104 |
| - tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) |
105 |
| - if local_size != max_size: |
106 |
| - padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") |
107 |
| - tensor = torch.cat((tensor, padding), dim=0) |
108 |
| - dist.all_gather(tensor_list, tensor) |
109 |
| - |
110 |
| - data_list = [] |
111 |
| - for size, tensor in zip(size_list, tensor_list): |
112 |
| - buffer = tensor.cpu().numpy().tobytes()[:size] |
113 |
| - data_list.append(pickle.loads(buffer)) |
114 |
| - |
| 84 | + data_list = [None] * world_size |
| 85 | + dist.all_gather_object(data_list, data) |
115 | 86 | return data_list
|
116 | 87 |
|
117 | 88 |
|
|
0 commit comments