Skip to content

Commit 59c6731

Browse files
authored
Updated all_gather() to make use of all_gather_object() (#3857)
1 parent 3c47bfd commit 59c6731

File tree

1 file changed

+4
-33
lines changed

1 file changed

+4
-33
lines changed

references/detection/utils.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from collections import defaultdict, deque
22
import datetime
3-
import pickle
3+
import errno
4+
import os
45
import time
56

67
import torch
78
import torch.distributed as dist
89

9-
import errno
10-
import os
11-
1210

1311
class SmoothedValue(object):
1412
"""Track a series of values and provide access to smoothed values over a
@@ -83,35 +81,8 @@ def all_gather(data):
8381
world_size = get_world_size()
8482
if world_size == 1:
8583
return [data]
86-
87-
# serialized to a Tensor
88-
buffer = pickle.dumps(data)
89-
storage = torch.ByteStorage.from_buffer(buffer)
90-
tensor = torch.ByteTensor(storage).to("cuda")
91-
92-
# obtain Tensor size of each rank
93-
local_size = torch.tensor([tensor.numel()], device="cuda")
94-
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
95-
dist.all_gather(size_list, local_size)
96-
size_list = [int(size.item()) for size in size_list]
97-
max_size = max(size_list)
98-
99-
# receiving Tensor from all ranks
100-
# we pad the tensor because torch all_gather does not support
101-
# gathering tensors of different shapes
102-
tensor_list = []
103-
for _ in size_list:
104-
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
105-
if local_size != max_size:
106-
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
107-
tensor = torch.cat((tensor, padding), dim=0)
108-
dist.all_gather(tensor_list, tensor)
109-
110-
data_list = []
111-
for size, tensor in zip(size_list, tensor_list):
112-
buffer = tensor.cpu().numpy().tobytes()[:size]
113-
data_list.append(pickle.loads(buffer))
114-
84+
data_list = [None] * world_size
85+
dist.all_gather_object(data_list, data)
11586
return data_list
11687

11788

0 commit comments

Comments
 (0)